diff --git a/src/api/providers/openai-native.ts b/src/api/providers/openai-native.ts index 30382c6..70d55b7 100644 --- a/src/api/providers/openai-native.ts +++ b/src/api/providers/openai-native.ts @@ -23,48 +23,52 @@ export class OpenAiNativeHandler implements ApiHandler { } async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - let systemPromptMessage: OpenAI.Chat.ChatCompletionMessageParam - let temperature = 0 switch (this.getModel().id) { case "o1-preview": - case "o1-mini": - systemPromptMessage = { role: "user", content: systemPrompt } - temperature = 1 - break - default: - systemPromptMessage = { role: "system", content: systemPrompt } - temperature = 0 - } - - const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ - systemPromptMessage, - ...convertToOpenAiMessages(messages), - ] - - const stream = await this.client.chat.completions.create({ - model: this.getModel().id, - // max_completion_tokens: this.getModel().info.maxTokens, - temperature, - messages: openAiMessages, - stream: true, - stream_options: { include_usage: true }, - }) - - for await (const chunk of stream) { - const delta = chunk.choices[0]?.delta - if (delta?.content) { + case "o1-mini": { + // o1 doesnt support streaming, non-1 temp, or system prompt + const response = await this.client.chat.completions.create({ + model: this.getModel().id, + messages: [{ role: "user", content: systemPrompt }, ...convertToOpenAiMessages(messages)], + }) yield { type: "text", - text: delta.content, + text: response.choices[0]?.message.content || "", } - } - - // contains a null value except for the last chunk which contains the token usage statistics for the entire request - if (chunk.usage) { yield { type: "usage", - inputTokens: chunk.usage.prompt_tokens || 0, - outputTokens: chunk.usage.completion_tokens || 0, + inputTokens: response.usage?.prompt_tokens || 0, + outputTokens: response.usage?.completion_tokens || 0, + } + break + } + default: { + const stream = await this.client.chat.completions.create({ + model: this.getModel().id, + // max_completion_tokens: this.getModel().info.maxTokens, + temperature: 0, + messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)], + stream: true, + stream_options: { include_usage: true }, + }) + + for await (const chunk of stream) { + const delta = chunk.choices[0]?.delta + if (delta?.content) { + yield { + type: "text", + text: delta.content, + } + } + + // contains a null value except for the last chunk which contains the token usage statistics for the entire request + if (chunk.usage) { + yield { + type: "usage", + inputTokens: chunk.usage.prompt_tokens || 0, + outputTokens: chunk.usage.completion_tokens || 0, + } + } } } }