26 bool LoadModel(
TinyLLMAsset* asset, int32_t maxSeqLen = 0);
28 bool IsModelLoaded()
const;
33 float*
Forward(int32_t token, int32_t pos);
34 int32_t Sample(
float temperature = 1.0f,
float topP = 0.9f);
37 std::string Generate(
const std::string& prompt, int32_t maxTokens,
38 float temperature = 1.0f,
float topP = 0.9f);
41 std::vector<int32_t> Encode(
const std::string& text,
bool addBos =
true);
42 std::string Decode(int32_t prevToken, int32_t token);
45 bool BeginGenerate(
const std::string& prompt, int32_t maxTokens,
46 float temperature = 1.0f,
float topP = 0.9f);
47 std::string ContinueGenerate();
48 bool IsGenerating()
const;
52 float GetLastTokPerSec()
const;
53 int32_t GetPosition()
const;
54 int32_t GetMaxSeqLen()
const;
57 static int32_t GetDefaultMaxSeqLen();
69 void AllocateRunState(int32_t maxSeqLen);
71 void SetupRunStatePointers();
72 int32_t SampleInternal(
float* logits,
float temperature,
float topP);
79 float* mRunStateBuffer =
nullptr;
80 size_t mRunStateSize = 0;
85 float* mXb2 =
nullptr;
87 float* mHb2 =
nullptr;
89 float* mKeyCache =
nullptr;
90 float* mValueCache =
nullptr;
91 float* mAtt =
nullptr;
92 float* mLogits =
nullptr;
94 std::vector<ProbIndex> mProbIndex;
95 uint64_t mRngState = 0;
99 int32_t mMaxSeqLen = 0;
102 bool mIsGenerating =
false;
103 std::vector<int32_t> mPromptTokens;
104 int32_t mPromptIdx = 0;
105 int32_t mLastToken = 0;
106 int32_t mGeneratedCount = 0;
107 int32_t mMaxGenTokens = 0;
108 float mTemperature = 1.0f;
112 float mLastTokPerSec = 0.0f;
113 int64_t mGenStartTime = 0;