diff --git a/package-lock.json b/package-lock.json index 369f7b6..c1a2cd8 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,17 +1,18 @@ { "name": "claude-dev", - "version": "1.5.34", + "version": "1.6.4", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "claude-dev", - "version": "1.5.34", + "version": "1.6.4", "license": "MIT", "dependencies": { "@anthropic-ai/bedrock-sdk": "^0.10.2", "@anthropic-ai/sdk": "^0.26.0", "@anthropic-ai/vertex-sdk": "^0.4.1", + "@google/generative-ai": "^0.18.0", "@types/clone-deep": "^4.0.4", "@types/pdf-parse": "^1.1.4", "@vscode/codicons": "^0.0.36", @@ -2635,6 +2636,15 @@ "node": "^12.22.0 || ^14.17.0 || >=16.0.0" } }, + "node_modules/@google/generative-ai": { + "version": "0.18.0", + "resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.18.0.tgz", + "integrity": "sha512-AhaIWSpk2tuhYHrBhUqC0xrWWznmYEja1/TRDIb+5kruBU5kUzMlFsXCQNO9PzyTZ4clUJ3CX/Rvy+Xm9x+w3g==", + "license": "Apache-2.0", + "engines": { + "node": ">=18.0.0" + } + }, "node_modules/@humanwhocodes/config-array": { "version": "0.11.14", "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.14.tgz", diff --git a/package.json b/package.json index 2572d54..7a5333d 100644 --- a/package.json +++ b/package.json @@ -151,6 +151,7 @@ "@anthropic-ai/bedrock-sdk": "^0.10.2", "@anthropic-ai/sdk": "^0.26.0", "@anthropic-ai/vertex-sdk": "^0.4.1", + "@google/generative-ai": "^0.18.0", "@types/clone-deep": "^4.0.4", "@types/pdf-parse": "^1.1.4", "@vscode/codicons": "^0.0.36", diff --git a/src/api/gemini.ts b/src/api/gemini.ts new file mode 100644 index 0000000..97f627c --- /dev/null +++ b/src/api/gemini.ts @@ -0,0 +1,57 @@ +import { Anthropic } from "@anthropic-ai/sdk" +import { FunctionCallingMode, GoogleGenerativeAI } from "@google/generative-ai" +import { ApiHandler, ApiHandlerMessageResponse } from "." +import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../shared/api" +import { + convertAnthropicMessageToGemini, + convertAnthropicToolToGemini, + convertGeminiResponseToAnthropic, +} from "../utils/gemini-format" + +export class GeminiHandler implements ApiHandler { + private options: ApiHandlerOptions + private client: GoogleGenerativeAI + + constructor(options: ApiHandlerOptions) { + if (!options.geminiApiKey) { + throw new Error("API key is required for Google Gemini") + } + this.options = options + this.client = new GoogleGenerativeAI(options.geminiApiKey) + } + + async createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + tools: Anthropic.Messages.Tool[] + ): Promise { + const model = this.client.getGenerativeModel({ + model: this.getModel().id, + systemInstruction: systemPrompt, + tools: [{ functionDeclarations: tools.map(convertAnthropicToolToGemini) }], + toolConfig: { + functionCallingConfig: { + mode: FunctionCallingMode.AUTO, + }, + }, + }) + const result = await model.generateContent({ + contents: messages.map(convertAnthropicMessageToGemini), + generationConfig: { + maxOutputTokens: this.getModel().info.maxTokens, + }, + }) + const message = convertGeminiResponseToAnthropic(result.response) + + return { message } + } + + getModel(): { id: GeminiModelId; info: ModelInfo } { + const modelId = this.options.apiModelId + if (modelId && modelId in geminiModels) { + const id = modelId as GeminiModelId + return { id, info: geminiModels[id] } + } + return { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] } + } +} diff --git a/src/api/index.ts b/src/api/index.ts index 9f06ccf..e741bc0 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -6,6 +6,7 @@ import { OpenRouterHandler } from "./openrouter" import { VertexHandler } from "./vertex" import { OpenAiHandler } from "./openai" import { OllamaHandler } from "./ollama" +import { GeminiHandler } from "./gemini" export interface ApiHandlerMessageResponse { message: Anthropic.Messages.Message @@ -37,6 +38,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler { return new OpenAiHandler(options) case "ollama": return new OllamaHandler(options) + case "gemini": + return new GeminiHandler(options) default: return new AnthropicHandler(options) } diff --git a/src/providers/ClaudeDevProvider.ts b/src/providers/ClaudeDevProvider.ts index 7414806..ede6a00 100644 --- a/src/providers/ClaudeDevProvider.ts +++ b/src/providers/ClaudeDevProvider.ts @@ -18,7 +18,14 @@ https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default https://github.com/KumarVariable/vscode-extension-sidebar-html/blob/master/src/customSidebarViewProvider.ts */ -type SecretKey = "apiKey" | "openRouterApiKey" | "awsAccessKey" | "awsSecretKey" | "awsSessionToken" | "openAiApiKey" +type SecretKey = + | "apiKey" + | "openRouterApiKey" + | "awsAccessKey" + | "awsSecretKey" + | "awsSessionToken" + | "openAiApiKey" + | "geminiApiKey" type GlobalStateKey = | "apiProvider" | "apiModelId" @@ -329,6 +336,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { ollamaModelId, ollamaBaseUrl, anthropicBaseUrl, + geminiApiKey, } = message.apiConfiguration await this.updateGlobalState("apiProvider", apiProvider) await this.updateGlobalState("apiModelId", apiModelId) @@ -346,6 +354,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { await this.updateGlobalState("ollamaModelId", ollamaModelId) await this.updateGlobalState("ollamaBaseUrl", ollamaBaseUrl) await this.updateGlobalState("anthropicBaseUrl", anthropicBaseUrl) + await this.storeSecret("geminiApiKey", geminiApiKey) this.claudeDev?.updateApi(message.apiConfiguration) } await this.postStateToWebview() @@ -667,6 +676,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { ollamaModelId, ollamaBaseUrl, anthropicBaseUrl, + geminiApiKey, lastShownAnnouncementId, customInstructions, alwaysAllowReadOnly, @@ -688,6 +698,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { this.getGlobalState("ollamaModelId") as Promise, this.getGlobalState("ollamaBaseUrl") as Promise, this.getGlobalState("anthropicBaseUrl") as Promise, + this.getSecret("geminiApiKey") as Promise, this.getGlobalState("lastShownAnnouncementId") as Promise, this.getGlobalState("customInstructions") as Promise, this.getGlobalState("alwaysAllowReadOnly") as Promise, @@ -726,6 +737,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { ollamaModelId, ollamaBaseUrl, anthropicBaseUrl, + geminiApiKey, }, lastShownAnnouncementId, customInstructions, @@ -804,6 +816,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { "awsSecretKey", "awsSessionToken", "openAiApiKey", + "geminiApiKey", ] for (const key of secretKeys) { await this.storeSecret(key, undefined) diff --git a/src/shared/api.ts b/src/shared/api.ts index 0d64a78..db40002 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -1,4 +1,4 @@ -export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" | "openai" | "ollama" +export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" | "openai" | "ollama" | "gemini" export interface ApiHandlerOptions { apiModelId?: string @@ -16,6 +16,7 @@ export interface ApiHandlerOptions { openAiModelId?: string ollamaModelId?: string ollamaBaseUrl?: string + geminiApiKey?: string } export type ApiConfiguration = ApiHandlerOptions & { @@ -305,3 +306,26 @@ export const openAiModelInfoSaneDefaults: ModelInfo = { inputPrice: 0, outputPrice: 0, } + +// Gemini +// https://ai.google.dev/gemini-api/docs/models/gemini +export type GeminiModelId = keyof typeof geminiModels +export const geminiDefaultModelId: GeminiModelId = "gemini-1.5-flash-latest" +export const geminiModels = { + "gemini-1.5-flash-latest": { + maxTokens: 8192, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemini-1.5-pro-latest": { + maxTokens: 8192, + contextWindow: 2_097_152, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, +} as const satisfies Record diff --git a/src/utils/gemini-format.ts b/src/utils/gemini-format.ts new file mode 100644 index 0000000..dd3207f --- /dev/null +++ b/src/utils/gemini-format.ts @@ -0,0 +1,137 @@ +import { Anthropic } from "@anthropic-ai/sdk" +import { Content, EnhancedGenerateContentResponse, FunctionDeclaration, Part, SchemaType } from "@google/generative-ai" + +export function convertAnthropicContentToGemini( + content: + | string + | Array< + | Anthropic.Messages.TextBlockParam + | Anthropic.Messages.ImageBlockParam + | Anthropic.Messages.ToolUseBlockParam + | Anthropic.Messages.ToolResultBlockParam + > +): Part[] { + if (typeof content === "string") { + return [{ text: content }] + } + return content.map((block) => { + switch (block.type) { + case "text": + return { text: block.text } + case "image": + if (block.source.type !== "base64") { + throw new Error("Unsupported image source type") + } + return { + inlineData: { + data: block.source.data, + mimeType: block.source.media_type, + }, + } + case "tool_use": + return { + functionCall: { + name: block.name, + args: block.input, + }, + } as Part + case "tool_result": + return { + functionResponse: { + name: block.tool_use_id, + response: { + content: block.content, + }, + }, + } + default: + throw new Error(`Unsupported content block type: ${(block as any).type}`) + } + }) +} + +export function convertAnthropicMessageToGemini(message: Anthropic.Messages.MessageParam): Content { + return { + role: message.role === "assistant" ? "model" : message.role, + parts: convertAnthropicContentToGemini(message.content), + } +} + +export function convertAnthropicToolToGemini(tool: Anthropic.Messages.Tool): FunctionDeclaration { + return { + name: tool.name, + description: tool.description || "", + parameters: { + type: SchemaType.OBJECT, + properties: Object.fromEntries( + Object.entries(tool.input_schema.properties || {}).map(([key, value]) => [ + key, + { + type: (value as any).type.toUpperCase(), + description: (value as any).description || "", + }, + ]) + ), + required: (tool.input_schema.required as string[]) || [], + }, + } +} + +export function convertGeminiResponseToAnthropic( + response: EnhancedGenerateContentResponse +): Anthropic.Messages.Message { + const content: Anthropic.Messages.ContentBlock[] = [] + + // Add the main text response + const text = response.text() + if (text) { + content.push({ type: "text", text }) + } + + // Add function calls as tool_use blocks + const functionCalls = response.functionCalls() + if (functionCalls) { + functionCalls.forEach((call, index) => { + content.push({ + type: "tool_use", + id: `tool_${index}`, + name: call.name, + input: call.args, + }) + }) + } + + // Determine stop reason + let stop_reason: Anthropic.Messages.Message["stop_reason"] = null + const finishReason = response.candidates?.[0]?.finishReason + if (finishReason) { + switch (finishReason) { + case "STOP": + stop_reason = "end_turn" + break + case "MAX_TOKENS": + stop_reason = "max_tokens" + break + case "SAFETY": + case "RECITATION": + case "OTHER": + stop_reason = "stop_sequence" + break + // Add more cases if needed + } + } + + return { + id: `msg_${Date.now()}`, // Generate a unique ID + type: "message", + role: "assistant", + content, + model: "", + stop_reason, + stop_sequence: null, // Gemini doesn't provide this information + usage: { + input_tokens: response.usageMetadata?.promptTokenCount ?? 0, + output_tokens: response.usageMetadata?.candidatesTokenCount ?? 0, + }, + } +} diff --git a/webview-ui/src/components/ApiOptions.tsx b/webview-ui/src/components/ApiOptions.tsx index 7621fef..c1f0c8e 100644 --- a/webview-ui/src/components/ApiOptions.tsx +++ b/webview-ui/src/components/ApiOptions.tsx @@ -1,13 +1,14 @@ import { + VSCodeCheckbox, VSCodeDropdown, VSCodeLink, VSCodeOption, VSCodeRadio, VSCodeRadioGroup, VSCodeTextField, - VSCodeCheckbox, } from "@vscode/webview-ui-toolkit/react" import { memo, useCallback, useEffect, useMemo, useState } from "react" +import { useEvent, useInterval } from "react-use" import { ApiConfiguration, ModelInfo, @@ -15,17 +16,18 @@ import { anthropicModels, bedrockDefaultModelId, bedrockModels, + geminiDefaultModelId, + geminiModels, openAiModelInfoSaneDefaults, openRouterDefaultModelId, openRouterModels, vertexDefaultModelId, vertexModels, } from "../../../src/shared/api" -import { useExtensionState } from "../context/ExtensionStateContext" -import VSCodeButtonLink from "./VSCodeButtonLink" import { ExtensionMessage } from "../../../src/shared/ExtensionMessage" -import { useEvent, useInterval } from "react-use" +import { useExtensionState } from "../context/ExtensionStateContext" import { vscode } from "../utils/vscode" +import VSCodeButtonLink from "./VSCodeButtonLink" interface ApiOptionsProps { showModelOptions: boolean @@ -113,6 +115,7 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => { OpenRouter AWS Bedrock GCP Vertex AI + Google Gemini OpenAI Compatible Ollama @@ -161,7 +164,9 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => { }}> This key is stored locally and only used to make API requests from this extension. {!apiConfiguration?.apiKey && ( - + You can get an Anthropic API key by signing up here. )} @@ -311,20 +316,48 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => { To use Google Cloud Vertex AI, you need to + style={{ display: "inline", fontSize: "inherit" }}> { "1) create a Google Cloud account › enable the Vertex AI API › enable the desired Claude models," } {" "} + style={{ display: "inline", fontSize: "inherit" }}> {"2) install the Google Cloud CLI › configure Application Default Credentials."}

)} + {selectedProvider === "gemini" && ( +
+ + Gemini API Key + +

+ This key is stored locally and only used to make API requests from this extension. + {!apiConfiguration?.geminiApiKey && ( + + You can get a Gemini API key by signing up here. + + )} +

+
+ )} + {selectedProvider === "openai" && (
{ started, see their + style={{ display: "inline", fontSize: "inherit" }}> quickstart guide. {" "} You can use any model that supports{" "} - + tool use. @@ -454,9 +489,10 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => { {selectedProvider === "openrouter" && createDropdown(openRouterModels)} {selectedProvider === "bedrock" && createDropdown(bedrockModels)} {selectedProvider === "vertex" && createDropdown(vertexModels)} + {selectedProvider === "gemini" && createDropdown(geminiModels)}
- + )} @@ -476,7 +512,8 @@ export const formatPrice = (price: number) => { }).format(price) } -const ModelInfoView = ({ modelInfo }: { modelInfo: ModelInfo }) => { +const ModelInfoView = ({ selectedModelId, modelInfo }: { selectedModelId: string; modelInfo: ModelInfo }) => { + const isGemini = Object.keys(geminiModels).includes(selectedModelId) return (

{ doesNotSupportLabel="Does not support images" />
- -
+ {!isGemini && ( + <> + +
+ + )} Max output: {modelInfo?.maxTokens?.toLocaleString()} tokens
- Input price: {formatPrice(modelInfo.inputPrice)}/million tokens + {modelInfo.inputPrice > 0 && ( + <> + Input price: {formatPrice(modelInfo.inputPrice)}/million + tokens +
+ + )} {modelInfo.supportsPromptCache && modelInfo.cacheWritesPrice && modelInfo.cacheReadsPrice && ( <> -
Cache writes price:{" "} {formatPrice(modelInfo.cacheWritesPrice || 0)}/million tokens
Cache reads price:{" "} {formatPrice(modelInfo.cacheReadsPrice || 0)}/million tokens +
+ + )} + {modelInfo.outputPrice > 0 && ( + <> + Output price: {formatPrice(modelInfo.outputPrice)}/million + tokens + + )} + {isGemini && ( + <> + + * Free up to {selectedModelId === geminiDefaultModelId ? "15" : "2"} requests per minute. After + that, billing depends on prompt size.{" "} + + For more info, see pricing details. + + )} -
- Output price: {formatPrice(modelInfo.outputPrice)}/million tokens

) } @@ -563,6 +630,8 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) { return getProviderData(bedrockModels, bedrockDefaultModelId) case "vertex": return getProviderData(vertexModels, vertexDefaultModelId) + case "gemini": + return getProviderData(geminiModels, geminiDefaultModelId) case "openai": return { selectedProvider: provider, diff --git a/webview-ui/src/components/TaskHeader.tsx b/webview-ui/src/components/TaskHeader.tsx index f05e86d..c746581 100644 --- a/webview-ui/src/components/TaskHeader.tsx +++ b/webview-ui/src/components/TaskHeader.tsx @@ -1,5 +1,5 @@ import { VSCodeButton } from "@vscode/webview-ui-toolkit/react" -import React, { memo, useEffect, useRef, useState } from "react" +import React, { memo, useEffect, useMemo, useRef, useState } from "react" import { useWindowSize } from "react-use" import { ClaudeMessage } from "../../../src/shared/ExtensionMessage" import { useExtensionState } from "../context/ExtensionStateContext" @@ -90,6 +90,14 @@ const TaskHeader: React.FC = ({ } }, [task.text, windowWidth]) + const isCostAvailable = useMemo(() => { + return ( + apiConfiguration?.apiProvider !== "openai" && + apiConfiguration?.apiProvider !== "ollama" && + apiConfiguration?.apiProvider !== "gemini" + ) + }, [apiConfiguration?.apiProvider]) + return (
= ({ {!isTaskExpanded && {task.text}}
- {!isTaskExpanded && - apiConfiguration?.apiProvider !== "openai" && - apiConfiguration?.apiProvider !== "ollama" && ( -
- ${totalCost?.toFixed(4)} -
- )} + {!isTaskExpanded && isCostAvailable && ( +
+ ${totalCost?.toFixed(4)} +
+ )} @@ -257,8 +262,7 @@ const TaskHeader: React.FC = ({ {tokensOut?.toLocaleString()} - {(apiConfiguration?.apiProvider === "openai" || - apiConfiguration?.apiProvider === "ollama") && } + {!isCostAvailable && } {(doesModelSupportPromptCache || cacheReads !== undefined || cacheWrites !== undefined) && ( @@ -280,21 +284,20 @@ const TaskHeader: React.FC = ({ )} - {apiConfiguration?.apiProvider !== "openai" && - apiConfiguration?.apiProvider !== "ollama" && ( -
-
- API Cost: - ${totalCost?.toFixed(4)} -
- + {isCostAvailable && ( +
+
+ API Cost: + ${totalCost?.toFixed(4)}
- )} + +
+ )}
)} diff --git a/webview-ui/src/context/ExtensionStateContext.tsx b/webview-ui/src/context/ExtensionStateContext.tsx index 97da44f..cb9c970 100644 --- a/webview-ui/src/context/ExtensionStateContext.tsx +++ b/webview-ui/src/context/ExtensionStateContext.tsx @@ -40,6 +40,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode config.vertexProjectId, config.openAiApiKey, config.ollamaModelId, + config.geminiApiKey, ].some((key) => key !== undefined) : false setShowWelcome(!hasKey) diff --git a/webview-ui/src/utils/validate.ts b/webview-ui/src/utils/validate.ts index c09a3e1..5e3dd1b 100644 --- a/webview-ui/src/utils/validate.ts +++ b/webview-ui/src/utils/validate.ts @@ -23,6 +23,11 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s return "You must provide a valid Google Cloud Project ID and Region." } break + case "gemini": + if (!apiConfiguration.geminiApiKey) { + return "You must provide a valid API key or choose a different provider." + } + break case "openai": if ( !apiConfiguration.openAiBaseUrl ||