Add OpenAI provider

This commit is contained in:
Saoud Rizwan
2024-09-12 15:01:28 -04:00
parent cb8ce1685f
commit 4b44e8f921
7 changed files with 182 additions and 5 deletions

View File

@@ -7,6 +7,7 @@ import { VertexHandler } from "./vertex"
import { OpenAiHandler } from "./openai"
import { OllamaHandler } from "./ollama"
import { GeminiHandler } from "./gemini"
import { OpenAiNativeHandler } from "./openai-native"
export interface ApiHandlerMessageResponse {
message: Anthropic.Messages.Message
@@ -40,6 +41,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
return new OllamaHandler(options)
case "gemini":
return new GeminiHandler(options)
case "openai-native":
return new OpenAiNativeHandler(options)
default:
return new AnthropicHandler(options)
}

65
src/api/openai-native.ts Normal file
View File

@@ -0,0 +1,65 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import { ApiHandler, ApiHandlerMessageResponse } from "."
import {
ApiHandlerOptions,
ModelInfo,
openAiNativeDefaultModelId,
OpenAiNativeModelId,
openAiNativeModels,
} from "../shared/api"
import { convertToAnthropicMessage, convertToOpenAiMessages } from "../utils/openai-format"
export class OpenAiNativeHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: OpenAI
constructor(options: ApiHandlerOptions) {
this.options = options
this.client = new OpenAI({
apiKey: this.options.openAiNativeApiKey,
})
}
async createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<ApiHandlerMessageResponse> {
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
...convertToOpenAiMessages(messages),
]
const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
},
}))
const createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: this.getModel().id,
max_tokens: this.getModel().info.maxTokens,
messages: openAiMessages,
tools: openAiTools,
tool_choice: "auto",
}
const completion = await this.client.chat.completions.create(createParams)
const errorMessage = (completion as any).error?.message
if (errorMessage) {
throw new Error(errorMessage)
}
const anthropicMessage = convertToAnthropicMessage(completion)
return { message: anthropicMessage }
}
getModel(): { id: OpenAiNativeModelId; info: ModelInfo } {
const modelId = this.options.apiModelId
if (modelId && modelId in openAiNativeModels) {
const id = modelId as OpenAiNativeModelId
return { id, info: openAiNativeModels[id] }
}
return { id: openAiNativeDefaultModelId, info: openAiNativeModels[openAiNativeDefaultModelId] }
}
}

View File

@@ -26,6 +26,7 @@ type SecretKey =
| "awsSessionToken"
| "openAiApiKey"
| "geminiApiKey"
| "openAiNativeApiKey"
type GlobalStateKey =
| "apiProvider"
| "apiModelId"
@@ -337,6 +338,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
ollamaBaseUrl,
anthropicBaseUrl,
geminiApiKey,
openAiNativeApiKey,
} = message.apiConfiguration
await this.updateGlobalState("apiProvider", apiProvider)
await this.updateGlobalState("apiModelId", apiModelId)
@@ -355,6 +357,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
await this.updateGlobalState("ollamaBaseUrl", ollamaBaseUrl)
await this.updateGlobalState("anthropicBaseUrl", anthropicBaseUrl)
await this.storeSecret("geminiApiKey", geminiApiKey)
await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
this.claudeDev?.updateApi(message.apiConfiguration)
}
await this.postStateToWebview()
@@ -677,6 +680,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
ollamaBaseUrl,
anthropicBaseUrl,
geminiApiKey,
openAiNativeApiKey,
lastShownAnnouncementId,
customInstructions,
alwaysAllowReadOnly,
@@ -699,6 +703,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
this.getGlobalState("ollamaBaseUrl") as Promise<string | undefined>,
this.getGlobalState("anthropicBaseUrl") as Promise<string | undefined>,
this.getSecret("geminiApiKey") as Promise<string | undefined>,
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
this.getGlobalState("lastShownAnnouncementId") as Promise<string | undefined>,
this.getGlobalState("customInstructions") as Promise<string | undefined>,
this.getGlobalState("alwaysAllowReadOnly") as Promise<boolean | undefined>,
@@ -738,6 +743,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
ollamaBaseUrl,
anthropicBaseUrl,
geminiApiKey,
openAiNativeApiKey,
},
lastShownAnnouncementId,
customInstructions,
@@ -817,6 +823,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
"awsSessionToken",
"openAiApiKey",
"geminiApiKey",
"openAiNativeApiKey",
]
for (const key of secretKeys) {
await this.storeSecret(key, undefined)

View File

@@ -1,4 +1,12 @@
export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" | "openai" | "ollama" | "gemini"
export type ApiProvider =
| "anthropic"
| "openrouter"
| "bedrock"
| "vertex"
| "openai"
| "ollama"
| "gemini"
| "openai-native"
export interface ApiHandlerOptions {
apiModelId?: string
@@ -17,6 +25,7 @@ export interface ApiHandlerOptions {
ollamaModelId?: string
ollamaBaseUrl?: string
geminiApiKey?: string
openAiNativeApiKey?: string
}
export type ApiConfiguration = ApiHandlerOptions & {
@@ -334,3 +343,42 @@ export const geminiModels = {
outputPrice: 0,
},
} as const satisfies Record<string, ModelInfo>
// OpenAI Native
// https://openai.com/api/pricing/
export type OpenAiNativeModelId = keyof typeof openAiNativeModels
export const openAiNativeDefaultModelId: OpenAiNativeModelId = "o1-preview"
export const openAiNativeModels = {
"o1-preview": {
maxTokens: 32_768,
contextWindow: 128_000,
supportsImages: true,
supportsPromptCache: false,
inputPrice: 15,
outputPrice: 60,
},
"o1-mini": {
maxTokens: 65_536,
contextWindow: 128_000,
supportsImages: true,
supportsPromptCache: false,
inputPrice: 3,
outputPrice: 12,
},
"gpt-4o": {
maxTokens: 4_096,
contextWindow: 128_000,
supportsImages: true,
supportsPromptCache: false,
inputPrice: 5,
outputPrice: 15,
},
"gpt-4o-mini": {
maxTokens: 16_384,
contextWindow: 128_000,
supportsImages: true,
supportsPromptCache: false,
inputPrice: 0.15,
outputPrice: 0.6,
},
} as const satisfies Record<string, ModelInfo>