From c209198b233123274387b6ce72f0ba0ea78fc4ab Mon Sep 17 00:00:00 2001 From: Saoud Rizwan <7799382+saoudrizwan@users.noreply.github.com> Date: Tue, 3 Sep 2024 17:08:29 -0400 Subject: [PATCH] Add openai compatible provider --- src/api/index.ts | 7 +- src/api/openai.ts | 74 +++++++++ src/api/openrouter.ts | 54 +------ src/providers/ClaudeDevProvider.ts | 24 ++- src/shared/api.ts | 18 ++- src/utils/openai-format.ts | 57 +++++++ webview-ui/src/components/ApiOptions.tsx | 62 +++++++- webview-ui/src/components/ChatView.tsx | 3 - webview-ui/src/components/HistoryPreview.tsx | 10 +- webview-ui/src/components/HistoryView.tsx | 143 ++++++++++-------- webview-ui/src/components/SettingsView.tsx | 7 +- webview-ui/src/components/TaskHeader.tsx | 92 ++++++----- .../src/context/ExtensionStateContext.tsx | 10 +- webview-ui/src/utils/validate.ts | 9 ++ 14 files changed, 383 insertions(+), 187 deletions(-) create mode 100644 src/api/openai.ts diff --git a/src/api/index.ts b/src/api/index.ts index bbafc1f..59b2bbc 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -1,9 +1,10 @@ import { Anthropic } from "@anthropic-ai/sdk" -import { ApiConfiguration, ApiModelId, ModelInfo } from "../shared/api" +import { ApiConfiguration, ModelInfo } from "../shared/api" import { AnthropicHandler } from "./anthropic" import { AwsBedrockHandler } from "./bedrock" import { OpenRouterHandler } from "./openrouter" import { VertexHandler } from "./vertex" +import { OpenAiHandler } from "./openai" export interface ApiHandlerMessageResponse { message: Anthropic.Messages.Message @@ -26,7 +27,7 @@ export interface ApiHandler { > ): any - getModel(): { id: ApiModelId; info: ModelInfo } + getModel(): { id: string; info: ModelInfo } } export function buildApiHandler(configuration: ApiConfiguration): ApiHandler { @@ -40,6 +41,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler { return new AwsBedrockHandler(options) case "vertex": return new VertexHandler(options) + case "openai": + return new OpenAiHandler(options) default: return new AnthropicHandler(options) } diff --git a/src/api/openai.ts b/src/api/openai.ts new file mode 100644 index 0000000..afec4b4 --- /dev/null +++ b/src/api/openai.ts @@ -0,0 +1,74 @@ +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" +import { ApiHandler, ApiHandlerMessageResponse, withoutImageData } from "." +import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../shared/api" +import { convertToAnthropicMessage, convertToOpenAiMessages } from "../utils/openai-format" + +export class OpenAiHandler implements ApiHandler { + private options: ApiHandlerOptions + private client: OpenAI + + constructor(options: ApiHandlerOptions) { + this.options = options + this.client = new OpenAI({ + baseURL: this.options.openAiBaseUrl, + apiKey: this.options.openAiApiKey, + }) + } + + async createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + tools: Anthropic.Messages.Tool[] + ): Promise { + const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ + { role: "system", content: systemPrompt }, + ...convertToOpenAiMessages(messages), + ] + const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({ + type: "function", + function: { + name: tool.name, + description: tool.description, + parameters: tool.input_schema, + }, + })) + const createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { + model: this.options.openAiModelId ?? "", + messages: openAiMessages, + tools: openAiTools, + tool_choice: "auto", + } + const completion = await this.client.chat.completions.create(createParams) + const errorMessage = (completion as any).error?.message + if (errorMessage) { + throw new Error(errorMessage) + } + const anthropicMessage = convertToAnthropicMessage(completion) + return { message: anthropicMessage } + } + + createUserReadableRequest( + userContent: Array< + | Anthropic.TextBlockParam + | Anthropic.ImageBlockParam + | Anthropic.ToolUseBlockParam + | Anthropic.ToolResultBlockParam + > + ): any { + return { + model: this.options.openAiModelId ?? "", + system: "(see SYSTEM_PROMPT in src/ClaudeDev.ts)", + messages: [{ conversation_history: "..." }, { role: "user", content: withoutImageData(userContent) }], + tools: "(see tools in src/ClaudeDev.ts)", + tool_choice: "auto", + } + } + + getModel(): { id: string; info: ModelInfo } { + return { + id: this.options.openAiModelId ?? "", + info: openAiModelInfoSaneDefaults, + } + } +} diff --git a/src/api/openrouter.ts b/src/api/openrouter.ts index 992181d..6970540 100644 --- a/src/api/openrouter.ts +++ b/src/api/openrouter.ts @@ -8,7 +8,7 @@ import { OpenRouterModelId, openRouterModels, } from "../shared/api" -import { convertToOpenAiMessages } from "../utils/openai-format" +import { convertToAnthropicMessage, convertToOpenAiMessages } from "../utils/openai-format" export class OpenRouterHandler implements ApiHandler { private options: ApiHandlerOptions @@ -68,57 +68,7 @@ export class OpenRouterHandler implements ApiHandler { throw new Error(errorMessage) } - // Convert OpenAI response to Anthropic format - const openAiMessage = completion.choices[0].message - const anthropicMessage: Anthropic.Messages.Message = { - id: completion.id, - type: "message", - role: openAiMessage.role, // always "assistant" - content: [ - { - type: "text", - text: openAiMessage.content || "", - }, - ], - model: completion.model, - stop_reason: (() => { - switch (completion.choices[0].finish_reason) { - case "stop": - return "end_turn" - case "length": - return "max_tokens" - case "tool_calls": - return "tool_use" - case "content_filter": // Anthropic doesn't have an exact equivalent - default: - return null - } - })(), - stop_sequence: null, // which custom stop_sequence was generated, if any (not applicable if you don't use stop_sequence) - usage: { - input_tokens: completion.usage?.prompt_tokens || 0, - output_tokens: completion.usage?.completion_tokens || 0, - }, - } - - if (openAiMessage.tool_calls && openAiMessage.tool_calls.length > 0) { - anthropicMessage.content.push( - ...openAiMessage.tool_calls.map((toolCall): Anthropic.ToolUseBlock => { - let parsedInput = {} - try { - parsedInput = JSON.parse(toolCall.function.arguments || "{}") - } catch (error) { - console.error("Failed to parse tool arguments:", error) - } - return { - type: "tool_use", - id: toolCall.id, - name: toolCall.function.name, - input: parsedInput, - } - }) - ) - } + const anthropicMessage = convertToAnthropicMessage(completion) return { message: anthropicMessage } } diff --git a/src/providers/ClaudeDevProvider.ts b/src/providers/ClaudeDevProvider.ts index 9c02256..4ff8349 100644 --- a/src/providers/ClaudeDevProvider.ts +++ b/src/providers/ClaudeDevProvider.ts @@ -1,7 +1,7 @@ import { Anthropic } from "@anthropic-ai/sdk" import * as vscode from "vscode" import { ClaudeDev } from "../ClaudeDev" -import { ApiModelId, ApiProvider } from "../shared/api" +import { ApiProvider } from "../shared/api" import { ExtensionMessage } from "../shared/ExtensionMessage" import { WebviewMessage } from "../shared/WebviewMessage" import { downloadTask, findLast, getNonce, getUri, selectImages } from "../utils" @@ -16,7 +16,7 @@ https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default https://github.com/KumarVariable/vscode-extension-sidebar-html/blob/master/src/customSidebarViewProvider.ts */ -type SecretKey = "apiKey" | "openRouterApiKey" | "awsAccessKey" | "awsSecretKey" | "awsSessionToken" +type SecretKey = "apiKey" | "openRouterApiKey" | "awsAccessKey" | "awsSecretKey" | "awsSessionToken" | "openAiApiKey" type GlobalStateKey = | "apiProvider" | "apiModelId" @@ -27,6 +27,8 @@ type GlobalStateKey = | "customInstructions" | "alwaysAllowReadOnly" | "taskHistory" + | "openAiBaseUrl" + | "openAiModelId" export class ClaudeDevProvider implements vscode.WebviewViewProvider { public static readonly sideBarId = "claude-dev.SidebarProvider" // used in package.json as the view's id. This value cannot be changed due to how vscode caches views based on their id, and updating the id would break existing instances of the extension. @@ -314,6 +316,9 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { awsRegion, vertexProjectId, vertexRegion, + openAiBaseUrl, + openAiApiKey, + openAiModelId, } = message.apiConfiguration await this.updateGlobalState("apiProvider", apiProvider) await this.updateGlobalState("apiModelId", apiModelId) @@ -325,6 +330,9 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { await this.updateGlobalState("awsRegion", awsRegion) await this.updateGlobalState("vertexProjectId", vertexProjectId) await this.updateGlobalState("vertexRegion", vertexRegion) + await this.updateGlobalState("openAiBaseUrl", openAiBaseUrl) + await this.storeSecret("openAiApiKey", openAiApiKey) + await this.updateGlobalState("openAiModelId", openAiModelId) this.claudeDev?.updateApi(message.apiConfiguration) } await this.postStateToWebview() @@ -615,13 +623,16 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { awsRegion, vertexProjectId, vertexRegion, + openAiBaseUrl, + openAiApiKey, + openAiModelId, lastShownAnnouncementId, customInstructions, alwaysAllowReadOnly, taskHistory, ] = await Promise.all([ this.getGlobalState("apiProvider") as Promise, - this.getGlobalState("apiModelId") as Promise, + this.getGlobalState("apiModelId") as Promise, this.getSecret("apiKey") as Promise, this.getSecret("openRouterApiKey") as Promise, this.getSecret("awsAccessKey") as Promise, @@ -630,6 +641,9 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { this.getGlobalState("awsRegion") as Promise, this.getGlobalState("vertexProjectId") as Promise, this.getGlobalState("vertexRegion") as Promise, + this.getGlobalState("openAiBaseUrl") as Promise, + this.getSecret("openAiApiKey") as Promise, + this.getGlobalState("openAiModelId") as Promise, this.getGlobalState("lastShownAnnouncementId") as Promise, this.getGlobalState("customInstructions") as Promise, this.getGlobalState("alwaysAllowReadOnly") as Promise, @@ -662,6 +676,9 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { awsRegion, vertexProjectId, vertexRegion, + openAiBaseUrl, + openAiApiKey, + openAiModelId, }, lastShownAnnouncementId, customInstructions, @@ -739,6 +756,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { "awsAccessKey", "awsSecretKey", "awsSessionToken", + "openAiApiKey", ] for (const key of secretKeys) { await this.storeSecret(key, undefined) diff --git a/src/shared/api.ts b/src/shared/api.ts index 38e926f..6c2a3de 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -1,7 +1,7 @@ -export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" +export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" | "openai" export interface ApiHandlerOptions { - apiModelId?: ApiModelId + apiModelId?: string apiKey?: string // anthropic openRouterApiKey?: string awsAccessKey?: string @@ -10,6 +10,9 @@ export interface ApiHandlerOptions { awsRegion?: string vertexProjectId?: string vertexRegion?: string + openAiBaseUrl?: string + openAiApiKey?: string + openAiModelId?: string } export type ApiConfiguration = ApiHandlerOptions & { @@ -29,8 +32,6 @@ export interface ModelInfo { cacheReadsPrice?: number } -export type ApiModelId = AnthropicModelId | OpenRouterModelId | BedrockModelId | VertexModelId - // Anthropic // https://docs.anthropic.com/en/docs/about-claude/models export type AnthropicModelId = keyof typeof anthropicModels @@ -292,3 +293,12 @@ export const vertexModels = { outputPrice: 1.25, }, } as const satisfies Record + +export const openAiModelInfoSaneDefaults: ModelInfo = { + maxTokens: -1, + contextWindow: 128_000, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, +} diff --git a/src/utils/openai-format.ts b/src/utils/openai-format.ts index 91cd837..ead1bc0 100644 --- a/src/utils/openai-format.ts +++ b/src/utils/openai-format.ts @@ -142,3 +142,60 @@ export function convertToOpenAiMessages( return openAiMessages } + +// Convert OpenAI response to Anthropic format +export function convertToAnthropicMessage( + completion: OpenAI.Chat.Completions.ChatCompletion +): Anthropic.Messages.Message { + const openAiMessage = completion.choices[0].message + const anthropicMessage: Anthropic.Messages.Message = { + id: completion.id, + type: "message", + role: openAiMessage.role, // always "assistant" + content: [ + { + type: "text", + text: openAiMessage.content || "", + }, + ], + model: completion.model, + stop_reason: (() => { + switch (completion.choices[0].finish_reason) { + case "stop": + return "end_turn" + case "length": + return "max_tokens" + case "tool_calls": + return "tool_use" + case "content_filter": // Anthropic doesn't have an exact equivalent + default: + return null + } + })(), + stop_sequence: null, // which custom stop_sequence was generated, if any (not applicable if you don't use stop_sequence) + usage: { + input_tokens: completion.usage?.prompt_tokens || 0, + output_tokens: completion.usage?.completion_tokens || 0, + }, + } + + if (openAiMessage.tool_calls && openAiMessage.tool_calls.length > 0) { + anthropicMessage.content.push( + ...openAiMessage.tool_calls.map((toolCall): Anthropic.ToolUseBlock => { + let parsedInput = {} + try { + parsedInput = JSON.parse(toolCall.function.arguments || "{}") + } catch (error) { + console.error("Failed to parse tool arguments:", error) + } + return { + type: "tool_use", + id: toolCall.id, + name: toolCall.function.name, + input: parsedInput, + } + }) + ) + } + return anthropicMessage +} diff --git a/webview-ui/src/components/ApiOptions.tsx b/webview-ui/src/components/ApiOptions.tsx index e0d657d..327ac41 100644 --- a/webview-ui/src/components/ApiOptions.tsx +++ b/webview-ui/src/components/ApiOptions.tsx @@ -2,12 +2,12 @@ import { VSCodeDropdown, VSCodeLink, VSCodeOption, VSCodeTextField } from "@vsco import React, { useMemo } from "react" import { ApiConfiguration, - ApiModelId, ModelInfo, anthropicDefaultModelId, anthropicModels, bedrockDefaultModelId, bedrockModels, + openAiModelInfoSaneDefaults, openRouterDefaultModelId, openRouterModels, vertexDefaultModelId, @@ -69,11 +69,16 @@ const ApiOptions: React.FC = ({ showModelOptions, apiErrorMessa - + Anthropic OpenRouter AWS Bedrock GCP Vertex AI + OpenAI Compatible @@ -256,6 +261,47 @@ const ApiOptions: React.FC = ({ showModelOptions, apiErrorMessa )} + {selectedProvider === "openai" && ( +
+ + Base URL + + + API Key + + + Model ID + +

+ You can use any OpenAI compatible API with models that support tool use.{" "} + + (Note: Claude Dev uses complex prompts, so results + may vary depending on the quality of the model you choose. Less capable models may not work + as expected.) + +

+
+ )} + {apiErrorMessage && (

= ({ showModelOptions, apiErrorMessa

)} - {showModelOptions && ( + {selectedProvider !== "openai" && showModelOptions && ( <>
diff --git a/webview-ui/src/components/HistoryView.tsx b/webview-ui/src/components/HistoryView.tsx index 14ab42e..5f1f89d 100644 --- a/webview-ui/src/components/HistoryView.tsx +++ b/webview-ui/src/components/HistoryView.tsx @@ -63,6 +63,17 @@ const HistoryView = ({ onDone }: HistoryViewProps) => { ) } + const ExportButton = ({ itemId }: { itemId: string }) => ( + { + e.stopPropagation() + handleExportMd(itemId) + }}> +
EXPORT .MD
+
+ ) + return ( <>