Streaming checkbox for OpenAI-compatible providers

This commit is contained in:
Matt Rubens
2025-01-05 20:14:58 -05:00
parent 376ffa3f2a
commit 2cdfff02c0
4 changed files with 62 additions and 45 deletions

View File

@@ -32,28 +32,28 @@ export class OpenAiHandler implements ApiHandler {
} }
} }
// Include stream_options for OpenAI Compatible providers if the checkbox is checked
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
...convertToOpenAiMessages(messages),
]
const modelInfo = this.getModel().info const modelInfo = this.getModel().info
const requestOptions: OpenAI.Chat.ChatCompletionCreateParams = { const modelId = this.options.openAiModelId ?? ""
model: this.options.openAiModelId ?? "",
messages: openAiMessages, if (this.options.openAiStreamingEnabled ?? true) {
const systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = {
role: "system",
content: systemPrompt
}
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: modelId,
temperature: 0, temperature: 0,
stream: true, messages: [systemMessage, ...convertToOpenAiMessages(messages)],
stream: true as const,
stream_options: { include_usage: true },
} }
if (this.options.includeMaxTokens) { if (this.options.includeMaxTokens) {
requestOptions.max_tokens = modelInfo.maxTokens requestOptions.max_tokens = modelInfo.maxTokens
} }
if (this.options.includeStreamOptions ?? true) {
requestOptions.stream_options = { include_usage: true }
}
const stream = await this.client.chat.completions.create(requestOptions) const stream = await this.client.chat.completions.create(requestOptions)
for await (const chunk of stream) { for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta const delta = chunk.choices[0]?.delta
if (delta?.content) { if (delta?.content) {
@@ -70,6 +70,28 @@ export class OpenAiHandler implements ApiHandler {
} }
} }
} }
} else {
// o1 for instance doesnt support streaming, non-1 temp, or system prompt
const systemMessage: OpenAI.Chat.ChatCompletionUserMessageParam = {
role: "user",
content: systemPrompt
}
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: modelId,
messages: [systemMessage, ...convertToOpenAiMessages(messages)],
}
const response = await this.client.chat.completions.create(requestOptions)
yield {
type: "text",
text: response.choices[0]?.message.content || "",
}
yield {
type: "usage",
inputTokens: response.usage?.prompt_tokens || 0,
outputTokens: response.usage?.completion_tokens || 0,
}
}
} }
getModel(): { id: string; info: ModelInfo } { getModel(): { id: string; info: ModelInfo } {

View File

@@ -66,7 +66,7 @@ type GlobalStateKey =
| "lmStudioBaseUrl" | "lmStudioBaseUrl"
| "anthropicBaseUrl" | "anthropicBaseUrl"
| "azureApiVersion" | "azureApiVersion"
| "includeStreamOptions" | "openAiStreamingEnabled"
| "openRouterModelId" | "openRouterModelId"
| "openRouterModelInfo" | "openRouterModelInfo"
| "openRouterUseMiddleOutTransform" | "openRouterUseMiddleOutTransform"
@@ -447,7 +447,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
geminiApiKey, geminiApiKey,
openAiNativeApiKey, openAiNativeApiKey,
azureApiVersion, azureApiVersion,
includeStreamOptions, openAiStreamingEnabled,
openRouterModelId, openRouterModelId,
openRouterModelInfo, openRouterModelInfo,
openRouterUseMiddleOutTransform, openRouterUseMiddleOutTransform,
@@ -478,7 +478,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey) await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
await this.storeSecret("deepSeekApiKey", message.apiConfiguration.deepSeekApiKey) await this.storeSecret("deepSeekApiKey", message.apiConfiguration.deepSeekApiKey)
await this.updateGlobalState("azureApiVersion", azureApiVersion) await this.updateGlobalState("azureApiVersion", azureApiVersion)
await this.updateGlobalState("includeStreamOptions", includeStreamOptions) await this.updateGlobalState("openAiStreamingEnabled", openAiStreamingEnabled)
await this.updateGlobalState("openRouterModelId", openRouterModelId) await this.updateGlobalState("openRouterModelId", openRouterModelId)
await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo) await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo)
await this.updateGlobalState("openRouterUseMiddleOutTransform", openRouterUseMiddleOutTransform) await this.updateGlobalState("openRouterUseMiddleOutTransform", openRouterUseMiddleOutTransform)
@@ -1295,7 +1295,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
openAiNativeApiKey, openAiNativeApiKey,
deepSeekApiKey, deepSeekApiKey,
azureApiVersion, azureApiVersion,
includeStreamOptions, openAiStreamingEnabled,
openRouterModelId, openRouterModelId,
openRouterModelInfo, openRouterModelInfo,
openRouterUseMiddleOutTransform, openRouterUseMiddleOutTransform,
@@ -1345,7 +1345,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>, this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
this.getSecret("deepSeekApiKey") as Promise<string | undefined>, this.getSecret("deepSeekApiKey") as Promise<string | undefined>,
this.getGlobalState("azureApiVersion") as Promise<string | undefined>, this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
this.getGlobalState("includeStreamOptions") as Promise<boolean | undefined>, this.getGlobalState("openAiStreamingEnabled") as Promise<boolean | undefined>,
this.getGlobalState("openRouterModelId") as Promise<string | undefined>, this.getGlobalState("openRouterModelId") as Promise<string | undefined>,
this.getGlobalState("openRouterModelInfo") as Promise<ModelInfo | undefined>, this.getGlobalState("openRouterModelInfo") as Promise<ModelInfo | undefined>,
this.getGlobalState("openRouterUseMiddleOutTransform") as Promise<boolean | undefined>, this.getGlobalState("openRouterUseMiddleOutTransform") as Promise<boolean | undefined>,
@@ -1412,7 +1412,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
openAiNativeApiKey, openAiNativeApiKey,
deepSeekApiKey, deepSeekApiKey,
azureApiVersion, azureApiVersion,
includeStreamOptions, openAiStreamingEnabled,
openRouterModelId, openRouterModelId,
openRouterModelInfo, openRouterModelInfo,
openRouterUseMiddleOutTransform, openRouterUseMiddleOutTransform,

View File

@@ -41,7 +41,7 @@ export interface ApiHandlerOptions {
openAiNativeApiKey?: string openAiNativeApiKey?: string
azureApiVersion?: string azureApiVersion?: string
openRouterUseMiddleOutTransform?: boolean openRouterUseMiddleOutTransform?: boolean
includeStreamOptions?: boolean openAiStreamingEnabled?: boolean
setAzureApiVersion?: boolean setAzureApiVersion?: boolean
deepSeekBaseUrl?: string deepSeekBaseUrl?: string
deepSeekApiKey?: string deepSeekApiKey?: string

View File

@@ -477,21 +477,16 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
<OpenAiModelPicker /> <OpenAiModelPicker />
<div style={{ display: 'flex', alignItems: 'center' }}> <div style={{ display: 'flex', alignItems: 'center' }}>
<VSCodeCheckbox <VSCodeCheckbox
checked={apiConfiguration?.includeStreamOptions ?? true} checked={apiConfiguration?.openAiStreamingEnabled ?? true}
onChange={(e: any) => { onChange={(e: any) => {
const isChecked = e.target.checked const isChecked = e.target.checked
setApiConfiguration({ setApiConfiguration({
...apiConfiguration, ...apiConfiguration,
includeStreamOptions: isChecked openAiStreamingEnabled: isChecked
}) })
}}> }}>
Include stream options Enable streaming
</VSCodeCheckbox> </VSCodeCheckbox>
<span
className="codicon codicon-info"
title="Stream options are for { include_usage: true }. Some providers may not support this option."
style={{ marginLeft: '5px', cursor: 'help' }}
></span>
</div> </div>
<VSCodeCheckbox <VSCodeCheckbox
checked={azureApiVersionSelected} checked={azureApiVersionSelected}