Add OpenAI provider

This commit is contained in:
Saoud Rizwan
2024-09-12 15:01:28 -04:00
parent cb8ce1685f
commit 4b44e8f921
7 changed files with 182 additions and 5 deletions

View File

@@ -12,6 +12,7 @@ import { useEvent, useInterval } from "react-use"
import {
ApiConfiguration,
ModelInfo,
OpenAiNativeModelId,
anthropicDefaultModelId,
anthropicModels,
bedrockDefaultModelId,
@@ -19,6 +20,8 @@ import {
geminiDefaultModelId,
geminiModels,
openAiModelInfoSaneDefaults,
openAiNativeDefaultModelId,
openAiNativeModels,
openRouterDefaultModelId,
openRouterModels,
vertexDefaultModelId,
@@ -112,10 +115,11 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
onChange={handleInputChange("apiProvider")}
style={{ minWidth: 130 }}>
<VSCodeOption value="anthropic">Anthropic</VSCodeOption>
<VSCodeOption value="openai-native">OpenAI</VSCodeOption>
<VSCodeOption value="openrouter">OpenRouter</VSCodeOption>
<VSCodeOption value="gemini">Google Gemini</VSCodeOption>
<VSCodeOption value="bedrock">AWS Bedrock</VSCodeOption>
<VSCodeOption value="vertex">GCP Vertex AI</VSCodeOption>
<VSCodeOption value="gemini">Google Gemini</VSCodeOption>
<VSCodeOption value="openai">OpenAI Compatible</VSCodeOption>
<VSCodeOption value="ollama">Ollama</VSCodeOption>
</VSCodeDropdown>
@@ -174,6 +178,34 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
</div>
)}
{selectedProvider === "openai-native" && (
<div>
<VSCodeTextField
value={apiConfiguration?.openAiNativeApiKey || ""}
style={{ width: "100%" }}
type="password"
onInput={handleInputChange("openAiNativeApiKey")}
placeholder="Enter API Key...">
<span style={{ fontWeight: 500 }}>OpenAI API Key</span>
</VSCodeTextField>
<p
style={{
fontSize: "12px",
marginTop: 3,
color: "var(--vscode-descriptionForeground)",
}}>
This key is stored locally and only used to make API requests from this extension.
{!apiConfiguration?.openAiNativeApiKey && (
<VSCodeLink
href="https://platform.openai.com/api-keys"
style={{ display: "inline", fontSize: "inherit" }}>
You can get an OpenAI API key by signing up here.
</VSCodeLink>
)}
</p>
</div>
)}
{selectedProvider === "openrouter" && (
<div>
<VSCodeTextField
@@ -490,6 +522,7 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
{selectedProvider === "bedrock" && createDropdown(bedrockModels)}
{selectedProvider === "vertex" && createDropdown(vertexModels)}
{selectedProvider === "gemini" && createDropdown(geminiModels)}
{selectedProvider === "openai-native" && createDropdown(openAiNativeModels)}
</div>
<ModelInfoView selectedModelId={selectedModelId} modelInfo={selectedModelInfo} />
@@ -514,6 +547,7 @@ export const formatPrice = (price: number) => {
const ModelInfoView = ({ selectedModelId, modelInfo }: { selectedModelId: string; modelInfo: ModelInfo }) => {
const isGemini = Object.keys(geminiModels).includes(selectedModelId)
const isO1 = (["o1-preview", "o1-mini"] as OpenAiNativeModelId[]).includes(selectedModelId as OpenAiNativeModelId)
return (
<p style={{ fontSize: "12px", marginTop: "2px", color: "var(--vscode-descriptionForeground)" }}>
<ModelInfoSupportsItem
@@ -533,32 +567,33 @@ const ModelInfoView = ({ selectedModelId, modelInfo }: { selectedModelId: string
</>
)}
<span style={{ fontWeight: 500 }}>Max output:</span> {modelInfo?.maxTokens?.toLocaleString()} tokens
<br />
{modelInfo.inputPrice > 0 && (
<>
<br />
<span style={{ fontWeight: 500 }}>Input price:</span> {formatPrice(modelInfo.inputPrice)}/million
tokens
<br />
</>
)}
{modelInfo.supportsPromptCache && modelInfo.cacheWritesPrice && modelInfo.cacheReadsPrice && (
<>
<br />
<span style={{ fontWeight: 500 }}>Cache writes price:</span>{" "}
{formatPrice(modelInfo.cacheWritesPrice || 0)}/million tokens
<br />
<span style={{ fontWeight: 500 }}>Cache reads price:</span>{" "}
{formatPrice(modelInfo.cacheReadsPrice || 0)}/million tokens
<br />
</>
)}
{modelInfo.outputPrice > 0 && (
<>
<br />
<span style={{ fontWeight: 500 }}>Output price:</span> {formatPrice(modelInfo.outputPrice)}/million
tokens
</>
)}
{isGemini && (
<>
<br />
<span
style={{
fontStyle: "italic",
@@ -573,6 +608,17 @@ const ModelInfoView = ({ selectedModelId, modelInfo }: { selectedModelId: string
</span>
</>
)}
{isO1 && (
<>
<br />
<span
style={{
fontStyle: "italic",
}}>
* This model is newly released and may not be accessible to all users yet.
</span>
</>
)}
</p>
)
}
@@ -632,6 +678,8 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
return getProviderData(vertexModels, vertexDefaultModelId)
case "gemini":
return getProviderData(geminiModels, geminiDefaultModelId)
case "openai-native":
return getProviderData(openAiNativeModels, openAiNativeDefaultModelId)
case "openai":
return {
selectedProvider: provider,

View File

@@ -41,6 +41,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
config.openAiApiKey,
config.ollamaModelId,
config.geminiApiKey,
config.openAiNativeApiKey,
].some((key) => key !== undefined)
: false
setShowWelcome(!hasKey)

View File

@@ -28,6 +28,11 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s
return "You must provide a valid API key or choose a different provider."
}
break
case "openai-native":
if (!apiConfiguration.openAiNativeApiKey) {
return "You must provide a valid API key or choose a different provider."
}
break
case "openai":
if (
!apiConfiguration.openAiBaseUrl ||