Add OpenRouter custom model scheme

This commit is contained in:
Saoud Rizwan
2024-10-03 19:46:09 -04:00
parent d5b3bd7788
commit 7cb0c524e5
12 changed files with 262 additions and 54 deletions

View File

@@ -25,7 +25,7 @@ export class GeminiHandler implements ApiHandler {
const result = await model.generateContentStream({ const result = await model.generateContentStream({
contents: messages.map(convertAnthropicMessageToGemini), contents: messages.map(convertAnthropicMessageToGemini),
generationConfig: { generationConfig: {
maxOutputTokens: this.getModel().info.maxTokens, // maxOutputTokens: this.getModel().info.maxTokens,
temperature: 0, temperature: 0,
}, },
}) })

View File

@@ -24,7 +24,7 @@ export class OllamaHandler implements ApiHandler {
] ]
const stream = await this.client.chat.completions.create({ const stream = await this.client.chat.completions.create({
model: this.options.ollamaModelId ?? "", model: this.getModel().id,
messages: openAiMessages, messages: openAiMessages,
temperature: 0, temperature: 0,
stream: true, stream: true,

View File

@@ -30,7 +30,7 @@ export class OpenAiNativeHandler implements ApiHandler {
const stream = await this.client.chat.completions.create({ const stream = await this.client.chat.completions.create({
model: this.getModel().id, model: this.getModel().id,
max_completion_tokens: this.getModel().info.maxTokens, // max_completion_tokens: this.getModel().info.maxTokens,
temperature: 0, temperature: 0,
messages: openAiMessages, messages: openAiMessages,
stream: true, stream: true,

View File

@@ -2,13 +2,7 @@ import { Anthropic } from "@anthropic-ai/sdk"
import axios from "axios" import axios from "axios"
import OpenAI from "openai" import OpenAI from "openai"
import { ApiHandler } from "../" import { ApiHandler } from "../"
import { import { ApiHandlerOptions, ModelInfo, openRouterDefaultModelId, openRouterDefaultModelInfo } from "../../shared/api"
ApiHandlerOptions,
ModelInfo,
openRouterDefaultModelId,
OpenRouterModelId,
openRouterModels,
} from "../../shared/api"
import { convertToOpenAiMessages } from "../transform/openai-format" import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream" import { ApiStream } from "../transform/stream"
@@ -74,9 +68,18 @@ export class OpenRouterHandler implements ApiHandler {
break break
} }
// Not sure how openrouter defaults max tokens when no value is provided, but the anthropic api requires this value and since they offer both 4096 and 8192 variants, we should ensure 8192.
// (models usually default to max tokens allowed)
let maxTokens: number | undefined
switch (this.getModel().id) {
case "anthropic/claude-3.5-sonnet":
case "anthropic/claude-3.5-sonnet:beta":
maxTokens = 8_192
break
}
const stream = await this.client.chat.completions.create({ const stream = await this.client.chat.completions.create({
model: this.getModel().id, model: this.getModel().id,
max_tokens: this.getModel().info.maxTokens, max_tokens: maxTokens,
temperature: 0, temperature: 0,
messages: openAiMessages, messages: openAiMessages,
stream: true, stream: true,
@@ -129,12 +132,12 @@ export class OpenRouterHandler implements ApiHandler {
} }
} }
getModel(): { id: OpenRouterModelId; info: ModelInfo } { getModel(): { id: string; info: ModelInfo } {
const modelId = this.options.apiModelId const modelId = this.options.openRouterModelId
if (modelId && modelId in openRouterModels) { const modelInfo = this.options.openRouterModelInfo
const id = modelId as OpenRouterModelId if (modelId && modelInfo) {
return { id, info: openRouterModels[id] } return { id: modelId, info: modelInfo }
} }
return { id: openRouterDefaultModelId, info: openRouterModels[openRouterDefaultModelId] } return { id: openRouterDefaultModelId, info: openRouterDefaultModelInfo }
} }
} }

View File

@@ -1,7 +1,7 @@
import { Anthropic } from "@anthropic-ai/sdk" import { Anthropic } from "@anthropic-ai/sdk"
import * as vscode from "vscode" import * as vscode from "vscode"
import { ClaudeDev } from "../ClaudeDev" import { ClaudeDev } from "../ClaudeDev"
import { ApiProvider } from "../../shared/api" import { ApiProvider, ModelInfo } from "../../shared/api"
import { ExtensionMessage } from "../../shared/ExtensionMessage" import { ExtensionMessage } from "../../shared/ExtensionMessage"
import { WebviewMessage } from "../../shared/WebviewMessage" import { WebviewMessage } from "../../shared/WebviewMessage"
import { findLast } from "../../shared/array" import { findLast } from "../../shared/array"
@@ -52,10 +52,13 @@ type GlobalStateKey =
| "ollamaBaseUrl" | "ollamaBaseUrl"
| "anthropicBaseUrl" | "anthropicBaseUrl"
| "azureApiVersion" | "azureApiVersion"
| "openRouterModelId"
| "openRouterModelInfo"
export const GlobalFileNames = { export const GlobalFileNames = {
apiConversationHistory: "api_conversation_history.json", apiConversationHistory: "api_conversation_history.json",
claudeMessages: "claude_messages.json", claudeMessages: "claude_messages.json",
openRouterModels: "openrouter_models.json",
} }
export class ClaudeDevProvider implements vscode.WebviewViewProvider { export class ClaudeDevProvider implements vscode.WebviewViewProvider {
@@ -322,10 +325,19 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
async (message: WebviewMessage) => { async (message: WebviewMessage) => {
switch (message.type) { switch (message.type) {
case "webviewDidLaunch": case "webviewDidLaunch":
await this.postStateToWebview() this.postStateToWebview()
const theme = await getTheme() this.workspaceTracker?.initializeFilePaths() // don't await
await this.postMessageToWebview({ type: "theme", text: JSON.stringify(theme) }) getTheme().then((theme) =>
this.workspaceTracker?.initializeFilePaths() this.postMessageToWebview({ type: "theme", text: JSON.stringify(theme) })
)
this.readOpenRouterModels().then((openRouterModels) => {
if (openRouterModels) {
this.postMessageToWebview({ type: "openRouterModels", openRouterModels })
} else {
// nothing cached, fetch first time
this.refreshOpenRouterModels()
}
})
break break
case "newTask": case "newTask":
// Code that should run in response to the hello message command // Code that should run in response to the hello message command
@@ -360,6 +372,8 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
geminiApiKey, geminiApiKey,
openAiNativeApiKey, openAiNativeApiKey,
azureApiVersion, azureApiVersion,
openRouterModelId,
openRouterModelInfo,
} = message.apiConfiguration } = message.apiConfiguration
await this.updateGlobalState("apiProvider", apiProvider) await this.updateGlobalState("apiProvider", apiProvider)
await this.updateGlobalState("apiModelId", apiModelId) await this.updateGlobalState("apiModelId", apiModelId)
@@ -380,6 +394,8 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
await this.storeSecret("geminiApiKey", geminiApiKey) await this.storeSecret("geminiApiKey", geminiApiKey)
await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey) await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
await this.updateGlobalState("azureApiVersion", azureApiVersion) await this.updateGlobalState("azureApiVersion", azureApiVersion)
await this.updateGlobalState("openRouterModelId", openRouterModelId)
await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo)
if (this.claudeDev) { if (this.claudeDev) {
this.claudeDev.api = buildApiHandler(message.apiConfiguration) this.claudeDev.api = buildApiHandler(message.apiConfiguration)
} }
@@ -431,8 +447,11 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
await this.resetState() await this.resetState()
break break
case "requestOllamaModels": case "requestOllamaModels":
const models = await this.getOllamaModels(message.text) const ollamaModels = await this.getOllamaModels(message.text)
this.postMessageToWebview({ type: "ollamaModels", models }) this.postMessageToWebview({ type: "ollamaModels", ollamaModels })
break
case "refreshOpenRouterModels":
await this.refreshOpenRouterModels()
break break
case "openImage": case "openImage":
openImage(message.text!) openImage(message.text!)
@@ -518,6 +537,97 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
// await this.postMessageToWebview({ type: "action", action: "settingsButtonTapped" }) // bad ux if user is on welcome // await this.postMessageToWebview({ type: "action", action: "settingsButtonTapped" }) // bad ux if user is on welcome
} }
async readOpenRouterModels(): Promise<Record<string, ModelInfo> | undefined> {
const cacheDir = path.join(this.context.globalStorageUri.fsPath, "cache")
const openRouterModelsFilePath = path.join(cacheDir, GlobalFileNames.openRouterModels)
const fileExists = await fileExistsAtPath(openRouterModelsFilePath)
if (fileExists) {
const fileContents = await fs.readFile(openRouterModelsFilePath, "utf8")
return JSON.parse(fileContents)
}
return undefined
}
async refreshOpenRouterModels() {
const cacheDir = path.join(this.context.globalStorageUri.fsPath, "cache")
const openRouterModelsFilePath = path.join(cacheDir, GlobalFileNames.openRouterModels)
let models: Record<string, ModelInfo> = {}
try {
const response = await axios.get("https://openrouter.ai/api/v1/models")
/*
{
"id": "anthropic/claude-3.5-sonnet",
"name": "Anthropic: Claude 3.5 Sonnet",
"created": 1718841600,
"description": "Claude 3.5 Sonnet delivers better-than-Opus capabilities, faster-than-Sonnet speeds, at the same Sonnet prices. Sonnet is particularly good at:\n\n- Coding: Autonomously writes, edits, and runs code with reasoning and troubleshooting\n- Data science: Augments human data science expertise; navigates unstructured data while using multiple tools for insights\n- Visual processing: excelling at interpreting charts, graphs, and images, accurately transcribing text to derive insights beyond just the text alone\n- Agentic tasks: exceptional tool use, making it great at agentic tasks (i.e. complex, multi-step problem solving tasks that require engaging with other systems)\n\n#multimodal",
"context_length": 200000,
"architecture": {
"modality": "text+image-\u003Etext",
"tokenizer": "Claude",
"instruct_type": null
},
"pricing": {
"prompt": "0.000003",
"completion": "0.000015",
"image": "0.0048",
"request": "0"
},
"top_provider": {
"context_length": 200000,
"max_completion_tokens": 8192,
"is_moderated": true
},
"per_request_limits": null
},
*/
if (response.data) {
const rawModels = response.data
for (const rawModel of rawModels) {
const modelInfo: ModelInfo = {
maxTokens: rawModel.top_provider?.max_completion_tokens || 2048,
contextWindow: rawModel.context_length || 128_000,
supportsImages: rawModel.architecture?.modality?.includes("image") ?? false,
supportsPromptCache: false,
inputPrice: parseFloat(rawModel.pricing?.prompt || 0) * 1_000_000,
outputPrice: parseFloat(rawModel.pricing?.completion || 0) * 1_000_000,
description: rawModel.description,
}
switch (rawModel.id) {
case "anthropic/claude-3.5-sonnet":
case "anthropic/claude-3.5-sonnet:beta":
modelInfo.supportsPromptCache = true
modelInfo.cacheWritesPrice = 3.75
modelInfo.cacheReadsPrice = 0.3
break
case "anthropic/claude-3-opus":
case "anthropic/claude-3-opus:beta":
modelInfo.supportsPromptCache = true
modelInfo.cacheWritesPrice = 18.75
modelInfo.cacheReadsPrice = 1.5
break
case "anthropic/claude-3-haiku":
case "anthropic/claude-3-haiku:beta":
modelInfo.supportsPromptCache = true
modelInfo.cacheWritesPrice = 0.3
modelInfo.cacheReadsPrice = 0.03
break
}
models[rawModel.id] = modelInfo
}
} else {
console.error("Invalid response from OpenRouter API")
}
await fs.writeFile(openRouterModelsFilePath, JSON.stringify(models))
} catch (error) {
console.error("Error fetching OpenRouter models:", error)
}
await this.postMessageToWebview({ type: "openRouterModels", openRouterModels: models })
}
// Task history // Task history
async getTaskWithId(id: string): Promise<{ async getTaskWithId(id: string): Promise<{
@@ -722,6 +832,8 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
geminiApiKey, geminiApiKey,
openAiNativeApiKey, openAiNativeApiKey,
azureApiVersion, azureApiVersion,
openRouterModelId,
openRouterModelInfo,
lastShownAnnouncementId, lastShownAnnouncementId,
customInstructions, customInstructions,
alwaysAllowReadOnly, alwaysAllowReadOnly,
@@ -746,6 +858,8 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
this.getSecret("geminiApiKey") as Promise<string | undefined>, this.getSecret("geminiApiKey") as Promise<string | undefined>,
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>, this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
this.getGlobalState("azureApiVersion") as Promise<string | undefined>, this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
this.getGlobalState("openRouterModelId") as Promise<string | undefined>,
this.getGlobalState("openRouterModelInfo") as Promise<ModelInfo | undefined>,
this.getGlobalState("lastShownAnnouncementId") as Promise<string | undefined>, this.getGlobalState("lastShownAnnouncementId") as Promise<string | undefined>,
this.getGlobalState("customInstructions") as Promise<string | undefined>, this.getGlobalState("customInstructions") as Promise<string | undefined>,
this.getGlobalState("alwaysAllowReadOnly") as Promise<boolean | undefined>, this.getGlobalState("alwaysAllowReadOnly") as Promise<boolean | undefined>,
@@ -787,6 +901,8 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
geminiApiKey, geminiApiKey,
openAiNativeApiKey, openAiNativeApiKey,
azureApiVersion, azureApiVersion,
openRouterModelId,
openRouterModelInfo,
}, },
lastShownAnnouncementId, lastShownAnnouncementId,
customInstructions, customInstructions,

View File

@@ -1,6 +1,6 @@
// type that represents json data that is sent from extension to webview, called ExtensionMessage and has 'type' enum which can be 'plusButtonTapped' or 'settingsButtonTapped' or 'hello' // type that represents json data that is sent from extension to webview, called ExtensionMessage and has 'type' enum which can be 'plusButtonTapped' or 'settingsButtonTapped' or 'hello'
import { ApiConfiguration } from "./api" import { ApiConfiguration, ModelInfo } from "./api"
import { HistoryItem } from "./HistoryItem" import { HistoryItem } from "./HistoryItem"
// webview will hold state // webview will hold state
@@ -14,14 +14,16 @@ export interface ExtensionMessage {
| "workspaceUpdated" | "workspaceUpdated"
| "invoke" | "invoke"
| "partialMessage" | "partialMessage"
| "openRouterModels"
text?: string text?: string
action?: "chatButtonTapped" | "settingsButtonTapped" | "historyButtonTapped" | "didBecomeVisible" action?: "chatButtonTapped" | "settingsButtonTapped" | "historyButtonTapped" | "didBecomeVisible"
invoke?: "sendMessage" | "primaryButtonClick" | "secondaryButtonClick" invoke?: "sendMessage" | "primaryButtonClick" | "secondaryButtonClick"
state?: ExtensionState state?: ExtensionState
images?: string[] images?: string[]
models?: string[] ollamaModels?: string[]
filePaths?: string[] filePaths?: string[]
partialMessage?: ClaudeMessage partialMessage?: ClaudeMessage
openRouterModels?: Record<string, ModelInfo>
} }
export interface ExtensionState { export interface ExtensionState {

View File

@@ -21,6 +21,7 @@ export interface WebviewMessage {
| "openFile" | "openFile"
| "openMention" | "openMention"
| "cancelTask" | "cancelTask"
| "refreshOpenRouterModels"
text?: string text?: string
askResponse?: ClaudeAskResponse askResponse?: ClaudeAskResponse
apiConfiguration?: ApiConfiguration apiConfiguration?: ApiConfiguration

View File

@@ -13,6 +13,8 @@ export interface ApiHandlerOptions {
apiKey?: string // anthropic apiKey?: string // anthropic
anthropicBaseUrl?: string anthropicBaseUrl?: string
openRouterApiKey?: string openRouterApiKey?: string
openRouterModelId?: string
openRouterModelInfo?: ModelInfo
awsAccessKey?: string awsAccessKey?: string
awsSecretKey?: string awsSecretKey?: string
awsSessionToken?: string awsSessionToken?: string
@@ -44,6 +46,7 @@ export interface ModelInfo {
outputPrice: number outputPrice: number
cacheWritesPrice?: number cacheWritesPrice?: number
cacheReadsPrice?: number cacheReadsPrice?: number
description?: string
} }
// Anthropic // Anthropic
@@ -116,9 +119,19 @@ export const bedrockModels = {
// OpenRouter // OpenRouter
// https://openrouter.ai/models?order=newest&supported_parameters=tools // https://openrouter.ai/models?order=newest&supported_parameters=tools
export type OpenRouterModelId = keyof typeof openRouterModels type OpenRouterModelId = keyof typeof openRouterModels
export const openRouterDefaultModelId: OpenRouterModelId = "anthropic/claude-3.5-sonnet:beta" export const openRouterDefaultModelId = "anthropic/claude-3.5-sonnet:beta" // will always exist in openRouterModels
export const openRouterModels = { export const openRouterDefaultModelInfo: ModelInfo = {
maxTokens: 8192,
contextWindow: 200_000,
supportsImages: true,
supportsPromptCache: true,
inputPrice: 3.0,
outputPrice: 15.0,
cacheWritesPrice: 3.75,
cacheReadsPrice: 0.3,
}
const openRouterModels = {
"anthropic/claude-3.5-sonnet:beta": { "anthropic/claude-3.5-sonnet:beta": {
maxTokens: 8192, maxTokens: 8192,
contextWindow: 200_000, contextWindow: 200_000,

View File

@@ -23,7 +23,7 @@ import {
openAiNativeDefaultModelId, openAiNativeDefaultModelId,
openAiNativeModels, openAiNativeModels,
openRouterDefaultModelId, openRouterDefaultModelId,
openRouterModels, openRouterDefaultModelInfo,
vertexDefaultModelId, vertexDefaultModelId,
vertexModels, vertexModels,
} from "../../../../src/shared/api" } from "../../../../src/shared/api"
@@ -66,8 +66,8 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
const handleMessage = useCallback((event: MessageEvent) => { const handleMessage = useCallback((event: MessageEvent) => {
const message: ExtensionMessage = event.data const message: ExtensionMessage = event.data
if (message.type === "ollamaModels" && message.models) { if (message.type === "ollamaModels" && message.ollamaModels) {
setOllamaModels(message.models) setOllamaModels(message.ollamaModels)
} }
}, []) }, [])
useEvent("message", handleMessage) useEvent("message", handleMessage)
@@ -529,23 +529,25 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
</p> </p>
)} )}
{selectedProvider !== "openai" && selectedProvider !== "ollama" && showModelOptions && ( {selectedProvider !== "openrouter" &&
<> selectedProvider !== "openai" &&
<div className="dropdown-container"> selectedProvider !== "ollama" &&
<label htmlFor="model-id"> showModelOptions && (
<span style={{ fontWeight: 500 }}>Model</span> <>
</label> <div className="dropdown-container">
{selectedProvider === "anthropic" && createDropdown(anthropicModels)} <label htmlFor="model-id">
{selectedProvider === "openrouter" && createDropdown(openRouterModels)} <span style={{ fontWeight: 500 }}>Model</span>
{selectedProvider === "bedrock" && createDropdown(bedrockModels)} </label>
{selectedProvider === "vertex" && createDropdown(vertexModels)} {selectedProvider === "anthropic" && createDropdown(anthropicModels)}
{selectedProvider === "gemini" && createDropdown(geminiModels)} {selectedProvider === "bedrock" && createDropdown(bedrockModels)}
{selectedProvider === "openai-native" && createDropdown(openAiNativeModels)} {selectedProvider === "vertex" && createDropdown(vertexModels)}
</div> {selectedProvider === "gemini" && createDropdown(geminiModels)}
{selectedProvider === "openai-native" && createDropdown(openAiNativeModels)}
</div>
<ModelInfoView selectedModelId={selectedModelId} modelInfo={selectedModelInfo} /> <ModelInfoView selectedModelId={selectedModelId} modelInfo={selectedModelInfo} />
</> </>
)} )}
</div> </div>
) )
} }
@@ -563,7 +565,7 @@ export const formatPrice = (price: number) => {
}).format(price) }).format(price)
} }
const ModelInfoView = ({ selectedModelId, modelInfo }: { selectedModelId: string; modelInfo: ModelInfo }) => { export const ModelInfoView = ({ selectedModelId, modelInfo }: { selectedModelId: string; modelInfo: ModelInfo }) => {
const isGemini = Object.keys(geminiModels).includes(selectedModelId) const isGemini = Object.keys(geminiModels).includes(selectedModelId)
const isO1 = selectedModelId && selectedModelId.includes("o1") const isO1 = selectedModelId && selectedModelId.includes("o1")
return ( return (
@@ -690,8 +692,6 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
switch (provider) { switch (provider) {
case "anthropic": case "anthropic":
return getProviderData(anthropicModels, anthropicDefaultModelId) return getProviderData(anthropicModels, anthropicDefaultModelId)
case "openrouter":
return getProviderData(openRouterModels, openRouterDefaultModelId)
case "bedrock": case "bedrock":
return getProviderData(bedrockModels, bedrockDefaultModelId) return getProviderData(bedrockModels, bedrockDefaultModelId)
case "vertex": case "vertex":
@@ -700,6 +700,12 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
return getProviderData(geminiModels, geminiDefaultModelId) return getProviderData(geminiModels, geminiDefaultModelId)
case "openai-native": case "openai-native":
return getProviderData(openAiNativeModels, openAiNativeDefaultModelId) return getProviderData(openAiNativeModels, openAiNativeDefaultModelId)
case "openrouter":
return {
selectedProvider: provider,
selectedModelId: apiConfiguration?.openRouterModelId ?? openRouterDefaultModelId,
selectedModelInfo: apiConfiguration?.openRouterModelInfo ?? openRouterDefaultModelInfo,
}
case "openai": case "openai":
return { return {
selectedProvider: provider, selectedProvider: provider,

View File

@@ -0,0 +1,59 @@
import { VSCodeDropdown, VSCodeOption } from "@vscode/webview-ui-toolkit/react"
import React, { useMemo } from "react"
import { useExtensionState } from "../../context/ExtensionStateContext"
import { ModelInfoView, normalizeApiConfiguration } from "./ApiOptions"
interface OpenRouterModelPickerProps {}
const OpenRouterModelPicker: React.FC<OpenRouterModelPickerProps> = () => {
const { apiConfiguration, setApiConfiguration, openRouterModels } = useExtensionState()
const handleModelChange = (event: any) => {
const newModelId = event.target.value
// get info
setApiConfiguration({ ...apiConfiguration, openRouterModelId: newModelId })
}
const { selectedModelId, selectedModelInfo } = useMemo(() => {
return normalizeApiConfiguration(apiConfiguration)
}, [apiConfiguration])
return (
<div style={{ display: "flex", flexDirection: "column", gap: 5 }}>
<div className="dropdown-container">
<label htmlFor="model-id">
<span style={{ fontWeight: 500 }}>Model</span>
</label>
<VSCodeDropdown
id="model-id"
value={selectedModelId}
onChange={handleModelChange}
style={{ width: "100%" }}>
<VSCodeOption value="">Select a model...</VSCodeOption>
{Object.keys(openRouterModels).map((modelId) => (
<VSCodeOption
key={modelId}
value={modelId}
style={{
whiteSpace: "normal",
wordWrap: "break-word",
maxWidth: "100%",
}}>
{modelId}
</VSCodeOption>
))}
</VSCodeDropdown>
</div>
{selectedModelInfo.description && (
<p style={{ fontSize: "12px", marginTop: "2px", color: "var(--vscode-descriptionForeground)" }}>
{selectedModelInfo.description}
</p>
)}
<ModelInfoView selectedModelId={selectedModelId} modelInfo={selectedModelInfo} />
</div>
)
}
export default OpenRouterModelPicker

View File

@@ -1,7 +1,7 @@
import React, { createContext, useCallback, useContext, useEffect, useState } from "react" import React, { createContext, useCallback, useContext, useEffect, useState } from "react"
import { useEvent } from "react-use" import { useEvent } from "react-use"
import { ExtensionMessage, ExtensionState } from "../../../src/shared/ExtensionMessage" import { ExtensionMessage, ExtensionState } from "../../../src/shared/ExtensionMessage"
import { ApiConfiguration } from "../../../src/shared/api" import { ApiConfiguration, ModelInfo } from "../../../src/shared/api"
import { vscode } from "../utils/vscode" import { vscode } from "../utils/vscode"
import { convertTextMateToHljs } from "../utils/textMateToHljs" import { convertTextMateToHljs } from "../utils/textMateToHljs"
import { findLastIndex } from "../../../src/shared/array" import { findLastIndex } from "../../../src/shared/array"
@@ -10,6 +10,7 @@ interface ExtensionStateContextType extends ExtensionState {
didHydrateState: boolean didHydrateState: boolean
showWelcome: boolean showWelcome: boolean
theme: any theme: any
openRouterModels: Record<string, ModelInfo>
filePaths: string[] filePaths: string[]
setApiConfiguration: (config: ApiConfiguration) => void setApiConfiguration: (config: ApiConfiguration) => void
setCustomInstructions: (value?: string) => void setCustomInstructions: (value?: string) => void
@@ -30,6 +31,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
const [showWelcome, setShowWelcome] = useState(false) const [showWelcome, setShowWelcome] = useState(false)
const [theme, setTheme] = useState<any>(undefined) const [theme, setTheme] = useState<any>(undefined)
const [filePaths, setFilePaths] = useState<string[]>([]) const [filePaths, setFilePaths] = useState<string[]>([])
const [openRouterModels, setOpenRouterModels] = useState<Record<string, ModelInfo>>({})
const handleMessage = useCallback((event: MessageEvent) => { const handleMessage = useCallback((event: MessageEvent) => {
const message: ExtensionMessage = event.data const message: ExtensionMessage = event.data
@@ -75,6 +77,11 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
} }
return prevState return prevState
}) })
break
}
case "openRouterModels": {
setOpenRouterModels(message.openRouterModels ?? {})
break
} }
} }
}, []) }, [])
@@ -90,6 +97,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
didHydrateState, didHydrateState,
showWelcome, showWelcome,
theme, theme,
openRouterModels,
filePaths, filePaths,
setApiConfiguration: (value) => setState((prevState) => ({ ...prevState, apiConfiguration: value })), setApiConfiguration: (value) => setState((prevState) => ({ ...prevState, apiConfiguration: value })),
setCustomInstructions: (value) => setState((prevState) => ({ ...prevState, customInstructions: value })), setCustomInstructions: (value) => setState((prevState) => ({ ...prevState, customInstructions: value })),

View File

@@ -14,8 +14,8 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s
} }
break break
case "openrouter": case "openrouter":
if (!apiConfiguration.openRouterApiKey) { if (!apiConfiguration.openRouterApiKey || !apiConfiguration.openRouterModelId) {
return "You must provide a valid API key or choose a different provider." return "You must provide a valid API key and model ID or choose a different provider."
} }
break break
case "vertex": case "vertex":