diff --git a/src/api/providers/__tests__/openai-native.test.ts b/src/api/providers/__tests__/openai-native.test.ts index f1da211..719a7fe 100644 --- a/src/api/providers/__tests__/openai-native.test.ts +++ b/src/api/providers/__tests__/openai-native.test.ts @@ -153,11 +153,35 @@ describe("OpenAiNativeHandler", () => { expect(mockCreate).toHaveBeenCalledWith({ model: "o1", messages: [ - { role: "developer", content: systemPrompt }, + { role: "developer", content: "Formatting re-enabled\n" + systemPrompt }, { role: "user", content: "Hello!" }, ], }) }) + + it("should handle o3-mini model family correctly", async () => { + handler = new OpenAiNativeHandler({ + ...mockOptions, + apiModelId: "o3-mini", + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(mockCreate).toHaveBeenCalledWith({ + model: "o3-mini", + messages: [ + { role: "developer", content: "Formatting re-enabled\n" + systemPrompt }, + { role: "user", content: "Hello!" }, + ], + stream: true, + stream_options: { include_usage: true }, + reasoning_effort: "medium", + }) + }) }) describe("streaming models", () => { diff --git a/src/api/providers/openai-native.ts b/src/api/providers/openai-native.ts index 0b8908d..f1b5bce 100644 --- a/src/api/providers/openai-native.ts +++ b/src/api/providers/openai-native.ts @@ -24,57 +24,111 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { const modelId = this.getModel().id - switch (modelId) { - case "o1": - case "o1-preview": - case "o1-mini": { - // o1-preview and o1-mini don't support streaming, non-1 temp, or system prompt - // o1 doesnt support streaming or non-1 temp but does support a developer prompt - const response = await this.client.chat.completions.create({ - model: modelId, - messages: [ - { role: modelId === "o1" ? "developer" : "user", content: systemPrompt }, - ...convertToOpenAiMessages(messages), - ], - }) + + if (modelId.startsWith("o1")) { + yield* this.handleO1FamilyMessage(modelId, systemPrompt, messages) + return + } + + if (modelId.startsWith("o3-mini")) { + yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages) + return + } + + yield* this.handleDefaultModelMessage(modelId, systemPrompt, messages) + } + + private async *handleO1FamilyMessage( + modelId: string, + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[] + ): ApiStream { + // o1 supports developer prompt with formatting + // o1-preview and o1-mini only support user messages + const isOriginalO1 = modelId === "o1" + const response = await this.client.chat.completions.create({ + model: modelId, + messages: [ + { + role: isOriginalO1 ? "developer" : "user", + content: isOriginalO1 ? `Formatting re-enabled\n${systemPrompt}` : systemPrompt, + }, + ...convertToOpenAiMessages(messages), + ], + }) + + yield* this.yieldResponseData(response) + } + + private async *handleO3FamilyMessage( + modelId: string, + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[] + ): ApiStream { + const stream = await this.client.chat.completions.create({ + model: "o3-mini", + messages: [ + { + role: "developer", + content: `Formatting re-enabled\n${systemPrompt}`, + }, + ...convertToOpenAiMessages(messages), + ], + stream: true, + stream_options: { include_usage: true }, + reasoning_effort: this.getModel().info.reasoningEffort, + }) + + yield* this.handleStreamResponse(stream) + } + + private async *handleDefaultModelMessage( + modelId: string, + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[] + ): ApiStream { + const stream = await this.client.chat.completions.create({ + model: modelId, + temperature: 0, + messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)], + stream: true, + stream_options: { include_usage: true }, + }) + + yield* this.handleStreamResponse(stream) + } + + private async *yieldResponseData( + response: OpenAI.Chat.Completions.ChatCompletion + ): ApiStream { + yield { + type: "text", + text: response.choices[0]?.message.content || "", + } + yield { + type: "usage", + inputTokens: response.usage?.prompt_tokens || 0, + outputTokens: response.usage?.completion_tokens || 0, + } + } + + private async *handleStreamResponse( + stream: AsyncIterable + ): ApiStream { + for await (const chunk of stream) { + const delta = chunk.choices[0]?.delta + if (delta?.content) { yield { type: "text", - text: response.choices[0]?.message.content || "", + text: delta.content, } + } + + if (chunk.usage) { yield { type: "usage", - inputTokens: response.usage?.prompt_tokens || 0, - outputTokens: response.usage?.completion_tokens || 0, - } - break - } - default: { - const stream = await this.client.chat.completions.create({ - model: this.getModel().id, - // max_completion_tokens: this.getModel().info.maxTokens, - temperature: 0, - messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)], - stream: true, - stream_options: { include_usage: true }, - }) - - for await (const chunk of stream) { - const delta = chunk.choices[0]?.delta - if (delta?.content) { - yield { - type: "text", - text: delta.content, - } - } - - // contains a null value except for the last chunk which contains the token usage statistics for the entire request - if (chunk.usage) { - yield { - type: "usage", - inputTokens: chunk.usage.prompt_tokens || 0, - outputTokens: chunk.usage.completion_tokens || 0, - } - } + inputTokens: chunk.usage.prompt_tokens || 0, + outputTokens: chunk.usage.completion_tokens || 0, } } } @@ -94,22 +148,12 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler const modelId = this.getModel().id let requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming - switch (modelId) { - case "o1": - case "o1-preview": - case "o1-mini": - // o1 doesn't support non-1 temp - requestOptions = { - model: modelId, - messages: [{ role: "user", content: prompt }], - } - break - default: - requestOptions = { - model: modelId, - messages: [{ role: "user", content: prompt }], - temperature: 0, - } + if (modelId.startsWith("o1")) { + requestOptions = this.getO1CompletionOptions(modelId, prompt) + } else if (modelId.startsWith("o3-mini")) { + requestOptions = this.getO3CompletionOptions(modelId, prompt) + } else { + requestOptions = this.getDefaultCompletionOptions(modelId, prompt) } const response = await this.client.chat.completions.create(requestOptions) @@ -121,4 +165,36 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler throw error } } + + private getO1CompletionOptions( + modelId: string, + prompt: string + ): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming { + return { + model: modelId, + messages: [{ role: "user", content: prompt }], + } + } + + private getO3CompletionOptions( + modelId: string, + prompt: string + ): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming { + return { + model: "o3-mini", + messages: [{ role: "user", content: prompt }], + reasoning_effort: this.getModel().info.reasoningEffort, + } + } + + private getDefaultCompletionOptions( + modelId: string, + prompt: string + ): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming { + return { + model: modelId, + messages: [{ role: "user", content: prompt }], + temperature: 0, + } + } }