Merge pull request #282 from RooVetGit/open_ai_streaming_toggle

Streaming checkbox for OpenAI-compatible providers
This commit is contained in:
Matt Rubens
2025-01-05 22:46:50 -05:00
committed by GitHub
7 changed files with 260 additions and 45 deletions

View File

@@ -0,0 +1,5 @@
---
"roo-cline": patch
---
Checkbox to disable streaming for OpenAI-compatible providers

View File

@@ -18,6 +18,7 @@ A fork of Cline, an autonomous coding agent, with some additional experimental f
- Support for Amazon Nova and Meta 3, 3.1, and 3.2 models via AWS Bedrock - Support for Amazon Nova and Meta 3, 3.1, and 3.2 models via AWS Bedrock
- Support for Glama - Support for Glama
- Support for listing models from OpenAI-compatible providers - Support for listing models from OpenAI-compatible providers
- Support for adding OpenAI-compatible models with or without streaming
- Per-tool MCP auto-approval - Per-tool MCP auto-approval
- Enable/disable individual MCP servers - Enable/disable individual MCP servers
- Enable/disable the MCP feature overall - Enable/disable the MCP feature overall

View File

@@ -0,0 +1,192 @@
import { OpenAiHandler } from '../openai'
import { ApiHandlerOptions, openAiModelInfoSaneDefaults } from '../../../shared/api'
import OpenAI, { AzureOpenAI } from 'openai'
import { Anthropic } from '@anthropic-ai/sdk'
// Mock dependencies
jest.mock('openai')
describe('OpenAiHandler', () => {
const mockOptions: ApiHandlerOptions = {
openAiApiKey: 'test-key',
openAiModelId: 'gpt-4',
openAiStreamingEnabled: true,
openAiBaseUrl: 'https://api.openai.com/v1'
}
beforeEach(() => {
jest.clearAllMocks()
})
test('constructor initializes with correct options', () => {
const handler = new OpenAiHandler(mockOptions)
expect(handler).toBeInstanceOf(OpenAiHandler)
expect(OpenAI).toHaveBeenCalledWith({
apiKey: mockOptions.openAiApiKey,
baseURL: mockOptions.openAiBaseUrl
})
})
test('constructor initializes Azure client when Azure URL is provided', () => {
const azureOptions: ApiHandlerOptions = {
...mockOptions,
openAiBaseUrl: 'https://example.azure.com',
azureApiVersion: '2023-05-15'
}
const handler = new OpenAiHandler(azureOptions)
expect(handler).toBeInstanceOf(OpenAiHandler)
expect(AzureOpenAI).toHaveBeenCalledWith({
baseURL: azureOptions.openAiBaseUrl,
apiKey: azureOptions.openAiApiKey,
apiVersion: azureOptions.azureApiVersion
})
})
test('getModel returns correct model info', () => {
const handler = new OpenAiHandler(mockOptions)
const result = handler.getModel()
expect(result).toEqual({
id: mockOptions.openAiModelId,
info: openAiModelInfoSaneDefaults
})
})
test('createMessage handles streaming correctly when enabled', async () => {
const handler = new OpenAiHandler({
...mockOptions,
openAiStreamingEnabled: true,
includeMaxTokens: true
})
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
choices: [{
delta: {
content: 'test response'
}
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5
}
}
}
}
const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
const systemPrompt = 'test system prompt'
const messages: Anthropic.Messages.MessageParam[] = [
{ role: 'user', content: 'test message' }
]
const generator = handler.createMessage(systemPrompt, messages)
const chunks = []
for await (const chunk of generator) {
chunks.push(chunk)
}
expect(chunks).toEqual([
{
type: 'text',
text: 'test response'
},
{
type: 'usage',
inputTokens: 10,
outputTokens: 5
}
])
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.openAiModelId,
messages: [
{ role: 'system', content: systemPrompt },
{ role: 'user', content: 'test message' }
],
temperature: 0,
stream: true,
stream_options: { include_usage: true },
max_tokens: openAiModelInfoSaneDefaults.maxTokens
})
})
test('createMessage handles non-streaming correctly when disabled', async () => {
const handler = new OpenAiHandler({
...mockOptions,
openAiStreamingEnabled: false
})
const mockResponse = {
choices: [{
message: {
content: 'test response'
}
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5
}
}
const mockCreate = jest.fn().mockResolvedValue(mockResponse)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
const systemPrompt = 'test system prompt'
const messages: Anthropic.Messages.MessageParam[] = [
{ role: 'user', content: 'test message' }
]
const generator = handler.createMessage(systemPrompt, messages)
const chunks = []
for await (const chunk of generator) {
chunks.push(chunk)
}
expect(chunks).toEqual([
{
type: 'text',
text: 'test response'
},
{
type: 'usage',
inputTokens: 10,
outputTokens: 5
}
])
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.openAiModelId,
messages: [
{ role: 'user', content: systemPrompt },
{ role: 'user', content: 'test message' }
]
})
})
test('createMessage handles API errors', async () => {
const handler = new OpenAiHandler(mockOptions)
const mockStream = {
async *[Symbol.asyncIterator]() {
throw new Error('API Error')
}
}
const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
const generator = handler.createMessage('test', [])
await expect(generator.next()).rejects.toThrow('API Error')
})
})

View File

@@ -32,42 +32,64 @@ export class OpenAiHandler implements ApiHandler {
} }
} }
// Include stream_options for OpenAI Compatible providers if the checkbox is checked
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
...convertToOpenAiMessages(messages),
]
const modelInfo = this.getModel().info const modelInfo = this.getModel().info
const requestOptions: OpenAI.Chat.ChatCompletionCreateParams = { const modelId = this.options.openAiModelId ?? ""
model: this.options.openAiModelId ?? "",
messages: openAiMessages,
temperature: 0,
stream: true,
}
if (this.options.includeMaxTokens) {
requestOptions.max_tokens = modelInfo.maxTokens
}
if (this.options.includeStreamOptions ?? true) { if (this.options.openAiStreamingEnabled ?? true) {
requestOptions.stream_options = { include_usage: true } const systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = {
} role: "system",
content: systemPrompt
}
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: modelId,
temperature: 0,
messages: [systemMessage, ...convertToOpenAiMessages(messages)],
stream: true as const,
stream_options: { include_usage: true },
}
if (this.options.includeMaxTokens) {
requestOptions.max_tokens = modelInfo.maxTokens
}
const stream = await this.client.chat.completions.create(requestOptions) const stream = await this.client.chat.completions.create(requestOptions)
for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta for await (const chunk of stream) {
if (delta?.content) { const delta = chunk.choices[0]?.delta
yield { if (delta?.content) {
type: "text", yield {
text: delta.content, type: "text",
text: delta.content,
}
}
if (chunk.usage) {
yield {
type: "usage",
inputTokens: chunk.usage.prompt_tokens || 0,
outputTokens: chunk.usage.completion_tokens || 0,
}
} }
} }
if (chunk.usage) { } else {
yield { // o1 for instance doesnt support streaming, non-1 temp, or system prompt
type: "usage", const systemMessage: OpenAI.Chat.ChatCompletionUserMessageParam = {
inputTokens: chunk.usage.prompt_tokens || 0, role: "user",
outputTokens: chunk.usage.completion_tokens || 0, content: systemPrompt
} }
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: modelId,
messages: [systemMessage, ...convertToOpenAiMessages(messages)],
}
const response = await this.client.chat.completions.create(requestOptions)
yield {
type: "text",
text: response.choices[0]?.message.content || "",
}
yield {
type: "usage",
inputTokens: response.usage?.prompt_tokens || 0,
outputTokens: response.usage?.completion_tokens || 0,
} }
} }
} }

View File

@@ -66,7 +66,7 @@ type GlobalStateKey =
| "lmStudioBaseUrl" | "lmStudioBaseUrl"
| "anthropicBaseUrl" | "anthropicBaseUrl"
| "azureApiVersion" | "azureApiVersion"
| "includeStreamOptions" | "openAiStreamingEnabled"
| "openRouterModelId" | "openRouterModelId"
| "openRouterModelInfo" | "openRouterModelInfo"
| "openRouterUseMiddleOutTransform" | "openRouterUseMiddleOutTransform"
@@ -447,7 +447,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
geminiApiKey, geminiApiKey,
openAiNativeApiKey, openAiNativeApiKey,
azureApiVersion, azureApiVersion,
includeStreamOptions, openAiStreamingEnabled,
openRouterModelId, openRouterModelId,
openRouterModelInfo, openRouterModelInfo,
openRouterUseMiddleOutTransform, openRouterUseMiddleOutTransform,
@@ -478,7 +478,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey) await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
await this.storeSecret("deepSeekApiKey", message.apiConfiguration.deepSeekApiKey) await this.storeSecret("deepSeekApiKey", message.apiConfiguration.deepSeekApiKey)
await this.updateGlobalState("azureApiVersion", azureApiVersion) await this.updateGlobalState("azureApiVersion", azureApiVersion)
await this.updateGlobalState("includeStreamOptions", includeStreamOptions) await this.updateGlobalState("openAiStreamingEnabled", openAiStreamingEnabled)
await this.updateGlobalState("openRouterModelId", openRouterModelId) await this.updateGlobalState("openRouterModelId", openRouterModelId)
await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo) await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo)
await this.updateGlobalState("openRouterUseMiddleOutTransform", openRouterUseMiddleOutTransform) await this.updateGlobalState("openRouterUseMiddleOutTransform", openRouterUseMiddleOutTransform)
@@ -1295,7 +1295,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
openAiNativeApiKey, openAiNativeApiKey,
deepSeekApiKey, deepSeekApiKey,
azureApiVersion, azureApiVersion,
includeStreamOptions, openAiStreamingEnabled,
openRouterModelId, openRouterModelId,
openRouterModelInfo, openRouterModelInfo,
openRouterUseMiddleOutTransform, openRouterUseMiddleOutTransform,
@@ -1345,7 +1345,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>, this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
this.getSecret("deepSeekApiKey") as Promise<string | undefined>, this.getSecret("deepSeekApiKey") as Promise<string | undefined>,
this.getGlobalState("azureApiVersion") as Promise<string | undefined>, this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
this.getGlobalState("includeStreamOptions") as Promise<boolean | undefined>, this.getGlobalState("openAiStreamingEnabled") as Promise<boolean | undefined>,
this.getGlobalState("openRouterModelId") as Promise<string | undefined>, this.getGlobalState("openRouterModelId") as Promise<string | undefined>,
this.getGlobalState("openRouterModelInfo") as Promise<ModelInfo | undefined>, this.getGlobalState("openRouterModelInfo") as Promise<ModelInfo | undefined>,
this.getGlobalState("openRouterUseMiddleOutTransform") as Promise<boolean | undefined>, this.getGlobalState("openRouterUseMiddleOutTransform") as Promise<boolean | undefined>,
@@ -1412,7 +1412,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
openAiNativeApiKey, openAiNativeApiKey,
deepSeekApiKey, deepSeekApiKey,
azureApiVersion, azureApiVersion,
includeStreamOptions, openAiStreamingEnabled,
openRouterModelId, openRouterModelId,
openRouterModelInfo, openRouterModelInfo,
openRouterUseMiddleOutTransform, openRouterUseMiddleOutTransform,

View File

@@ -41,7 +41,7 @@ export interface ApiHandlerOptions {
openAiNativeApiKey?: string openAiNativeApiKey?: string
azureApiVersion?: string azureApiVersion?: string
openRouterUseMiddleOutTransform?: boolean openRouterUseMiddleOutTransform?: boolean
includeStreamOptions?: boolean openAiStreamingEnabled?: boolean
setAzureApiVersion?: boolean setAzureApiVersion?: boolean
deepSeekBaseUrl?: string deepSeekBaseUrl?: string
deepSeekApiKey?: string deepSeekApiKey?: string

View File

@@ -477,21 +477,16 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
<OpenAiModelPicker /> <OpenAiModelPicker />
<div style={{ display: 'flex', alignItems: 'center' }}> <div style={{ display: 'flex', alignItems: 'center' }}>
<VSCodeCheckbox <VSCodeCheckbox
checked={apiConfiguration?.includeStreamOptions ?? true} checked={apiConfiguration?.openAiStreamingEnabled ?? true}
onChange={(e: any) => { onChange={(e: any) => {
const isChecked = e.target.checked const isChecked = e.target.checked
setApiConfiguration({ setApiConfiguration({
...apiConfiguration, ...apiConfiguration,
includeStreamOptions: isChecked openAiStreamingEnabled: isChecked
}) })
}}> }}>
Include stream options Enable streaming
</VSCodeCheckbox> </VSCodeCheckbox>
<span
className="codicon codicon-info"
title="Stream options are for { include_usage: true }. Some providers may not support this option."
style={{ marginLeft: '5px', cursor: 'help' }}
></span>
</div> </div>
<VSCodeCheckbox <VSCodeCheckbox
checked={azureApiVersionSelected} checked={azureApiVersionSelected}