Add OpenAI provider

This commit is contained in:
Saoud Rizwan
2024-09-12 15:01:28 -04:00
parent cb8ce1685f
commit 4b44e8f921
7 changed files with 182 additions and 5 deletions

View File

@@ -7,6 +7,7 @@ import { VertexHandler } from "./vertex"
import { OpenAiHandler } from "./openai"
import { OllamaHandler } from "./ollama"
import { GeminiHandler } from "./gemini"
import { OpenAiNativeHandler } from "./openai-native"
export interface ApiHandlerMessageResponse {
message: Anthropic.Messages.Message
@@ -40,6 +41,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
return new OllamaHandler(options)
case "gemini":
return new GeminiHandler(options)
case "openai-native":
return new OpenAiNativeHandler(options)
default:
return new AnthropicHandler(options)
}

65
src/api/openai-native.ts Normal file
View File

@@ -0,0 +1,65 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import { ApiHandler, ApiHandlerMessageResponse } from "."
import {
ApiHandlerOptions,
ModelInfo,
openAiNativeDefaultModelId,
OpenAiNativeModelId,
openAiNativeModels,
} from "../shared/api"
import { convertToAnthropicMessage, convertToOpenAiMessages } from "../utils/openai-format"
export class OpenAiNativeHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: OpenAI
constructor(options: ApiHandlerOptions) {
this.options = options
this.client = new OpenAI({
apiKey: this.options.openAiNativeApiKey,
})
}
async createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<ApiHandlerMessageResponse> {
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
...convertToOpenAiMessages(messages),
]
const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
},
}))
const createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: this.getModel().id,
max_tokens: this.getModel().info.maxTokens,
messages: openAiMessages,
tools: openAiTools,
tool_choice: "auto",
}
const completion = await this.client.chat.completions.create(createParams)
const errorMessage = (completion as any).error?.message
if (errorMessage) {
throw new Error(errorMessage)
}
const anthropicMessage = convertToAnthropicMessage(completion)
return { message: anthropicMessage }
}
getModel(): { id: OpenAiNativeModelId; info: ModelInfo } {
const modelId = this.options.apiModelId
if (modelId && modelId in openAiNativeModels) {
const id = modelId as OpenAiNativeModelId
return { id, info: openAiNativeModels[id] }
}
return { id: openAiNativeDefaultModelId, info: openAiNativeModels[openAiNativeDefaultModelId] }
}
}

View File

@@ -26,6 +26,7 @@ type SecretKey =
| "awsSessionToken"
| "openAiApiKey"
| "geminiApiKey"
| "openAiNativeApiKey"
type GlobalStateKey =
| "apiProvider"
| "apiModelId"
@@ -337,6 +338,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
ollamaBaseUrl,
anthropicBaseUrl,
geminiApiKey,
openAiNativeApiKey,
} = message.apiConfiguration
await this.updateGlobalState("apiProvider", apiProvider)
await this.updateGlobalState("apiModelId", apiModelId)
@@ -355,6 +357,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
await this.updateGlobalState("ollamaBaseUrl", ollamaBaseUrl)
await this.updateGlobalState("anthropicBaseUrl", anthropicBaseUrl)
await this.storeSecret("geminiApiKey", geminiApiKey)
await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
this.claudeDev?.updateApi(message.apiConfiguration)
}
await this.postStateToWebview()
@@ -677,6 +680,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
ollamaBaseUrl,
anthropicBaseUrl,
geminiApiKey,
openAiNativeApiKey,
lastShownAnnouncementId,
customInstructions,
alwaysAllowReadOnly,
@@ -699,6 +703,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
this.getGlobalState("ollamaBaseUrl") as Promise<string | undefined>,
this.getGlobalState("anthropicBaseUrl") as Promise<string | undefined>,
this.getSecret("geminiApiKey") as Promise<string | undefined>,
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
this.getGlobalState("lastShownAnnouncementId") as Promise<string | undefined>,
this.getGlobalState("customInstructions") as Promise<string | undefined>,
this.getGlobalState("alwaysAllowReadOnly") as Promise<boolean | undefined>,
@@ -738,6 +743,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
ollamaBaseUrl,
anthropicBaseUrl,
geminiApiKey,
openAiNativeApiKey,
},
lastShownAnnouncementId,
customInstructions,
@@ -817,6 +823,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
"awsSessionToken",
"openAiApiKey",
"geminiApiKey",
"openAiNativeApiKey",
]
for (const key of secretKeys) {
await this.storeSecret(key, undefined)

View File

@@ -1,4 +1,12 @@
export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" | "openai" | "ollama" | "gemini"
export type ApiProvider =
| "anthropic"
| "openrouter"
| "bedrock"
| "vertex"
| "openai"
| "ollama"
| "gemini"
| "openai-native"
export interface ApiHandlerOptions {
apiModelId?: string
@@ -17,6 +25,7 @@ export interface ApiHandlerOptions {
ollamaModelId?: string
ollamaBaseUrl?: string
geminiApiKey?: string
openAiNativeApiKey?: string
}
export type ApiConfiguration = ApiHandlerOptions & {
@@ -334,3 +343,42 @@ export const geminiModels = {
outputPrice: 0,
},
} as const satisfies Record<string, ModelInfo>
// OpenAI Native
// https://openai.com/api/pricing/
export type OpenAiNativeModelId = keyof typeof openAiNativeModels
export const openAiNativeDefaultModelId: OpenAiNativeModelId = "o1-preview"
export const openAiNativeModels = {
"o1-preview": {
maxTokens: 32_768,
contextWindow: 128_000,
supportsImages: true,
supportsPromptCache: false,
inputPrice: 15,
outputPrice: 60,
},
"o1-mini": {
maxTokens: 65_536,
contextWindow: 128_000,
supportsImages: true,
supportsPromptCache: false,
inputPrice: 3,
outputPrice: 12,
},
"gpt-4o": {
maxTokens: 4_096,
contextWindow: 128_000,
supportsImages: true,
supportsPromptCache: false,
inputPrice: 5,
outputPrice: 15,
},
"gpt-4o-mini": {
maxTokens: 16_384,
contextWindow: 128_000,
supportsImages: true,
supportsPromptCache: false,
inputPrice: 0.15,
outputPrice: 0.6,
},
} as const satisfies Record<string, ModelInfo>

View File

@@ -12,6 +12,7 @@ import { useEvent, useInterval } from "react-use"
import {
ApiConfiguration,
ModelInfo,
OpenAiNativeModelId,
anthropicDefaultModelId,
anthropicModels,
bedrockDefaultModelId,
@@ -19,6 +20,8 @@ import {
geminiDefaultModelId,
geminiModels,
openAiModelInfoSaneDefaults,
openAiNativeDefaultModelId,
openAiNativeModels,
openRouterDefaultModelId,
openRouterModels,
vertexDefaultModelId,
@@ -112,10 +115,11 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
onChange={handleInputChange("apiProvider")}
style={{ minWidth: 130 }}>
<VSCodeOption value="anthropic">Anthropic</VSCodeOption>
<VSCodeOption value="openai-native">OpenAI</VSCodeOption>
<VSCodeOption value="openrouter">OpenRouter</VSCodeOption>
<VSCodeOption value="gemini">Google Gemini</VSCodeOption>
<VSCodeOption value="bedrock">AWS Bedrock</VSCodeOption>
<VSCodeOption value="vertex">GCP Vertex AI</VSCodeOption>
<VSCodeOption value="gemini">Google Gemini</VSCodeOption>
<VSCodeOption value="openai">OpenAI Compatible</VSCodeOption>
<VSCodeOption value="ollama">Ollama</VSCodeOption>
</VSCodeDropdown>
@@ -174,6 +178,34 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
</div>
)}
{selectedProvider === "openai-native" && (
<div>
<VSCodeTextField
value={apiConfiguration?.openAiNativeApiKey || ""}
style={{ width: "100%" }}
type="password"
onInput={handleInputChange("openAiNativeApiKey")}
placeholder="Enter API Key...">
<span style={{ fontWeight: 500 }}>OpenAI API Key</span>
</VSCodeTextField>
<p
style={{
fontSize: "12px",
marginTop: 3,
color: "var(--vscode-descriptionForeground)",
}}>
This key is stored locally and only used to make API requests from this extension.
{!apiConfiguration?.openAiNativeApiKey && (
<VSCodeLink
href="https://platform.openai.com/api-keys"
style={{ display: "inline", fontSize: "inherit" }}>
You can get an OpenAI API key by signing up here.
</VSCodeLink>
)}
</p>
</div>
)}
{selectedProvider === "openrouter" && (
<div>
<VSCodeTextField
@@ -490,6 +522,7 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
{selectedProvider === "bedrock" && createDropdown(bedrockModels)}
{selectedProvider === "vertex" && createDropdown(vertexModels)}
{selectedProvider === "gemini" && createDropdown(geminiModels)}
{selectedProvider === "openai-native" && createDropdown(openAiNativeModels)}
</div>
<ModelInfoView selectedModelId={selectedModelId} modelInfo={selectedModelInfo} />
@@ -514,6 +547,7 @@ export const formatPrice = (price: number) => {
const ModelInfoView = ({ selectedModelId, modelInfo }: { selectedModelId: string; modelInfo: ModelInfo }) => {
const isGemini = Object.keys(geminiModels).includes(selectedModelId)
const isO1 = (["o1-preview", "o1-mini"] as OpenAiNativeModelId[]).includes(selectedModelId as OpenAiNativeModelId)
return (
<p style={{ fontSize: "12px", marginTop: "2px", color: "var(--vscode-descriptionForeground)" }}>
<ModelInfoSupportsItem
@@ -533,32 +567,33 @@ const ModelInfoView = ({ selectedModelId, modelInfo }: { selectedModelId: string
</>
)}
<span style={{ fontWeight: 500 }}>Max output:</span> {modelInfo?.maxTokens?.toLocaleString()} tokens
<br />
{modelInfo.inputPrice > 0 && (
<>
<br />
<span style={{ fontWeight: 500 }}>Input price:</span> {formatPrice(modelInfo.inputPrice)}/million
tokens
<br />
</>
)}
{modelInfo.supportsPromptCache && modelInfo.cacheWritesPrice && modelInfo.cacheReadsPrice && (
<>
<br />
<span style={{ fontWeight: 500 }}>Cache writes price:</span>{" "}
{formatPrice(modelInfo.cacheWritesPrice || 0)}/million tokens
<br />
<span style={{ fontWeight: 500 }}>Cache reads price:</span>{" "}
{formatPrice(modelInfo.cacheReadsPrice || 0)}/million tokens
<br />
</>
)}
{modelInfo.outputPrice > 0 && (
<>
<br />
<span style={{ fontWeight: 500 }}>Output price:</span> {formatPrice(modelInfo.outputPrice)}/million
tokens
</>
)}
{isGemini && (
<>
<br />
<span
style={{
fontStyle: "italic",
@@ -573,6 +608,17 @@ const ModelInfoView = ({ selectedModelId, modelInfo }: { selectedModelId: string
</span>
</>
)}
{isO1 && (
<>
<br />
<span
style={{
fontStyle: "italic",
}}>
* This model is newly released and may not be accessible to all users yet.
</span>
</>
)}
</p>
)
}
@@ -632,6 +678,8 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
return getProviderData(vertexModels, vertexDefaultModelId)
case "gemini":
return getProviderData(geminiModels, geminiDefaultModelId)
case "openai-native":
return getProviderData(openAiNativeModels, openAiNativeDefaultModelId)
case "openai":
return {
selectedProvider: provider,

View File

@@ -41,6 +41,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
config.openAiApiKey,
config.ollamaModelId,
config.geminiApiKey,
config.openAiNativeApiKey,
].some((key) => key !== undefined)
: false
setShowWelcome(!hasKey)

View File

@@ -28,6 +28,11 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s
return "You must provide a valid API key or choose a different provider."
}
break
case "openai-native":
if (!apiConfiguration.openAiNativeApiKey) {
return "You must provide a valid API key or choose a different provider."
}
break
case "openai":
if (
!apiConfiguration.openAiBaseUrl ||