diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index 13922a4..c2c3a89 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -108,7 +108,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler { getModel(): { id: string; info: ModelInfo } { return { id: this.options.openAiModelId ?? "", - info: openAiModelInfoSaneDefaults, + info: this.options.openAiCusModelInfo ?? openAiModelInfoSaneDefaults, } } diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 72e7e27..771aae3 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -68,6 +68,7 @@ type GlobalStateKey = | "taskHistory" | "openAiBaseUrl" | "openAiModelId" + | "openAiCusModelInfo" | "ollamaModelId" | "ollamaBaseUrl" | "lmStudioModelId" @@ -1198,6 +1199,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { openAiBaseUrl, openAiApiKey, openAiModelId, + openAiCusModelInfo, ollamaModelId, ollamaBaseUrl, lmStudioModelId, @@ -1231,6 +1233,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { await this.updateGlobalState("openAiBaseUrl", openAiBaseUrl) await this.storeSecret("openAiApiKey", openAiApiKey) await this.updateGlobalState("openAiModelId", openAiModelId) + await this.updateGlobalState("openAiCusModelInfo", openAiCusModelInfo) await this.updateGlobalState("ollamaModelId", ollamaModelId) await this.updateGlobalState("ollamaBaseUrl", ollamaBaseUrl) await this.updateGlobalState("lmStudioModelId", lmStudioModelId) @@ -1847,6 +1850,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { openAiBaseUrl, openAiApiKey, openAiModelId, + openAiCusModelInfo, ollamaModelId, ollamaBaseUrl, lmStudioModelId, @@ -1910,6 +1914,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { this.getGlobalState("openAiBaseUrl") as Promise, this.getSecret("openAiApiKey") as Promise, this.getGlobalState("openAiModelId") as Promise, + this.getGlobalState("openAiCusModelInfo") as Promise, this.getGlobalState("ollamaModelId") as Promise, this.getGlobalState("ollamaBaseUrl") as Promise, this.getGlobalState("lmStudioModelId") as Promise, @@ -1990,6 +1995,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { openAiBaseUrl, openAiApiKey, openAiModelId, + openAiCusModelInfo, ollamaModelId, ollamaBaseUrl, lmStudioModelId, diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index ce05976..5aaaa82 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -76,6 +76,7 @@ export interface WebviewMessage { | "autoApprovalEnabled" | "updateCustomMode" | "deleteCustomMode" + | "setOpenAiCusModelInfo" text?: string disabled?: boolean askResponse?: ClineAskResponse diff --git a/src/shared/api.ts b/src/shared/api.ts index 8f65c67..5524a1e 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -38,6 +38,7 @@ export interface ApiHandlerOptions { openAiBaseUrl?: string openAiApiKey?: string openAiModelId?: string + openAiCusModelInfo?: ModelInfo ollamaModelId?: string ollamaBaseUrl?: string lmStudioModelId?: string diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 561ae5b..9fbb6d8 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -550,6 +550,184 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = placeholder={`Default: ${azureOpenAiDefaultApiVersion}`} /> )} + + {/* Model Info Configuration */} +
+
+ Model Configuration +

+ Configure the capabilities and pricing for your custom OpenAI-compatible model +

+
+ + {/* Capabilities Section */} +
+ Capabilities +
+ { + const value = parseInt(e.target.value) + setApiConfiguration({ + ...apiConfiguration, + openAiCusModelInfo: { + ...(apiConfiguration?.openAiCusModelInfo || openAiModelInfoSaneDefaults), + maxTokens: isNaN(value) ? undefined : value + } + }) + }} + placeholder="e.g. 4096"> + Max Output Tokens + + + { + const parsed = parseInt(e.target.value) + setApiConfiguration({ + ...apiConfiguration, + openAiCusModelInfo: { + ...(apiConfiguration?.openAiCusModelInfo || openAiModelInfoSaneDefaults), + contextWindow: e.target.value === "" ? undefined : (isNaN(parsed) ? openAiModelInfoSaneDefaults.contextWindow : parsed) + } + }) + }} + placeholder="e.g. 128000"> + Context Window Size + + +
+ { + setApiConfiguration({ + ...apiConfiguration, + openAiCusModelInfo: { + ...(apiConfiguration?.openAiCusModelInfo || openAiModelInfoSaneDefaults), + supportsImages: e.target.checked + } + }) + }}> + Supports Images + + + { + setApiConfiguration({ + ...apiConfiguration, + openAiCusModelInfo: { + ...(apiConfiguration?.openAiCusModelInfo || openAiModelInfoSaneDefaults), + supportsComputerUse: e.target.checked + } + }) + }}> + Supports Computer Use + +
+
+
+ + {/* Pricing Section */} +
+ Pricing (USD per million tokens) +
+ {/* Input/Output Prices */} +
+ { + const parsed = parseFloat(e.target.value) + setApiConfiguration({ + ...apiConfiguration, + openAiCusModelInfo: { + ...(apiConfiguration?.openAiCusModelInfo || openAiModelInfoSaneDefaults), + inputPrice: e.target.value === "" ? undefined : (isNaN(parsed) ? openAiModelInfoSaneDefaults.inputPrice : parsed) + } + }) + }} + placeholder="e.g. 0.0001"> + Input Price + + + { + const parsed = parseFloat(e.target.value) + setApiConfiguration({ + ...apiConfiguration, + openAiCusModelInfo: { + ...(apiConfiguration?.openAiCusModelInfo || openAiModelInfoSaneDefaults), + outputPrice: e.target.value === "" ? undefined : (isNaN(parsed) ? openAiModelInfoSaneDefaults.outputPrice : parsed) + } + }) + }} + placeholder="e.g. 0.0002"> + Output Price + +
+ + {/* Cache Prices */} +
+ { + const parsed = parseFloat(e.target.value) + setApiConfiguration({ + ...apiConfiguration, + openAiCusModelInfo: { + ...(apiConfiguration?.openAiCusModelInfo || openAiModelInfoSaneDefaults), + cacheWritesPrice: e.target.value === "" ? undefined : (isNaN(parsed) ? openAiModelInfoSaneDefaults.cacheWritesPrice : parsed) + } + }) + }} + placeholder="e.g. 0.0001"> + Cache Write Price + + + { + const parsed = parseFloat(e.target.value) + setApiConfiguration({ + ...apiConfiguration, + openAiCusModelInfo: { + ...(apiConfiguration?.openAiCusModelInfo || openAiModelInfoSaneDefaults), + cacheReadsPrice: e.target.value === "" ? undefined : (isNaN(parsed) ? openAiModelInfoSaneDefaults.cacheReadsPrice : parsed) + } + }) + }} + placeholder="e.g. 0.00001"> + Cache Read Price + +
+
+
+
+ + { /* TODO: model info here */} + +