Add OpenRouter custom model scheme

This commit is contained in:
Saoud Rizwan
2024-10-03 19:46:09 -04:00
parent d5b3bd7788
commit 7cb0c524e5
12 changed files with 262 additions and 54 deletions

View File

@@ -25,7 +25,7 @@ export class GeminiHandler implements ApiHandler {
const result = await model.generateContentStream({
contents: messages.map(convertAnthropicMessageToGemini),
generationConfig: {
maxOutputTokens: this.getModel().info.maxTokens,
// maxOutputTokens: this.getModel().info.maxTokens,
temperature: 0,
},
})

View File

@@ -24,7 +24,7 @@ export class OllamaHandler implements ApiHandler {
]
const stream = await this.client.chat.completions.create({
model: this.options.ollamaModelId ?? "",
model: this.getModel().id,
messages: openAiMessages,
temperature: 0,
stream: true,

View File

@@ -30,7 +30,7 @@ export class OpenAiNativeHandler implements ApiHandler {
const stream = await this.client.chat.completions.create({
model: this.getModel().id,
max_completion_tokens: this.getModel().info.maxTokens,
// max_completion_tokens: this.getModel().info.maxTokens,
temperature: 0,
messages: openAiMessages,
stream: true,

View File

@@ -2,13 +2,7 @@ import { Anthropic } from "@anthropic-ai/sdk"
import axios from "axios"
import OpenAI from "openai"
import { ApiHandler } from "../"
import {
ApiHandlerOptions,
ModelInfo,
openRouterDefaultModelId,
OpenRouterModelId,
openRouterModels,
} from "../../shared/api"
import { ApiHandlerOptions, ModelInfo, openRouterDefaultModelId, openRouterDefaultModelInfo } from "../../shared/api"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"
@@ -74,9 +68,18 @@ export class OpenRouterHandler implements ApiHandler {
break
}
// Not sure how openrouter defaults max tokens when no value is provided, but the anthropic api requires this value and since they offer both 4096 and 8192 variants, we should ensure 8192.
// (models usually default to max tokens allowed)
let maxTokens: number | undefined
switch (this.getModel().id) {
case "anthropic/claude-3.5-sonnet":
case "anthropic/claude-3.5-sonnet:beta":
maxTokens = 8_192
break
}
const stream = await this.client.chat.completions.create({
model: this.getModel().id,
max_tokens: this.getModel().info.maxTokens,
max_tokens: maxTokens,
temperature: 0,
messages: openAiMessages,
stream: true,
@@ -129,12 +132,12 @@ export class OpenRouterHandler implements ApiHandler {
}
}
getModel(): { id: OpenRouterModelId; info: ModelInfo } {
const modelId = this.options.apiModelId
if (modelId && modelId in openRouterModels) {
const id = modelId as OpenRouterModelId
return { id, info: openRouterModels[id] }
getModel(): { id: string; info: ModelInfo } {
const modelId = this.options.openRouterModelId
const modelInfo = this.options.openRouterModelInfo
if (modelId && modelInfo) {
return { id: modelId, info: modelInfo }
}
return { id: openRouterDefaultModelId, info: openRouterModels[openRouterDefaultModelId] }
return { id: openRouterDefaultModelId, info: openRouterDefaultModelInfo }
}
}

View File

@@ -1,7 +1,7 @@
import { Anthropic } from "@anthropic-ai/sdk"
import * as vscode from "vscode"
import { ClaudeDev } from "../ClaudeDev"
import { ApiProvider } from "../../shared/api"
import { ApiProvider, ModelInfo } from "../../shared/api"
import { ExtensionMessage } from "../../shared/ExtensionMessage"
import { WebviewMessage } from "../../shared/WebviewMessage"
import { findLast } from "../../shared/array"
@@ -52,10 +52,13 @@ type GlobalStateKey =
| "ollamaBaseUrl"
| "anthropicBaseUrl"
| "azureApiVersion"
| "openRouterModelId"
| "openRouterModelInfo"
export const GlobalFileNames = {
apiConversationHistory: "api_conversation_history.json",
claudeMessages: "claude_messages.json",
openRouterModels: "openrouter_models.json",
}
export class ClaudeDevProvider implements vscode.WebviewViewProvider {
@@ -322,10 +325,19 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
async (message: WebviewMessage) => {
switch (message.type) {
case "webviewDidLaunch":
await this.postStateToWebview()
const theme = await getTheme()
await this.postMessageToWebview({ type: "theme", text: JSON.stringify(theme) })
this.workspaceTracker?.initializeFilePaths()
this.postStateToWebview()
this.workspaceTracker?.initializeFilePaths() // don't await
getTheme().then((theme) =>
this.postMessageToWebview({ type: "theme", text: JSON.stringify(theme) })
)
this.readOpenRouterModels().then((openRouterModels) => {
if (openRouterModels) {
this.postMessageToWebview({ type: "openRouterModels", openRouterModels })
} else {
// nothing cached, fetch first time
this.refreshOpenRouterModels()
}
})
break
case "newTask":
// Code that should run in response to the hello message command
@@ -360,6 +372,8 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
geminiApiKey,
openAiNativeApiKey,
azureApiVersion,
openRouterModelId,
openRouterModelInfo,
} = message.apiConfiguration
await this.updateGlobalState("apiProvider", apiProvider)
await this.updateGlobalState("apiModelId", apiModelId)
@@ -380,6 +394,8 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
await this.storeSecret("geminiApiKey", geminiApiKey)
await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
await this.updateGlobalState("azureApiVersion", azureApiVersion)
await this.updateGlobalState("openRouterModelId", openRouterModelId)
await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo)
if (this.claudeDev) {
this.claudeDev.api = buildApiHandler(message.apiConfiguration)
}
@@ -431,8 +447,11 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
await this.resetState()
break
case "requestOllamaModels":
const models = await this.getOllamaModels(message.text)
this.postMessageToWebview({ type: "ollamaModels", models })
const ollamaModels = await this.getOllamaModels(message.text)
this.postMessageToWebview({ type: "ollamaModels", ollamaModels })
break
case "refreshOpenRouterModels":
await this.refreshOpenRouterModels()
break
case "openImage":
openImage(message.text!)
@@ -518,6 +537,97 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
// await this.postMessageToWebview({ type: "action", action: "settingsButtonTapped" }) // bad ux if user is on welcome
}
async readOpenRouterModels(): Promise<Record<string, ModelInfo> | undefined> {
const cacheDir = path.join(this.context.globalStorageUri.fsPath, "cache")
const openRouterModelsFilePath = path.join(cacheDir, GlobalFileNames.openRouterModels)
const fileExists = await fileExistsAtPath(openRouterModelsFilePath)
if (fileExists) {
const fileContents = await fs.readFile(openRouterModelsFilePath, "utf8")
return JSON.parse(fileContents)
}
return undefined
}
async refreshOpenRouterModels() {
const cacheDir = path.join(this.context.globalStorageUri.fsPath, "cache")
const openRouterModelsFilePath = path.join(cacheDir, GlobalFileNames.openRouterModels)
let models: Record<string, ModelInfo> = {}
try {
const response = await axios.get("https://openrouter.ai/api/v1/models")
/*
{
"id": "anthropic/claude-3.5-sonnet",
"name": "Anthropic: Claude 3.5 Sonnet",
"created": 1718841600,
"description": "Claude 3.5 Sonnet delivers better-than-Opus capabilities, faster-than-Sonnet speeds, at the same Sonnet prices. Sonnet is particularly good at:\n\n- Coding: Autonomously writes, edits, and runs code with reasoning and troubleshooting\n- Data science: Augments human data science expertise; navigates unstructured data while using multiple tools for insights\n- Visual processing: excelling at interpreting charts, graphs, and images, accurately transcribing text to derive insights beyond just the text alone\n- Agentic tasks: exceptional tool use, making it great at agentic tasks (i.e. complex, multi-step problem solving tasks that require engaging with other systems)\n\n#multimodal",
"context_length": 200000,
"architecture": {
"modality": "text+image-\u003Etext",
"tokenizer": "Claude",
"instruct_type": null
},
"pricing": {
"prompt": "0.000003",
"completion": "0.000015",
"image": "0.0048",
"request": "0"
},
"top_provider": {
"context_length": 200000,
"max_completion_tokens": 8192,
"is_moderated": true
},
"per_request_limits": null
},
*/
if (response.data) {
const rawModels = response.data
for (const rawModel of rawModels) {
const modelInfo: ModelInfo = {
maxTokens: rawModel.top_provider?.max_completion_tokens || 2048,
contextWindow: rawModel.context_length || 128_000,
supportsImages: rawModel.architecture?.modality?.includes("image") ?? false,
supportsPromptCache: false,
inputPrice: parseFloat(rawModel.pricing?.prompt || 0) * 1_000_000,
outputPrice: parseFloat(rawModel.pricing?.completion || 0) * 1_000_000,
description: rawModel.description,
}
switch (rawModel.id) {
case "anthropic/claude-3.5-sonnet":
case "anthropic/claude-3.5-sonnet:beta":
modelInfo.supportsPromptCache = true
modelInfo.cacheWritesPrice = 3.75
modelInfo.cacheReadsPrice = 0.3
break
case "anthropic/claude-3-opus":
case "anthropic/claude-3-opus:beta":
modelInfo.supportsPromptCache = true
modelInfo.cacheWritesPrice = 18.75
modelInfo.cacheReadsPrice = 1.5
break
case "anthropic/claude-3-haiku":
case "anthropic/claude-3-haiku:beta":
modelInfo.supportsPromptCache = true
modelInfo.cacheWritesPrice = 0.3
modelInfo.cacheReadsPrice = 0.03
break
}
models[rawModel.id] = modelInfo
}
} else {
console.error("Invalid response from OpenRouter API")
}
await fs.writeFile(openRouterModelsFilePath, JSON.stringify(models))
} catch (error) {
console.error("Error fetching OpenRouter models:", error)
}
await this.postMessageToWebview({ type: "openRouterModels", openRouterModels: models })
}
// Task history
async getTaskWithId(id: string): Promise<{
@@ -722,6 +832,8 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
geminiApiKey,
openAiNativeApiKey,
azureApiVersion,
openRouterModelId,
openRouterModelInfo,
lastShownAnnouncementId,
customInstructions,
alwaysAllowReadOnly,
@@ -746,6 +858,8 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
this.getSecret("geminiApiKey") as Promise<string | undefined>,
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
this.getGlobalState("openRouterModelId") as Promise<string | undefined>,
this.getGlobalState("openRouterModelInfo") as Promise<ModelInfo | undefined>,
this.getGlobalState("lastShownAnnouncementId") as Promise<string | undefined>,
this.getGlobalState("customInstructions") as Promise<string | undefined>,
this.getGlobalState("alwaysAllowReadOnly") as Promise<boolean | undefined>,
@@ -787,6 +901,8 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
geminiApiKey,
openAiNativeApiKey,
azureApiVersion,
openRouterModelId,
openRouterModelInfo,
},
lastShownAnnouncementId,
customInstructions,

View File

@@ -1,6 +1,6 @@
// type that represents json data that is sent from extension to webview, called ExtensionMessage and has 'type' enum which can be 'plusButtonTapped' or 'settingsButtonTapped' or 'hello'
import { ApiConfiguration } from "./api"
import { ApiConfiguration, ModelInfo } from "./api"
import { HistoryItem } from "./HistoryItem"
// webview will hold state
@@ -14,14 +14,16 @@ export interface ExtensionMessage {
| "workspaceUpdated"
| "invoke"
| "partialMessage"
| "openRouterModels"
text?: string
action?: "chatButtonTapped" | "settingsButtonTapped" | "historyButtonTapped" | "didBecomeVisible"
invoke?: "sendMessage" | "primaryButtonClick" | "secondaryButtonClick"
state?: ExtensionState
images?: string[]
models?: string[]
ollamaModels?: string[]
filePaths?: string[]
partialMessage?: ClaudeMessage
openRouterModels?: Record<string, ModelInfo>
}
export interface ExtensionState {

View File

@@ -21,6 +21,7 @@ export interface WebviewMessage {
| "openFile"
| "openMention"
| "cancelTask"
| "refreshOpenRouterModels"
text?: string
askResponse?: ClaudeAskResponse
apiConfiguration?: ApiConfiguration

View File

@@ -13,6 +13,8 @@ export interface ApiHandlerOptions {
apiKey?: string // anthropic
anthropicBaseUrl?: string
openRouterApiKey?: string
openRouterModelId?: string
openRouterModelInfo?: ModelInfo
awsAccessKey?: string
awsSecretKey?: string
awsSessionToken?: string
@@ -44,6 +46,7 @@ export interface ModelInfo {
outputPrice: number
cacheWritesPrice?: number
cacheReadsPrice?: number
description?: string
}
// Anthropic
@@ -116,9 +119,19 @@ export const bedrockModels = {
// OpenRouter
// https://openrouter.ai/models?order=newest&supported_parameters=tools
export type OpenRouterModelId = keyof typeof openRouterModels
export const openRouterDefaultModelId: OpenRouterModelId = "anthropic/claude-3.5-sonnet:beta"
export const openRouterModels = {
type OpenRouterModelId = keyof typeof openRouterModels
export const openRouterDefaultModelId = "anthropic/claude-3.5-sonnet:beta" // will always exist in openRouterModels
export const openRouterDefaultModelInfo: ModelInfo = {
maxTokens: 8192,
contextWindow: 200_000,
supportsImages: true,
supportsPromptCache: true,
inputPrice: 3.0,
outputPrice: 15.0,
cacheWritesPrice: 3.75,
cacheReadsPrice: 0.3,
}
const openRouterModels = {
"anthropic/claude-3.5-sonnet:beta": {
maxTokens: 8192,
contextWindow: 200_000,

View File

@@ -23,7 +23,7 @@ import {
openAiNativeDefaultModelId,
openAiNativeModels,
openRouterDefaultModelId,
openRouterModels,
openRouterDefaultModelInfo,
vertexDefaultModelId,
vertexModels,
} from "../../../../src/shared/api"
@@ -66,8 +66,8 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
const handleMessage = useCallback((event: MessageEvent) => {
const message: ExtensionMessage = event.data
if (message.type === "ollamaModels" && message.models) {
setOllamaModels(message.models)
if (message.type === "ollamaModels" && message.ollamaModels) {
setOllamaModels(message.ollamaModels)
}
}, [])
useEvent("message", handleMessage)
@@ -529,23 +529,25 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
</p>
)}
{selectedProvider !== "openai" && selectedProvider !== "ollama" && showModelOptions && (
<>
<div className="dropdown-container">
<label htmlFor="model-id">
<span style={{ fontWeight: 500 }}>Model</span>
</label>
{selectedProvider === "anthropic" && createDropdown(anthropicModels)}
{selectedProvider === "openrouter" && createDropdown(openRouterModels)}
{selectedProvider === "bedrock" && createDropdown(bedrockModels)}
{selectedProvider === "vertex" && createDropdown(vertexModels)}
{selectedProvider === "gemini" && createDropdown(geminiModels)}
{selectedProvider === "openai-native" && createDropdown(openAiNativeModels)}
</div>
{selectedProvider !== "openrouter" &&
selectedProvider !== "openai" &&
selectedProvider !== "ollama" &&
showModelOptions && (
<>
<div className="dropdown-container">
<label htmlFor="model-id">
<span style={{ fontWeight: 500 }}>Model</span>
</label>
{selectedProvider === "anthropic" && createDropdown(anthropicModels)}
{selectedProvider === "bedrock" && createDropdown(bedrockModels)}
{selectedProvider === "vertex" && createDropdown(vertexModels)}
{selectedProvider === "gemini" && createDropdown(geminiModels)}
{selectedProvider === "openai-native" && createDropdown(openAiNativeModels)}
</div>
<ModelInfoView selectedModelId={selectedModelId} modelInfo={selectedModelInfo} />
</>
)}
<ModelInfoView selectedModelId={selectedModelId} modelInfo={selectedModelInfo} />
</>
)}
</div>
)
}
@@ -563,7 +565,7 @@ export const formatPrice = (price: number) => {
}).format(price)
}
const ModelInfoView = ({ selectedModelId, modelInfo }: { selectedModelId: string; modelInfo: ModelInfo }) => {
export const ModelInfoView = ({ selectedModelId, modelInfo }: { selectedModelId: string; modelInfo: ModelInfo }) => {
const isGemini = Object.keys(geminiModels).includes(selectedModelId)
const isO1 = selectedModelId && selectedModelId.includes("o1")
return (
@@ -690,8 +692,6 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
switch (provider) {
case "anthropic":
return getProviderData(anthropicModels, anthropicDefaultModelId)
case "openrouter":
return getProviderData(openRouterModels, openRouterDefaultModelId)
case "bedrock":
return getProviderData(bedrockModels, bedrockDefaultModelId)
case "vertex":
@@ -700,6 +700,12 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
return getProviderData(geminiModels, geminiDefaultModelId)
case "openai-native":
return getProviderData(openAiNativeModels, openAiNativeDefaultModelId)
case "openrouter":
return {
selectedProvider: provider,
selectedModelId: apiConfiguration?.openRouterModelId ?? openRouterDefaultModelId,
selectedModelInfo: apiConfiguration?.openRouterModelInfo ?? openRouterDefaultModelInfo,
}
case "openai":
return {
selectedProvider: provider,

View File

@@ -0,0 +1,59 @@
import { VSCodeDropdown, VSCodeOption } from "@vscode/webview-ui-toolkit/react"
import React, { useMemo } from "react"
import { useExtensionState } from "../../context/ExtensionStateContext"
import { ModelInfoView, normalizeApiConfiguration } from "./ApiOptions"
interface OpenRouterModelPickerProps {}
const OpenRouterModelPicker: React.FC<OpenRouterModelPickerProps> = () => {
const { apiConfiguration, setApiConfiguration, openRouterModels } = useExtensionState()
const handleModelChange = (event: any) => {
const newModelId = event.target.value
// get info
setApiConfiguration({ ...apiConfiguration, openRouterModelId: newModelId })
}
const { selectedModelId, selectedModelInfo } = useMemo(() => {
return normalizeApiConfiguration(apiConfiguration)
}, [apiConfiguration])
return (
<div style={{ display: "flex", flexDirection: "column", gap: 5 }}>
<div className="dropdown-container">
<label htmlFor="model-id">
<span style={{ fontWeight: 500 }}>Model</span>
</label>
<VSCodeDropdown
id="model-id"
value={selectedModelId}
onChange={handleModelChange}
style={{ width: "100%" }}>
<VSCodeOption value="">Select a model...</VSCodeOption>
{Object.keys(openRouterModels).map((modelId) => (
<VSCodeOption
key={modelId}
value={modelId}
style={{
whiteSpace: "normal",
wordWrap: "break-word",
maxWidth: "100%",
}}>
{modelId}
</VSCodeOption>
))}
</VSCodeDropdown>
</div>
{selectedModelInfo.description && (
<p style={{ fontSize: "12px", marginTop: "2px", color: "var(--vscode-descriptionForeground)" }}>
{selectedModelInfo.description}
</p>
)}
<ModelInfoView selectedModelId={selectedModelId} modelInfo={selectedModelInfo} />
</div>
)
}
export default OpenRouterModelPicker

View File

@@ -1,7 +1,7 @@
import React, { createContext, useCallback, useContext, useEffect, useState } from "react"
import { useEvent } from "react-use"
import { ExtensionMessage, ExtensionState } from "../../../src/shared/ExtensionMessage"
import { ApiConfiguration } from "../../../src/shared/api"
import { ApiConfiguration, ModelInfo } from "../../../src/shared/api"
import { vscode } from "../utils/vscode"
import { convertTextMateToHljs } from "../utils/textMateToHljs"
import { findLastIndex } from "../../../src/shared/array"
@@ -10,6 +10,7 @@ interface ExtensionStateContextType extends ExtensionState {
didHydrateState: boolean
showWelcome: boolean
theme: any
openRouterModels: Record<string, ModelInfo>
filePaths: string[]
setApiConfiguration: (config: ApiConfiguration) => void
setCustomInstructions: (value?: string) => void
@@ -30,6 +31,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
const [showWelcome, setShowWelcome] = useState(false)
const [theme, setTheme] = useState<any>(undefined)
const [filePaths, setFilePaths] = useState<string[]>([])
const [openRouterModels, setOpenRouterModels] = useState<Record<string, ModelInfo>>({})
const handleMessage = useCallback((event: MessageEvent) => {
const message: ExtensionMessage = event.data
@@ -75,6 +77,11 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
}
return prevState
})
break
}
case "openRouterModels": {
setOpenRouterModels(message.openRouterModels ?? {})
break
}
}
}, [])
@@ -90,6 +97,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
didHydrateState,
showWelcome,
theme,
openRouterModels,
filePaths,
setApiConfiguration: (value) => setState((prevState) => ({ ...prevState, apiConfiguration: value })),
setCustomInstructions: (value) => setState((prevState) => ({ ...prevState, customInstructions: value })),

View File

@@ -14,8 +14,8 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s
}
break
case "openrouter":
if (!apiConfiguration.openRouterApiKey) {
return "You must provide a valid API key or choose a different provider."
if (!apiConfiguration.openRouterApiKey || !apiConfiguration.openRouterModelId) {
return "You must provide a valid API key and model ID or choose a different provider."
}
break
case "vertex":