From 4b44e8f9217e91e8a27c747c6808204f8e549642 Mon Sep 17 00:00:00 2001 From: Saoud Rizwan <7799382+saoudrizwan@users.noreply.github.com> Date: Thu, 12 Sep 2024 15:01:28 -0400 Subject: [PATCH] Add OpenAI provider --- src/api/index.ts | 3 + src/api/openai-native.ts | 65 +++++++++++++++++++ src/providers/ClaudeDevProvider.ts | 7 ++ src/shared/api.ts | 50 +++++++++++++- webview-ui/src/components/ApiOptions.tsx | 56 ++++++++++++++-- .../src/context/ExtensionStateContext.tsx | 1 + webview-ui/src/utils/validate.ts | 5 ++ 7 files changed, 182 insertions(+), 5 deletions(-) create mode 100644 src/api/openai-native.ts diff --git a/src/api/index.ts b/src/api/index.ts index e741bc0..e0e4108 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -7,6 +7,7 @@ import { VertexHandler } from "./vertex" import { OpenAiHandler } from "./openai" import { OllamaHandler } from "./ollama" import { GeminiHandler } from "./gemini" +import { OpenAiNativeHandler } from "./openai-native" export interface ApiHandlerMessageResponse { message: Anthropic.Messages.Message @@ -40,6 +41,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler { return new OllamaHandler(options) case "gemini": return new GeminiHandler(options) + case "openai-native": + return new OpenAiNativeHandler(options) default: return new AnthropicHandler(options) } diff --git a/src/api/openai-native.ts b/src/api/openai-native.ts new file mode 100644 index 0000000..421d537 --- /dev/null +++ b/src/api/openai-native.ts @@ -0,0 +1,65 @@ +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" +import { ApiHandler, ApiHandlerMessageResponse } from "." +import { + ApiHandlerOptions, + ModelInfo, + openAiNativeDefaultModelId, + OpenAiNativeModelId, + openAiNativeModels, +} from "../shared/api" +import { convertToAnthropicMessage, convertToOpenAiMessages } from "../utils/openai-format" + +export class OpenAiNativeHandler implements ApiHandler { + private options: ApiHandlerOptions + private client: OpenAI + + constructor(options: ApiHandlerOptions) { + this.options = options + this.client = new OpenAI({ + apiKey: this.options.openAiNativeApiKey, + }) + } + + async createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + tools: Anthropic.Messages.Tool[] + ): Promise { + const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ + { role: "system", content: systemPrompt }, + ...convertToOpenAiMessages(messages), + ] + const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({ + type: "function", + function: { + name: tool.name, + description: tool.description, + parameters: tool.input_schema, + }, + })) + const createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { + model: this.getModel().id, + max_tokens: this.getModel().info.maxTokens, + messages: openAiMessages, + tools: openAiTools, + tool_choice: "auto", + } + const completion = await this.client.chat.completions.create(createParams) + const errorMessage = (completion as any).error?.message + if (errorMessage) { + throw new Error(errorMessage) + } + const anthropicMessage = convertToAnthropicMessage(completion) + return { message: anthropicMessage } + } + + getModel(): { id: OpenAiNativeModelId; info: ModelInfo } { + const modelId = this.options.apiModelId + if (modelId && modelId in openAiNativeModels) { + const id = modelId as OpenAiNativeModelId + return { id, info: openAiNativeModels[id] } + } + return { id: openAiNativeDefaultModelId, info: openAiNativeModels[openAiNativeDefaultModelId] } + } +} diff --git a/src/providers/ClaudeDevProvider.ts b/src/providers/ClaudeDevProvider.ts index ede6a00..5cdc075 100644 --- a/src/providers/ClaudeDevProvider.ts +++ b/src/providers/ClaudeDevProvider.ts @@ -26,6 +26,7 @@ type SecretKey = | "awsSessionToken" | "openAiApiKey" | "geminiApiKey" + | "openAiNativeApiKey" type GlobalStateKey = | "apiProvider" | "apiModelId" @@ -337,6 +338,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { ollamaBaseUrl, anthropicBaseUrl, geminiApiKey, + openAiNativeApiKey, } = message.apiConfiguration await this.updateGlobalState("apiProvider", apiProvider) await this.updateGlobalState("apiModelId", apiModelId) @@ -355,6 +357,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { await this.updateGlobalState("ollamaBaseUrl", ollamaBaseUrl) await this.updateGlobalState("anthropicBaseUrl", anthropicBaseUrl) await this.storeSecret("geminiApiKey", geminiApiKey) + await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey) this.claudeDev?.updateApi(message.apiConfiguration) } await this.postStateToWebview() @@ -677,6 +680,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { ollamaBaseUrl, anthropicBaseUrl, geminiApiKey, + openAiNativeApiKey, lastShownAnnouncementId, customInstructions, alwaysAllowReadOnly, @@ -699,6 +703,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { this.getGlobalState("ollamaBaseUrl") as Promise, this.getGlobalState("anthropicBaseUrl") as Promise, this.getSecret("geminiApiKey") as Promise, + this.getSecret("openAiNativeApiKey") as Promise, this.getGlobalState("lastShownAnnouncementId") as Promise, this.getGlobalState("customInstructions") as Promise, this.getGlobalState("alwaysAllowReadOnly") as Promise, @@ -738,6 +743,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { ollamaBaseUrl, anthropicBaseUrl, geminiApiKey, + openAiNativeApiKey, }, lastShownAnnouncementId, customInstructions, @@ -817,6 +823,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { "awsSessionToken", "openAiApiKey", "geminiApiKey", + "openAiNativeApiKey", ] for (const key of secretKeys) { await this.storeSecret(key, undefined) diff --git a/src/shared/api.ts b/src/shared/api.ts index a803a11..d4dfc5c 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -1,4 +1,12 @@ -export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" | "openai" | "ollama" | "gemini" +export type ApiProvider = + | "anthropic" + | "openrouter" + | "bedrock" + | "vertex" + | "openai" + | "ollama" + | "gemini" + | "openai-native" export interface ApiHandlerOptions { apiModelId?: string @@ -17,6 +25,7 @@ export interface ApiHandlerOptions { ollamaModelId?: string ollamaBaseUrl?: string geminiApiKey?: string + openAiNativeApiKey?: string } export type ApiConfiguration = ApiHandlerOptions & { @@ -334,3 +343,42 @@ export const geminiModels = { outputPrice: 0, }, } as const satisfies Record + +// OpenAI Native +// https://openai.com/api/pricing/ +export type OpenAiNativeModelId = keyof typeof openAiNativeModels +export const openAiNativeDefaultModelId: OpenAiNativeModelId = "o1-preview" +export const openAiNativeModels = { + "o1-preview": { + maxTokens: 32_768, + contextWindow: 128_000, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 15, + outputPrice: 60, + }, + "o1-mini": { + maxTokens: 65_536, + contextWindow: 128_000, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 3, + outputPrice: 12, + }, + "gpt-4o": { + maxTokens: 4_096, + contextWindow: 128_000, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 5, + outputPrice: 15, + }, + "gpt-4o-mini": { + maxTokens: 16_384, + contextWindow: 128_000, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0.15, + outputPrice: 0.6, + }, +} as const satisfies Record diff --git a/webview-ui/src/components/ApiOptions.tsx b/webview-ui/src/components/ApiOptions.tsx index c1f0c8e..d2c809b 100644 --- a/webview-ui/src/components/ApiOptions.tsx +++ b/webview-ui/src/components/ApiOptions.tsx @@ -12,6 +12,7 @@ import { useEvent, useInterval } from "react-use" import { ApiConfiguration, ModelInfo, + OpenAiNativeModelId, anthropicDefaultModelId, anthropicModels, bedrockDefaultModelId, @@ -19,6 +20,8 @@ import { geminiDefaultModelId, geminiModels, openAiModelInfoSaneDefaults, + openAiNativeDefaultModelId, + openAiNativeModels, openRouterDefaultModelId, openRouterModels, vertexDefaultModelId, @@ -112,10 +115,11 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => { onChange={handleInputChange("apiProvider")} style={{ minWidth: 130 }}> Anthropic + OpenAI OpenRouter + Google Gemini AWS Bedrock GCP Vertex AI - Google Gemini OpenAI Compatible Ollama @@ -174,6 +178,34 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => { )} + {selectedProvider === "openai-native" && ( +
+ + OpenAI API Key + +

+ This key is stored locally and only used to make API requests from this extension. + {!apiConfiguration?.openAiNativeApiKey && ( + + You can get an OpenAI API key by signing up here. + + )} +

+
+ )} + {selectedProvider === "openrouter" && (
{ {selectedProvider === "bedrock" && createDropdown(bedrockModels)} {selectedProvider === "vertex" && createDropdown(vertexModels)} {selectedProvider === "gemini" && createDropdown(geminiModels)} + {selectedProvider === "openai-native" && createDropdown(openAiNativeModels)}
@@ -514,6 +547,7 @@ export const formatPrice = (price: number) => { const ModelInfoView = ({ selectedModelId, modelInfo }: { selectedModelId: string; modelInfo: ModelInfo }) => { const isGemini = Object.keys(geminiModels).includes(selectedModelId) + const isO1 = (["o1-preview", "o1-mini"] as OpenAiNativeModelId[]).includes(selectedModelId as OpenAiNativeModelId) return (

)} Max output: {modelInfo?.maxTokens?.toLocaleString()} tokens -
{modelInfo.inputPrice > 0 && ( <> +
Input price: {formatPrice(modelInfo.inputPrice)}/million tokens -
)} {modelInfo.supportsPromptCache && modelInfo.cacheWritesPrice && modelInfo.cacheReadsPrice && ( <> +
Cache writes price:{" "} {formatPrice(modelInfo.cacheWritesPrice || 0)}/million tokens
Cache reads price:{" "} {formatPrice(modelInfo.cacheReadsPrice || 0)}/million tokens -
)} {modelInfo.outputPrice > 0 && ( <> +
Output price: {formatPrice(modelInfo.outputPrice)}/million tokens )} {isGemini && ( <> +
)} + {isO1 && ( + <> +
+ + * This model is newly released and may not be accessible to all users yet. + + + )}

) } @@ -632,6 +678,8 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) { return getProviderData(vertexModels, vertexDefaultModelId) case "gemini": return getProviderData(geminiModels, geminiDefaultModelId) + case "openai-native": + return getProviderData(openAiNativeModels, openAiNativeDefaultModelId) case "openai": return { selectedProvider: provider, diff --git a/webview-ui/src/context/ExtensionStateContext.tsx b/webview-ui/src/context/ExtensionStateContext.tsx index cb9c970..791307c 100644 --- a/webview-ui/src/context/ExtensionStateContext.tsx +++ b/webview-ui/src/context/ExtensionStateContext.tsx @@ -41,6 +41,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode config.openAiApiKey, config.ollamaModelId, config.geminiApiKey, + config.openAiNativeApiKey, ].some((key) => key !== undefined) : false setShowWelcome(!hasKey) diff --git a/webview-ui/src/utils/validate.ts b/webview-ui/src/utils/validate.ts index 5e3dd1b..06ed416 100644 --- a/webview-ui/src/utils/validate.ts +++ b/webview-ui/src/utils/validate.ts @@ -28,6 +28,11 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s return "You must provide a valid API key or choose a different provider." } break + case "openai-native": + if (!apiConfiguration.openAiNativeApiKey) { + return "You must provide a valid API key or choose a different provider." + } + break case "openai": if ( !apiConfiguration.openAiBaseUrl ||