diff --git a/.changeset/light-shoes-rescue.md b/.changeset/light-shoes-rescue.md new file mode 100644 index 0000000..0da95a2 --- /dev/null +++ b/.changeset/light-shoes-rescue.md @@ -0,0 +1,5 @@ +--- +"roo-cline": patch +--- + +Checkbox to disable streaming for OpenAI-compatible providers diff --git a/README.md b/README.md index 64b0117..88cf173 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ A fork of Cline, an autonomous coding agent, with some additional experimental f - Support for Amazon Nova and Meta 3, 3.1, and 3.2 models via AWS Bedrock - Support for Glama - Support for listing models from OpenAI-compatible providers +- Support for adding OpenAI-compatible models with or without streaming - Per-tool MCP auto-approval - Enable/disable individual MCP servers - Enable/disable the MCP feature overall diff --git a/src/api/providers/__tests__/openai.test.ts b/src/api/providers/__tests__/openai.test.ts new file mode 100644 index 0000000..0a88068 --- /dev/null +++ b/src/api/providers/__tests__/openai.test.ts @@ -0,0 +1,192 @@ +import { OpenAiHandler } from '../openai' +import { ApiHandlerOptions, openAiModelInfoSaneDefaults } from '../../../shared/api' +import OpenAI, { AzureOpenAI } from 'openai' +import { Anthropic } from '@anthropic-ai/sdk' + +// Mock dependencies +jest.mock('openai') + +describe('OpenAiHandler', () => { + const mockOptions: ApiHandlerOptions = { + openAiApiKey: 'test-key', + openAiModelId: 'gpt-4', + openAiStreamingEnabled: true, + openAiBaseUrl: 'https://api.openai.com/v1' + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + test('constructor initializes with correct options', () => { + const handler = new OpenAiHandler(mockOptions) + expect(handler).toBeInstanceOf(OpenAiHandler) + expect(OpenAI).toHaveBeenCalledWith({ + apiKey: mockOptions.openAiApiKey, + baseURL: mockOptions.openAiBaseUrl + }) + }) + + test('constructor initializes Azure client when Azure URL is provided', () => { + const azureOptions: ApiHandlerOptions = { + ...mockOptions, + openAiBaseUrl: 'https://example.azure.com', + azureApiVersion: '2023-05-15' + } + const handler = new OpenAiHandler(azureOptions) + expect(handler).toBeInstanceOf(OpenAiHandler) + expect(AzureOpenAI).toHaveBeenCalledWith({ + baseURL: azureOptions.openAiBaseUrl, + apiKey: azureOptions.openAiApiKey, + apiVersion: azureOptions.azureApiVersion + }) + }) + + test('getModel returns correct model info', () => { + const handler = new OpenAiHandler(mockOptions) + const result = handler.getModel() + + expect(result).toEqual({ + id: mockOptions.openAiModelId, + info: openAiModelInfoSaneDefaults + }) + }) + + test('createMessage handles streaming correctly when enabled', async () => { + const handler = new OpenAiHandler({ + ...mockOptions, + openAiStreamingEnabled: true, + includeMaxTokens: true + }) + + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + choices: [{ + delta: { + content: 'test response' + } + }], + usage: { + prompt_tokens: 10, + completion_tokens: 5 + } + } + } + } + + const mockCreate = jest.fn().mockResolvedValue(mockStream) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate } + } as any + + const systemPrompt = 'test system prompt' + const messages: Anthropic.Messages.MessageParam[] = [ + { role: 'user', content: 'test message' } + ] + + const generator = handler.createMessage(systemPrompt, messages) + const chunks = [] + + for await (const chunk of generator) { + chunks.push(chunk) + } + + expect(chunks).toEqual([ + { + type: 'text', + text: 'test response' + }, + { + type: 'usage', + inputTokens: 10, + outputTokens: 5 + } + ]) + + expect(mockCreate).toHaveBeenCalledWith({ + model: mockOptions.openAiModelId, + messages: [ + { role: 'system', content: systemPrompt }, + { role: 'user', content: 'test message' } + ], + temperature: 0, + stream: true, + stream_options: { include_usage: true }, + max_tokens: openAiModelInfoSaneDefaults.maxTokens + }) + }) + + test('createMessage handles non-streaming correctly when disabled', async () => { + const handler = new OpenAiHandler({ + ...mockOptions, + openAiStreamingEnabled: false + }) + + const mockResponse = { + choices: [{ + message: { + content: 'test response' + } + }], + usage: { + prompt_tokens: 10, + completion_tokens: 5 + } + } + + const mockCreate = jest.fn().mockResolvedValue(mockResponse) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate } + } as any + + const systemPrompt = 'test system prompt' + const messages: Anthropic.Messages.MessageParam[] = [ + { role: 'user', content: 'test message' } + ] + + const generator = handler.createMessage(systemPrompt, messages) + const chunks = [] + + for await (const chunk of generator) { + chunks.push(chunk) + } + + expect(chunks).toEqual([ + { + type: 'text', + text: 'test response' + }, + { + type: 'usage', + inputTokens: 10, + outputTokens: 5 + } + ]) + + expect(mockCreate).toHaveBeenCalledWith({ + model: mockOptions.openAiModelId, + messages: [ + { role: 'user', content: systemPrompt }, + { role: 'user', content: 'test message' } + ] + }) + }) + + test('createMessage handles API errors', async () => { + const handler = new OpenAiHandler(mockOptions) + const mockStream = { + async *[Symbol.asyncIterator]() { + throw new Error('API Error') + } + } + + const mockCreate = jest.fn().mockResolvedValue(mockStream) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate } + } as any + + const generator = handler.createMessage('test', []) + await expect(generator.next()).rejects.toThrow('API Error') + }) +}) \ No newline at end of file diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index 071df8d..3ec2192 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -32,42 +32,64 @@ export class OpenAiHandler implements ApiHandler { } } - // Include stream_options for OpenAI Compatible providers if the checkbox is checked async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ - { role: "system", content: systemPrompt }, - ...convertToOpenAiMessages(messages), - ] const modelInfo = this.getModel().info - const requestOptions: OpenAI.Chat.ChatCompletionCreateParams = { - model: this.options.openAiModelId ?? "", - messages: openAiMessages, - temperature: 0, - stream: true, - } - if (this.options.includeMaxTokens) { - requestOptions.max_tokens = modelInfo.maxTokens - } + const modelId = this.options.openAiModelId ?? "" - if (this.options.includeStreamOptions ?? true) { - requestOptions.stream_options = { include_usage: true } - } + if (this.options.openAiStreamingEnabled ?? true) { + const systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = { + role: "system", + content: systemPrompt + } + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { + model: modelId, + temperature: 0, + messages: [systemMessage, ...convertToOpenAiMessages(messages)], + stream: true as const, + stream_options: { include_usage: true }, + } + if (this.options.includeMaxTokens) { + requestOptions.max_tokens = modelInfo.maxTokens + } - const stream = await this.client.chat.completions.create(requestOptions) - for await (const chunk of stream) { - const delta = chunk.choices[0]?.delta - if (delta?.content) { - yield { - type: "text", - text: delta.content, + const stream = await this.client.chat.completions.create(requestOptions) + + for await (const chunk of stream) { + const delta = chunk.choices[0]?.delta + if (delta?.content) { + yield { + type: "text", + text: delta.content, + } + } + if (chunk.usage) { + yield { + type: "usage", + inputTokens: chunk.usage.prompt_tokens || 0, + outputTokens: chunk.usage.completion_tokens || 0, + } } } - if (chunk.usage) { - yield { - type: "usage", - inputTokens: chunk.usage.prompt_tokens || 0, - outputTokens: chunk.usage.completion_tokens || 0, - } + } else { + // o1 for instance doesnt support streaming, non-1 temp, or system prompt + const systemMessage: OpenAI.Chat.ChatCompletionUserMessageParam = { + role: "user", + content: systemPrompt + } + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { + model: modelId, + messages: [systemMessage, ...convertToOpenAiMessages(messages)], + } + const response = await this.client.chat.completions.create(requestOptions) + + yield { + type: "text", + text: response.choices[0]?.message.content || "", + } + yield { + type: "usage", + inputTokens: response.usage?.prompt_tokens || 0, + outputTokens: response.usage?.completion_tokens || 0, } } } diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index a49caee..025cb88 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -66,7 +66,7 @@ type GlobalStateKey = | "lmStudioBaseUrl" | "anthropicBaseUrl" | "azureApiVersion" - | "includeStreamOptions" + | "openAiStreamingEnabled" | "openRouterModelId" | "openRouterModelInfo" | "openRouterUseMiddleOutTransform" @@ -447,7 +447,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { geminiApiKey, openAiNativeApiKey, azureApiVersion, - includeStreamOptions, + openAiStreamingEnabled, openRouterModelId, openRouterModelInfo, openRouterUseMiddleOutTransform, @@ -478,7 +478,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey) await this.storeSecret("deepSeekApiKey", message.apiConfiguration.deepSeekApiKey) await this.updateGlobalState("azureApiVersion", azureApiVersion) - await this.updateGlobalState("includeStreamOptions", includeStreamOptions) + await this.updateGlobalState("openAiStreamingEnabled", openAiStreamingEnabled) await this.updateGlobalState("openRouterModelId", openRouterModelId) await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo) await this.updateGlobalState("openRouterUseMiddleOutTransform", openRouterUseMiddleOutTransform) @@ -1295,7 +1295,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { openAiNativeApiKey, deepSeekApiKey, azureApiVersion, - includeStreamOptions, + openAiStreamingEnabled, openRouterModelId, openRouterModelInfo, openRouterUseMiddleOutTransform, @@ -1345,7 +1345,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { this.getSecret("openAiNativeApiKey") as Promise, this.getSecret("deepSeekApiKey") as Promise, this.getGlobalState("azureApiVersion") as Promise, - this.getGlobalState("includeStreamOptions") as Promise, + this.getGlobalState("openAiStreamingEnabled") as Promise, this.getGlobalState("openRouterModelId") as Promise, this.getGlobalState("openRouterModelInfo") as Promise, this.getGlobalState("openRouterUseMiddleOutTransform") as Promise, @@ -1412,7 +1412,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { openAiNativeApiKey, deepSeekApiKey, azureApiVersion, - includeStreamOptions, + openAiStreamingEnabled, openRouterModelId, openRouterModelInfo, openRouterUseMiddleOutTransform, diff --git a/src/shared/api.ts b/src/shared/api.ts index 2863893..b30c5ed 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -41,7 +41,7 @@ export interface ApiHandlerOptions { openAiNativeApiKey?: string azureApiVersion?: string openRouterUseMiddleOutTransform?: boolean - includeStreamOptions?: boolean + openAiStreamingEnabled?: boolean setAzureApiVersion?: boolean deepSeekBaseUrl?: string deepSeekApiKey?: string diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index ebeab8d..f38a0a3 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -477,21 +477,16 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
{ const isChecked = e.target.checked setApiConfiguration({ ...apiConfiguration, - includeStreamOptions: isChecked + openAiStreamingEnabled: isChecked }) }}> - Include stream options + Enable streaming -