mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-20 12:21:13 -05:00
Add OpenRouter custom model scheme
This commit is contained in:
@@ -25,7 +25,7 @@ export class GeminiHandler implements ApiHandler {
|
|||||||
const result = await model.generateContentStream({
|
const result = await model.generateContentStream({
|
||||||
contents: messages.map(convertAnthropicMessageToGemini),
|
contents: messages.map(convertAnthropicMessageToGemini),
|
||||||
generationConfig: {
|
generationConfig: {
|
||||||
maxOutputTokens: this.getModel().info.maxTokens,
|
// maxOutputTokens: this.getModel().info.maxTokens,
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ export class OllamaHandler implements ApiHandler {
|
|||||||
]
|
]
|
||||||
|
|
||||||
const stream = await this.client.chat.completions.create({
|
const stream = await this.client.chat.completions.create({
|
||||||
model: this.options.ollamaModelId ?? "",
|
model: this.getModel().id,
|
||||||
messages: openAiMessages,
|
messages: openAiMessages,
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
stream: true,
|
stream: true,
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ export class OpenAiNativeHandler implements ApiHandler {
|
|||||||
|
|
||||||
const stream = await this.client.chat.completions.create({
|
const stream = await this.client.chat.completions.create({
|
||||||
model: this.getModel().id,
|
model: this.getModel().id,
|
||||||
max_completion_tokens: this.getModel().info.maxTokens,
|
// max_completion_tokens: this.getModel().info.maxTokens,
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
messages: openAiMessages,
|
messages: openAiMessages,
|
||||||
stream: true,
|
stream: true,
|
||||||
|
|||||||
@@ -2,13 +2,7 @@ import { Anthropic } from "@anthropic-ai/sdk"
|
|||||||
import axios from "axios"
|
import axios from "axios"
|
||||||
import OpenAI from "openai"
|
import OpenAI from "openai"
|
||||||
import { ApiHandler } from "../"
|
import { ApiHandler } from "../"
|
||||||
import {
|
import { ApiHandlerOptions, ModelInfo, openRouterDefaultModelId, openRouterDefaultModelInfo } from "../../shared/api"
|
||||||
ApiHandlerOptions,
|
|
||||||
ModelInfo,
|
|
||||||
openRouterDefaultModelId,
|
|
||||||
OpenRouterModelId,
|
|
||||||
openRouterModels,
|
|
||||||
} from "../../shared/api"
|
|
||||||
import { convertToOpenAiMessages } from "../transform/openai-format"
|
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||||
import { ApiStream } from "../transform/stream"
|
import { ApiStream } from "../transform/stream"
|
||||||
|
|
||||||
@@ -74,9 +68,18 @@ export class OpenRouterHandler implements ApiHandler {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Not sure how openrouter defaults max tokens when no value is provided, but the anthropic api requires this value and since they offer both 4096 and 8192 variants, we should ensure 8192.
|
||||||
|
// (models usually default to max tokens allowed)
|
||||||
|
let maxTokens: number | undefined
|
||||||
|
switch (this.getModel().id) {
|
||||||
|
case "anthropic/claude-3.5-sonnet":
|
||||||
|
case "anthropic/claude-3.5-sonnet:beta":
|
||||||
|
maxTokens = 8_192
|
||||||
|
break
|
||||||
|
}
|
||||||
const stream = await this.client.chat.completions.create({
|
const stream = await this.client.chat.completions.create({
|
||||||
model: this.getModel().id,
|
model: this.getModel().id,
|
||||||
max_tokens: this.getModel().info.maxTokens,
|
max_tokens: maxTokens,
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
messages: openAiMessages,
|
messages: openAiMessages,
|
||||||
stream: true,
|
stream: true,
|
||||||
@@ -129,12 +132,12 @@ export class OpenRouterHandler implements ApiHandler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
getModel(): { id: OpenRouterModelId; info: ModelInfo } {
|
getModel(): { id: string; info: ModelInfo } {
|
||||||
const modelId = this.options.apiModelId
|
const modelId = this.options.openRouterModelId
|
||||||
if (modelId && modelId in openRouterModels) {
|
const modelInfo = this.options.openRouterModelInfo
|
||||||
const id = modelId as OpenRouterModelId
|
if (modelId && modelInfo) {
|
||||||
return { id, info: openRouterModels[id] }
|
return { id: modelId, info: modelInfo }
|
||||||
}
|
}
|
||||||
return { id: openRouterDefaultModelId, info: openRouterModels[openRouterDefaultModelId] }
|
return { id: openRouterDefaultModelId, info: openRouterDefaultModelInfo }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { Anthropic } from "@anthropic-ai/sdk"
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import * as vscode from "vscode"
|
import * as vscode from "vscode"
|
||||||
import { ClaudeDev } from "../ClaudeDev"
|
import { ClaudeDev } from "../ClaudeDev"
|
||||||
import { ApiProvider } from "../../shared/api"
|
import { ApiProvider, ModelInfo } from "../../shared/api"
|
||||||
import { ExtensionMessage } from "../../shared/ExtensionMessage"
|
import { ExtensionMessage } from "../../shared/ExtensionMessage"
|
||||||
import { WebviewMessage } from "../../shared/WebviewMessage"
|
import { WebviewMessage } from "../../shared/WebviewMessage"
|
||||||
import { findLast } from "../../shared/array"
|
import { findLast } from "../../shared/array"
|
||||||
@@ -52,10 +52,13 @@ type GlobalStateKey =
|
|||||||
| "ollamaBaseUrl"
|
| "ollamaBaseUrl"
|
||||||
| "anthropicBaseUrl"
|
| "anthropicBaseUrl"
|
||||||
| "azureApiVersion"
|
| "azureApiVersion"
|
||||||
|
| "openRouterModelId"
|
||||||
|
| "openRouterModelInfo"
|
||||||
|
|
||||||
export const GlobalFileNames = {
|
export const GlobalFileNames = {
|
||||||
apiConversationHistory: "api_conversation_history.json",
|
apiConversationHistory: "api_conversation_history.json",
|
||||||
claudeMessages: "claude_messages.json",
|
claudeMessages: "claude_messages.json",
|
||||||
|
openRouterModels: "openrouter_models.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
export class ClaudeDevProvider implements vscode.WebviewViewProvider {
|
export class ClaudeDevProvider implements vscode.WebviewViewProvider {
|
||||||
@@ -322,10 +325,19 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
|
|||||||
async (message: WebviewMessage) => {
|
async (message: WebviewMessage) => {
|
||||||
switch (message.type) {
|
switch (message.type) {
|
||||||
case "webviewDidLaunch":
|
case "webviewDidLaunch":
|
||||||
await this.postStateToWebview()
|
this.postStateToWebview()
|
||||||
const theme = await getTheme()
|
this.workspaceTracker?.initializeFilePaths() // don't await
|
||||||
await this.postMessageToWebview({ type: "theme", text: JSON.stringify(theme) })
|
getTheme().then((theme) =>
|
||||||
this.workspaceTracker?.initializeFilePaths()
|
this.postMessageToWebview({ type: "theme", text: JSON.stringify(theme) })
|
||||||
|
)
|
||||||
|
this.readOpenRouterModels().then((openRouterModels) => {
|
||||||
|
if (openRouterModels) {
|
||||||
|
this.postMessageToWebview({ type: "openRouterModels", openRouterModels })
|
||||||
|
} else {
|
||||||
|
// nothing cached, fetch first time
|
||||||
|
this.refreshOpenRouterModels()
|
||||||
|
}
|
||||||
|
})
|
||||||
break
|
break
|
||||||
case "newTask":
|
case "newTask":
|
||||||
// Code that should run in response to the hello message command
|
// Code that should run in response to the hello message command
|
||||||
@@ -360,6 +372,8 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
|
|||||||
geminiApiKey,
|
geminiApiKey,
|
||||||
openAiNativeApiKey,
|
openAiNativeApiKey,
|
||||||
azureApiVersion,
|
azureApiVersion,
|
||||||
|
openRouterModelId,
|
||||||
|
openRouterModelInfo,
|
||||||
} = message.apiConfiguration
|
} = message.apiConfiguration
|
||||||
await this.updateGlobalState("apiProvider", apiProvider)
|
await this.updateGlobalState("apiProvider", apiProvider)
|
||||||
await this.updateGlobalState("apiModelId", apiModelId)
|
await this.updateGlobalState("apiModelId", apiModelId)
|
||||||
@@ -380,6 +394,8 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
|
|||||||
await this.storeSecret("geminiApiKey", geminiApiKey)
|
await this.storeSecret("geminiApiKey", geminiApiKey)
|
||||||
await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
|
await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
|
||||||
await this.updateGlobalState("azureApiVersion", azureApiVersion)
|
await this.updateGlobalState("azureApiVersion", azureApiVersion)
|
||||||
|
await this.updateGlobalState("openRouterModelId", openRouterModelId)
|
||||||
|
await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo)
|
||||||
if (this.claudeDev) {
|
if (this.claudeDev) {
|
||||||
this.claudeDev.api = buildApiHandler(message.apiConfiguration)
|
this.claudeDev.api = buildApiHandler(message.apiConfiguration)
|
||||||
}
|
}
|
||||||
@@ -431,8 +447,11 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
|
|||||||
await this.resetState()
|
await this.resetState()
|
||||||
break
|
break
|
||||||
case "requestOllamaModels":
|
case "requestOllamaModels":
|
||||||
const models = await this.getOllamaModels(message.text)
|
const ollamaModels = await this.getOllamaModels(message.text)
|
||||||
this.postMessageToWebview({ type: "ollamaModels", models })
|
this.postMessageToWebview({ type: "ollamaModels", ollamaModels })
|
||||||
|
break
|
||||||
|
case "refreshOpenRouterModels":
|
||||||
|
await this.refreshOpenRouterModels()
|
||||||
break
|
break
|
||||||
case "openImage":
|
case "openImage":
|
||||||
openImage(message.text!)
|
openImage(message.text!)
|
||||||
@@ -518,6 +537,97 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
|
|||||||
// await this.postMessageToWebview({ type: "action", action: "settingsButtonTapped" }) // bad ux if user is on welcome
|
// await this.postMessageToWebview({ type: "action", action: "settingsButtonTapped" }) // bad ux if user is on welcome
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async readOpenRouterModels(): Promise<Record<string, ModelInfo> | undefined> {
|
||||||
|
const cacheDir = path.join(this.context.globalStorageUri.fsPath, "cache")
|
||||||
|
const openRouterModelsFilePath = path.join(cacheDir, GlobalFileNames.openRouterModels)
|
||||||
|
const fileExists = await fileExistsAtPath(openRouterModelsFilePath)
|
||||||
|
if (fileExists) {
|
||||||
|
const fileContents = await fs.readFile(openRouterModelsFilePath, "utf8")
|
||||||
|
return JSON.parse(fileContents)
|
||||||
|
}
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
async refreshOpenRouterModels() {
|
||||||
|
const cacheDir = path.join(this.context.globalStorageUri.fsPath, "cache")
|
||||||
|
const openRouterModelsFilePath = path.join(cacheDir, GlobalFileNames.openRouterModels)
|
||||||
|
|
||||||
|
let models: Record<string, ModelInfo> = {}
|
||||||
|
try {
|
||||||
|
const response = await axios.get("https://openrouter.ai/api/v1/models")
|
||||||
|
/*
|
||||||
|
{
|
||||||
|
"id": "anthropic/claude-3.5-sonnet",
|
||||||
|
"name": "Anthropic: Claude 3.5 Sonnet",
|
||||||
|
"created": 1718841600,
|
||||||
|
"description": "Claude 3.5 Sonnet delivers better-than-Opus capabilities, faster-than-Sonnet speeds, at the same Sonnet prices. Sonnet is particularly good at:\n\n- Coding: Autonomously writes, edits, and runs code with reasoning and troubleshooting\n- Data science: Augments human data science expertise; navigates unstructured data while using multiple tools for insights\n- Visual processing: excelling at interpreting charts, graphs, and images, accurately transcribing text to derive insights beyond just the text alone\n- Agentic tasks: exceptional tool use, making it great at agentic tasks (i.e. complex, multi-step problem solving tasks that require engaging with other systems)\n\n#multimodal",
|
||||||
|
"context_length": 200000,
|
||||||
|
"architecture": {
|
||||||
|
"modality": "text+image-\u003Etext",
|
||||||
|
"tokenizer": "Claude",
|
||||||
|
"instruct_type": null
|
||||||
|
},
|
||||||
|
"pricing": {
|
||||||
|
"prompt": "0.000003",
|
||||||
|
"completion": "0.000015",
|
||||||
|
"image": "0.0048",
|
||||||
|
"request": "0"
|
||||||
|
},
|
||||||
|
"top_provider": {
|
||||||
|
"context_length": 200000,
|
||||||
|
"max_completion_tokens": 8192,
|
||||||
|
"is_moderated": true
|
||||||
|
},
|
||||||
|
"per_request_limits": null
|
||||||
|
},
|
||||||
|
*/
|
||||||
|
if (response.data) {
|
||||||
|
const rawModels = response.data
|
||||||
|
for (const rawModel of rawModels) {
|
||||||
|
const modelInfo: ModelInfo = {
|
||||||
|
maxTokens: rawModel.top_provider?.max_completion_tokens || 2048,
|
||||||
|
contextWindow: rawModel.context_length || 128_000,
|
||||||
|
supportsImages: rawModel.architecture?.modality?.includes("image") ?? false,
|
||||||
|
supportsPromptCache: false,
|
||||||
|
inputPrice: parseFloat(rawModel.pricing?.prompt || 0) * 1_000_000,
|
||||||
|
outputPrice: parseFloat(rawModel.pricing?.completion || 0) * 1_000_000,
|
||||||
|
description: rawModel.description,
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (rawModel.id) {
|
||||||
|
case "anthropic/claude-3.5-sonnet":
|
||||||
|
case "anthropic/claude-3.5-sonnet:beta":
|
||||||
|
modelInfo.supportsPromptCache = true
|
||||||
|
modelInfo.cacheWritesPrice = 3.75
|
||||||
|
modelInfo.cacheReadsPrice = 0.3
|
||||||
|
break
|
||||||
|
case "anthropic/claude-3-opus":
|
||||||
|
case "anthropic/claude-3-opus:beta":
|
||||||
|
modelInfo.supportsPromptCache = true
|
||||||
|
modelInfo.cacheWritesPrice = 18.75
|
||||||
|
modelInfo.cacheReadsPrice = 1.5
|
||||||
|
break
|
||||||
|
case "anthropic/claude-3-haiku":
|
||||||
|
case "anthropic/claude-3-haiku:beta":
|
||||||
|
modelInfo.supportsPromptCache = true
|
||||||
|
modelInfo.cacheWritesPrice = 0.3
|
||||||
|
modelInfo.cacheReadsPrice = 0.03
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
models[rawModel.id] = modelInfo
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.error("Invalid response from OpenRouter API")
|
||||||
|
}
|
||||||
|
await fs.writeFile(openRouterModelsFilePath, JSON.stringify(models))
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error fetching OpenRouter models:", error)
|
||||||
|
}
|
||||||
|
|
||||||
|
await this.postMessageToWebview({ type: "openRouterModels", openRouterModels: models })
|
||||||
|
}
|
||||||
|
|
||||||
// Task history
|
// Task history
|
||||||
|
|
||||||
async getTaskWithId(id: string): Promise<{
|
async getTaskWithId(id: string): Promise<{
|
||||||
@@ -722,6 +832,8 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
|
|||||||
geminiApiKey,
|
geminiApiKey,
|
||||||
openAiNativeApiKey,
|
openAiNativeApiKey,
|
||||||
azureApiVersion,
|
azureApiVersion,
|
||||||
|
openRouterModelId,
|
||||||
|
openRouterModelInfo,
|
||||||
lastShownAnnouncementId,
|
lastShownAnnouncementId,
|
||||||
customInstructions,
|
customInstructions,
|
||||||
alwaysAllowReadOnly,
|
alwaysAllowReadOnly,
|
||||||
@@ -746,6 +858,8 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
|
|||||||
this.getSecret("geminiApiKey") as Promise<string | undefined>,
|
this.getSecret("geminiApiKey") as Promise<string | undefined>,
|
||||||
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
|
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
|
||||||
this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
|
this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
|
||||||
|
this.getGlobalState("openRouterModelId") as Promise<string | undefined>,
|
||||||
|
this.getGlobalState("openRouterModelInfo") as Promise<ModelInfo | undefined>,
|
||||||
this.getGlobalState("lastShownAnnouncementId") as Promise<string | undefined>,
|
this.getGlobalState("lastShownAnnouncementId") as Promise<string | undefined>,
|
||||||
this.getGlobalState("customInstructions") as Promise<string | undefined>,
|
this.getGlobalState("customInstructions") as Promise<string | undefined>,
|
||||||
this.getGlobalState("alwaysAllowReadOnly") as Promise<boolean | undefined>,
|
this.getGlobalState("alwaysAllowReadOnly") as Promise<boolean | undefined>,
|
||||||
@@ -787,6 +901,8 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
|
|||||||
geminiApiKey,
|
geminiApiKey,
|
||||||
openAiNativeApiKey,
|
openAiNativeApiKey,
|
||||||
azureApiVersion,
|
azureApiVersion,
|
||||||
|
openRouterModelId,
|
||||||
|
openRouterModelInfo,
|
||||||
},
|
},
|
||||||
lastShownAnnouncementId,
|
lastShownAnnouncementId,
|
||||||
customInstructions,
|
customInstructions,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// type that represents json data that is sent from extension to webview, called ExtensionMessage and has 'type' enum which can be 'plusButtonTapped' or 'settingsButtonTapped' or 'hello'
|
// type that represents json data that is sent from extension to webview, called ExtensionMessage and has 'type' enum which can be 'plusButtonTapped' or 'settingsButtonTapped' or 'hello'
|
||||||
|
|
||||||
import { ApiConfiguration } from "./api"
|
import { ApiConfiguration, ModelInfo } from "./api"
|
||||||
import { HistoryItem } from "./HistoryItem"
|
import { HistoryItem } from "./HistoryItem"
|
||||||
|
|
||||||
// webview will hold state
|
// webview will hold state
|
||||||
@@ -14,14 +14,16 @@ export interface ExtensionMessage {
|
|||||||
| "workspaceUpdated"
|
| "workspaceUpdated"
|
||||||
| "invoke"
|
| "invoke"
|
||||||
| "partialMessage"
|
| "partialMessage"
|
||||||
|
| "openRouterModels"
|
||||||
text?: string
|
text?: string
|
||||||
action?: "chatButtonTapped" | "settingsButtonTapped" | "historyButtonTapped" | "didBecomeVisible"
|
action?: "chatButtonTapped" | "settingsButtonTapped" | "historyButtonTapped" | "didBecomeVisible"
|
||||||
invoke?: "sendMessage" | "primaryButtonClick" | "secondaryButtonClick"
|
invoke?: "sendMessage" | "primaryButtonClick" | "secondaryButtonClick"
|
||||||
state?: ExtensionState
|
state?: ExtensionState
|
||||||
images?: string[]
|
images?: string[]
|
||||||
models?: string[]
|
ollamaModels?: string[]
|
||||||
filePaths?: string[]
|
filePaths?: string[]
|
||||||
partialMessage?: ClaudeMessage
|
partialMessage?: ClaudeMessage
|
||||||
|
openRouterModels?: Record<string, ModelInfo>
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ExtensionState {
|
export interface ExtensionState {
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ export interface WebviewMessage {
|
|||||||
| "openFile"
|
| "openFile"
|
||||||
| "openMention"
|
| "openMention"
|
||||||
| "cancelTask"
|
| "cancelTask"
|
||||||
|
| "refreshOpenRouterModels"
|
||||||
text?: string
|
text?: string
|
||||||
askResponse?: ClaudeAskResponse
|
askResponse?: ClaudeAskResponse
|
||||||
apiConfiguration?: ApiConfiguration
|
apiConfiguration?: ApiConfiguration
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ export interface ApiHandlerOptions {
|
|||||||
apiKey?: string // anthropic
|
apiKey?: string // anthropic
|
||||||
anthropicBaseUrl?: string
|
anthropicBaseUrl?: string
|
||||||
openRouterApiKey?: string
|
openRouterApiKey?: string
|
||||||
|
openRouterModelId?: string
|
||||||
|
openRouterModelInfo?: ModelInfo
|
||||||
awsAccessKey?: string
|
awsAccessKey?: string
|
||||||
awsSecretKey?: string
|
awsSecretKey?: string
|
||||||
awsSessionToken?: string
|
awsSessionToken?: string
|
||||||
@@ -44,6 +46,7 @@ export interface ModelInfo {
|
|||||||
outputPrice: number
|
outputPrice: number
|
||||||
cacheWritesPrice?: number
|
cacheWritesPrice?: number
|
||||||
cacheReadsPrice?: number
|
cacheReadsPrice?: number
|
||||||
|
description?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Anthropic
|
// Anthropic
|
||||||
@@ -116,9 +119,19 @@ export const bedrockModels = {
|
|||||||
|
|
||||||
// OpenRouter
|
// OpenRouter
|
||||||
// https://openrouter.ai/models?order=newest&supported_parameters=tools
|
// https://openrouter.ai/models?order=newest&supported_parameters=tools
|
||||||
export type OpenRouterModelId = keyof typeof openRouterModels
|
type OpenRouterModelId = keyof typeof openRouterModels
|
||||||
export const openRouterDefaultModelId: OpenRouterModelId = "anthropic/claude-3.5-sonnet:beta"
|
export const openRouterDefaultModelId = "anthropic/claude-3.5-sonnet:beta" // will always exist in openRouterModels
|
||||||
export const openRouterModels = {
|
export const openRouterDefaultModelInfo: ModelInfo = {
|
||||||
|
maxTokens: 8192,
|
||||||
|
contextWindow: 200_000,
|
||||||
|
supportsImages: true,
|
||||||
|
supportsPromptCache: true,
|
||||||
|
inputPrice: 3.0,
|
||||||
|
outputPrice: 15.0,
|
||||||
|
cacheWritesPrice: 3.75,
|
||||||
|
cacheReadsPrice: 0.3,
|
||||||
|
}
|
||||||
|
const openRouterModels = {
|
||||||
"anthropic/claude-3.5-sonnet:beta": {
|
"anthropic/claude-3.5-sonnet:beta": {
|
||||||
maxTokens: 8192,
|
maxTokens: 8192,
|
||||||
contextWindow: 200_000,
|
contextWindow: 200_000,
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import {
|
|||||||
openAiNativeDefaultModelId,
|
openAiNativeDefaultModelId,
|
||||||
openAiNativeModels,
|
openAiNativeModels,
|
||||||
openRouterDefaultModelId,
|
openRouterDefaultModelId,
|
||||||
openRouterModels,
|
openRouterDefaultModelInfo,
|
||||||
vertexDefaultModelId,
|
vertexDefaultModelId,
|
||||||
vertexModels,
|
vertexModels,
|
||||||
} from "../../../../src/shared/api"
|
} from "../../../../src/shared/api"
|
||||||
@@ -66,8 +66,8 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
|
|||||||
|
|
||||||
const handleMessage = useCallback((event: MessageEvent) => {
|
const handleMessage = useCallback((event: MessageEvent) => {
|
||||||
const message: ExtensionMessage = event.data
|
const message: ExtensionMessage = event.data
|
||||||
if (message.type === "ollamaModels" && message.models) {
|
if (message.type === "ollamaModels" && message.ollamaModels) {
|
||||||
setOllamaModels(message.models)
|
setOllamaModels(message.ollamaModels)
|
||||||
}
|
}
|
||||||
}, [])
|
}, [])
|
||||||
useEvent("message", handleMessage)
|
useEvent("message", handleMessage)
|
||||||
@@ -529,14 +529,16 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
|
|||||||
</p>
|
</p>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{selectedProvider !== "openai" && selectedProvider !== "ollama" && showModelOptions && (
|
{selectedProvider !== "openrouter" &&
|
||||||
|
selectedProvider !== "openai" &&
|
||||||
|
selectedProvider !== "ollama" &&
|
||||||
|
showModelOptions && (
|
||||||
<>
|
<>
|
||||||
<div className="dropdown-container">
|
<div className="dropdown-container">
|
||||||
<label htmlFor="model-id">
|
<label htmlFor="model-id">
|
||||||
<span style={{ fontWeight: 500 }}>Model</span>
|
<span style={{ fontWeight: 500 }}>Model</span>
|
||||||
</label>
|
</label>
|
||||||
{selectedProvider === "anthropic" && createDropdown(anthropicModels)}
|
{selectedProvider === "anthropic" && createDropdown(anthropicModels)}
|
||||||
{selectedProvider === "openrouter" && createDropdown(openRouterModels)}
|
|
||||||
{selectedProvider === "bedrock" && createDropdown(bedrockModels)}
|
{selectedProvider === "bedrock" && createDropdown(bedrockModels)}
|
||||||
{selectedProvider === "vertex" && createDropdown(vertexModels)}
|
{selectedProvider === "vertex" && createDropdown(vertexModels)}
|
||||||
{selectedProvider === "gemini" && createDropdown(geminiModels)}
|
{selectedProvider === "gemini" && createDropdown(geminiModels)}
|
||||||
@@ -563,7 +565,7 @@ export const formatPrice = (price: number) => {
|
|||||||
}).format(price)
|
}).format(price)
|
||||||
}
|
}
|
||||||
|
|
||||||
const ModelInfoView = ({ selectedModelId, modelInfo }: { selectedModelId: string; modelInfo: ModelInfo }) => {
|
export const ModelInfoView = ({ selectedModelId, modelInfo }: { selectedModelId: string; modelInfo: ModelInfo }) => {
|
||||||
const isGemini = Object.keys(geminiModels).includes(selectedModelId)
|
const isGemini = Object.keys(geminiModels).includes(selectedModelId)
|
||||||
const isO1 = selectedModelId && selectedModelId.includes("o1")
|
const isO1 = selectedModelId && selectedModelId.includes("o1")
|
||||||
return (
|
return (
|
||||||
@@ -690,8 +692,6 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
|
|||||||
switch (provider) {
|
switch (provider) {
|
||||||
case "anthropic":
|
case "anthropic":
|
||||||
return getProviderData(anthropicModels, anthropicDefaultModelId)
|
return getProviderData(anthropicModels, anthropicDefaultModelId)
|
||||||
case "openrouter":
|
|
||||||
return getProviderData(openRouterModels, openRouterDefaultModelId)
|
|
||||||
case "bedrock":
|
case "bedrock":
|
||||||
return getProviderData(bedrockModels, bedrockDefaultModelId)
|
return getProviderData(bedrockModels, bedrockDefaultModelId)
|
||||||
case "vertex":
|
case "vertex":
|
||||||
@@ -700,6 +700,12 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
|
|||||||
return getProviderData(geminiModels, geminiDefaultModelId)
|
return getProviderData(geminiModels, geminiDefaultModelId)
|
||||||
case "openai-native":
|
case "openai-native":
|
||||||
return getProviderData(openAiNativeModels, openAiNativeDefaultModelId)
|
return getProviderData(openAiNativeModels, openAiNativeDefaultModelId)
|
||||||
|
case "openrouter":
|
||||||
|
return {
|
||||||
|
selectedProvider: provider,
|
||||||
|
selectedModelId: apiConfiguration?.openRouterModelId ?? openRouterDefaultModelId,
|
||||||
|
selectedModelInfo: apiConfiguration?.openRouterModelInfo ?? openRouterDefaultModelInfo,
|
||||||
|
}
|
||||||
case "openai":
|
case "openai":
|
||||||
return {
|
return {
|
||||||
selectedProvider: provider,
|
selectedProvider: provider,
|
||||||
|
|||||||
59
webview-ui/src/components/settings/OpenRouterModelPicker.tsx
Normal file
59
webview-ui/src/components/settings/OpenRouterModelPicker.tsx
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import { VSCodeDropdown, VSCodeOption } from "@vscode/webview-ui-toolkit/react"
|
||||||
|
import React, { useMemo } from "react"
|
||||||
|
import { useExtensionState } from "../../context/ExtensionStateContext"
|
||||||
|
import { ModelInfoView, normalizeApiConfiguration } from "./ApiOptions"
|
||||||
|
|
||||||
|
interface OpenRouterModelPickerProps {}
|
||||||
|
|
||||||
|
const OpenRouterModelPicker: React.FC<OpenRouterModelPickerProps> = () => {
|
||||||
|
const { apiConfiguration, setApiConfiguration, openRouterModels } = useExtensionState()
|
||||||
|
|
||||||
|
const handleModelChange = (event: any) => {
|
||||||
|
const newModelId = event.target.value
|
||||||
|
// get info
|
||||||
|
setApiConfiguration({ ...apiConfiguration, openRouterModelId: newModelId })
|
||||||
|
}
|
||||||
|
|
||||||
|
const { selectedModelId, selectedModelInfo } = useMemo(() => {
|
||||||
|
return normalizeApiConfiguration(apiConfiguration)
|
||||||
|
}, [apiConfiguration])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div style={{ display: "flex", flexDirection: "column", gap: 5 }}>
|
||||||
|
<div className="dropdown-container">
|
||||||
|
<label htmlFor="model-id">
|
||||||
|
<span style={{ fontWeight: 500 }}>Model</span>
|
||||||
|
</label>
|
||||||
|
<VSCodeDropdown
|
||||||
|
id="model-id"
|
||||||
|
value={selectedModelId}
|
||||||
|
onChange={handleModelChange}
|
||||||
|
style={{ width: "100%" }}>
|
||||||
|
<VSCodeOption value="">Select a model...</VSCodeOption>
|
||||||
|
{Object.keys(openRouterModels).map((modelId) => (
|
||||||
|
<VSCodeOption
|
||||||
|
key={modelId}
|
||||||
|
value={modelId}
|
||||||
|
style={{
|
||||||
|
whiteSpace: "normal",
|
||||||
|
wordWrap: "break-word",
|
||||||
|
maxWidth: "100%",
|
||||||
|
}}>
|
||||||
|
{modelId}
|
||||||
|
</VSCodeOption>
|
||||||
|
))}
|
||||||
|
</VSCodeDropdown>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{selectedModelInfo.description && (
|
||||||
|
<p style={{ fontSize: "12px", marginTop: "2px", color: "var(--vscode-descriptionForeground)" }}>
|
||||||
|
{selectedModelInfo.description}
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<ModelInfoView selectedModelId={selectedModelId} modelInfo={selectedModelInfo} />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default OpenRouterModelPicker
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import React, { createContext, useCallback, useContext, useEffect, useState } from "react"
|
import React, { createContext, useCallback, useContext, useEffect, useState } from "react"
|
||||||
import { useEvent } from "react-use"
|
import { useEvent } from "react-use"
|
||||||
import { ExtensionMessage, ExtensionState } from "../../../src/shared/ExtensionMessage"
|
import { ExtensionMessage, ExtensionState } from "../../../src/shared/ExtensionMessage"
|
||||||
import { ApiConfiguration } from "../../../src/shared/api"
|
import { ApiConfiguration, ModelInfo } from "../../../src/shared/api"
|
||||||
import { vscode } from "../utils/vscode"
|
import { vscode } from "../utils/vscode"
|
||||||
import { convertTextMateToHljs } from "../utils/textMateToHljs"
|
import { convertTextMateToHljs } from "../utils/textMateToHljs"
|
||||||
import { findLastIndex } from "../../../src/shared/array"
|
import { findLastIndex } from "../../../src/shared/array"
|
||||||
@@ -10,6 +10,7 @@ interface ExtensionStateContextType extends ExtensionState {
|
|||||||
didHydrateState: boolean
|
didHydrateState: boolean
|
||||||
showWelcome: boolean
|
showWelcome: boolean
|
||||||
theme: any
|
theme: any
|
||||||
|
openRouterModels: Record<string, ModelInfo>
|
||||||
filePaths: string[]
|
filePaths: string[]
|
||||||
setApiConfiguration: (config: ApiConfiguration) => void
|
setApiConfiguration: (config: ApiConfiguration) => void
|
||||||
setCustomInstructions: (value?: string) => void
|
setCustomInstructions: (value?: string) => void
|
||||||
@@ -30,6 +31,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
|
|||||||
const [showWelcome, setShowWelcome] = useState(false)
|
const [showWelcome, setShowWelcome] = useState(false)
|
||||||
const [theme, setTheme] = useState<any>(undefined)
|
const [theme, setTheme] = useState<any>(undefined)
|
||||||
const [filePaths, setFilePaths] = useState<string[]>([])
|
const [filePaths, setFilePaths] = useState<string[]>([])
|
||||||
|
const [openRouterModels, setOpenRouterModels] = useState<Record<string, ModelInfo>>({})
|
||||||
|
|
||||||
const handleMessage = useCallback((event: MessageEvent) => {
|
const handleMessage = useCallback((event: MessageEvent) => {
|
||||||
const message: ExtensionMessage = event.data
|
const message: ExtensionMessage = event.data
|
||||||
@@ -75,6 +77,11 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
|
|||||||
}
|
}
|
||||||
return prevState
|
return prevState
|
||||||
})
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case "openRouterModels": {
|
||||||
|
setOpenRouterModels(message.openRouterModels ?? {})
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, [])
|
}, [])
|
||||||
@@ -90,6 +97,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
|
|||||||
didHydrateState,
|
didHydrateState,
|
||||||
showWelcome,
|
showWelcome,
|
||||||
theme,
|
theme,
|
||||||
|
openRouterModels,
|
||||||
filePaths,
|
filePaths,
|
||||||
setApiConfiguration: (value) => setState((prevState) => ({ ...prevState, apiConfiguration: value })),
|
setApiConfiguration: (value) => setState((prevState) => ({ ...prevState, apiConfiguration: value })),
|
||||||
setCustomInstructions: (value) => setState((prevState) => ({ ...prevState, customInstructions: value })),
|
setCustomInstructions: (value) => setState((prevState) => ({ ...prevState, customInstructions: value })),
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s
|
|||||||
}
|
}
|
||||||
break
|
break
|
||||||
case "openrouter":
|
case "openrouter":
|
||||||
if (!apiConfiguration.openRouterApiKey) {
|
if (!apiConfiguration.openRouterApiKey || !apiConfiguration.openRouterModelId) {
|
||||||
return "You must provide a valid API key or choose a different provider."
|
return "You must provide a valid API key and model ID or choose a different provider."
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
case "vertex":
|
case "vertex":
|
||||||
|
|||||||
Reference in New Issue
Block a user