From fbb7620fa180ea92685896db3711938bd6783118 Mon Sep 17 00:00:00 2001
From: Saoud Rizwan <7799382+saoudrizwan@users.noreply.github.com>
Date: Thu, 12 Sep 2024 08:11:33 -0400
Subject: [PATCH] Add gemini support
---
package-lock.json | 14 +-
package.json | 1 +
src/api/gemini.ts | 57 ++++++++
src/api/index.ts | 3 +
src/providers/ClaudeDevProvider.ts | 15 +-
src/shared/api.ts | 26 +++-
src/utils/gemini-format.ts | 137 ++++++++++++++++++
webview-ui/src/components/ApiOptions.tsx | 111 +++++++++++---
webview-ui/src/components/TaskHeader.tsx | 75 +++++-----
.../src/context/ExtensionStateContext.tsx | 1 +
webview-ui/src/utils/validate.ts | 5 +
11 files changed, 384 insertions(+), 61 deletions(-)
create mode 100644 src/api/gemini.ts
create mode 100644 src/utils/gemini-format.ts
diff --git a/package-lock.json b/package-lock.json
index 369f7b6..c1a2cd8 100644
--- a/package-lock.json
+++ b/package-lock.json
@@ -1,17 +1,18 @@
{
"name": "claude-dev",
- "version": "1.5.34",
+ "version": "1.6.4",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "claude-dev",
- "version": "1.5.34",
+ "version": "1.6.4",
"license": "MIT",
"dependencies": {
"@anthropic-ai/bedrock-sdk": "^0.10.2",
"@anthropic-ai/sdk": "^0.26.0",
"@anthropic-ai/vertex-sdk": "^0.4.1",
+ "@google/generative-ai": "^0.18.0",
"@types/clone-deep": "^4.0.4",
"@types/pdf-parse": "^1.1.4",
"@vscode/codicons": "^0.0.36",
@@ -2635,6 +2636,15 @@
"node": "^12.22.0 || ^14.17.0 || >=16.0.0"
}
},
+ "node_modules/@google/generative-ai": {
+ "version": "0.18.0",
+ "resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.18.0.tgz",
+ "integrity": "sha512-AhaIWSpk2tuhYHrBhUqC0xrWWznmYEja1/TRDIb+5kruBU5kUzMlFsXCQNO9PzyTZ4clUJ3CX/Rvy+Xm9x+w3g==",
+ "license": "Apache-2.0",
+ "engines": {
+ "node": ">=18.0.0"
+ }
+ },
"node_modules/@humanwhocodes/config-array": {
"version": "0.11.14",
"resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.14.tgz",
diff --git a/package.json b/package.json
index 2572d54..7a5333d 100644
--- a/package.json
+++ b/package.json
@@ -151,6 +151,7 @@
"@anthropic-ai/bedrock-sdk": "^0.10.2",
"@anthropic-ai/sdk": "^0.26.0",
"@anthropic-ai/vertex-sdk": "^0.4.1",
+ "@google/generative-ai": "^0.18.0",
"@types/clone-deep": "^4.0.4",
"@types/pdf-parse": "^1.1.4",
"@vscode/codicons": "^0.0.36",
diff --git a/src/api/gemini.ts b/src/api/gemini.ts
new file mode 100644
index 0000000..97f627c
--- /dev/null
+++ b/src/api/gemini.ts
@@ -0,0 +1,57 @@
+import { Anthropic } from "@anthropic-ai/sdk"
+import { FunctionCallingMode, GoogleGenerativeAI } from "@google/generative-ai"
+import { ApiHandler, ApiHandlerMessageResponse } from "."
+import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../shared/api"
+import {
+ convertAnthropicMessageToGemini,
+ convertAnthropicToolToGemini,
+ convertGeminiResponseToAnthropic,
+} from "../utils/gemini-format"
+
+export class GeminiHandler implements ApiHandler {
+ private options: ApiHandlerOptions
+ private client: GoogleGenerativeAI
+
+ constructor(options: ApiHandlerOptions) {
+ if (!options.geminiApiKey) {
+ throw new Error("API key is required for Google Gemini")
+ }
+ this.options = options
+ this.client = new GoogleGenerativeAI(options.geminiApiKey)
+ }
+
+ async createMessage(
+ systemPrompt: string,
+ messages: Anthropic.Messages.MessageParam[],
+ tools: Anthropic.Messages.Tool[]
+ ): Promise {
+ const model = this.client.getGenerativeModel({
+ model: this.getModel().id,
+ systemInstruction: systemPrompt,
+ tools: [{ functionDeclarations: tools.map(convertAnthropicToolToGemini) }],
+ toolConfig: {
+ functionCallingConfig: {
+ mode: FunctionCallingMode.AUTO,
+ },
+ },
+ })
+ const result = await model.generateContent({
+ contents: messages.map(convertAnthropicMessageToGemini),
+ generationConfig: {
+ maxOutputTokens: this.getModel().info.maxTokens,
+ },
+ })
+ const message = convertGeminiResponseToAnthropic(result.response)
+
+ return { message }
+ }
+
+ getModel(): { id: GeminiModelId; info: ModelInfo } {
+ const modelId = this.options.apiModelId
+ if (modelId && modelId in geminiModels) {
+ const id = modelId as GeminiModelId
+ return { id, info: geminiModels[id] }
+ }
+ return { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] }
+ }
+}
diff --git a/src/api/index.ts b/src/api/index.ts
index 9f06ccf..e741bc0 100644
--- a/src/api/index.ts
+++ b/src/api/index.ts
@@ -6,6 +6,7 @@ import { OpenRouterHandler } from "./openrouter"
import { VertexHandler } from "./vertex"
import { OpenAiHandler } from "./openai"
import { OllamaHandler } from "./ollama"
+import { GeminiHandler } from "./gemini"
export interface ApiHandlerMessageResponse {
message: Anthropic.Messages.Message
@@ -37,6 +38,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
return new OpenAiHandler(options)
case "ollama":
return new OllamaHandler(options)
+ case "gemini":
+ return new GeminiHandler(options)
default:
return new AnthropicHandler(options)
}
diff --git a/src/providers/ClaudeDevProvider.ts b/src/providers/ClaudeDevProvider.ts
index 7414806..ede6a00 100644
--- a/src/providers/ClaudeDevProvider.ts
+++ b/src/providers/ClaudeDevProvider.ts
@@ -18,7 +18,14 @@ https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default
https://github.com/KumarVariable/vscode-extension-sidebar-html/blob/master/src/customSidebarViewProvider.ts
*/
-type SecretKey = "apiKey" | "openRouterApiKey" | "awsAccessKey" | "awsSecretKey" | "awsSessionToken" | "openAiApiKey"
+type SecretKey =
+ | "apiKey"
+ | "openRouterApiKey"
+ | "awsAccessKey"
+ | "awsSecretKey"
+ | "awsSessionToken"
+ | "openAiApiKey"
+ | "geminiApiKey"
type GlobalStateKey =
| "apiProvider"
| "apiModelId"
@@ -329,6 +336,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
ollamaModelId,
ollamaBaseUrl,
anthropicBaseUrl,
+ geminiApiKey,
} = message.apiConfiguration
await this.updateGlobalState("apiProvider", apiProvider)
await this.updateGlobalState("apiModelId", apiModelId)
@@ -346,6 +354,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
await this.updateGlobalState("ollamaModelId", ollamaModelId)
await this.updateGlobalState("ollamaBaseUrl", ollamaBaseUrl)
await this.updateGlobalState("anthropicBaseUrl", anthropicBaseUrl)
+ await this.storeSecret("geminiApiKey", geminiApiKey)
this.claudeDev?.updateApi(message.apiConfiguration)
}
await this.postStateToWebview()
@@ -667,6 +676,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
ollamaModelId,
ollamaBaseUrl,
anthropicBaseUrl,
+ geminiApiKey,
lastShownAnnouncementId,
customInstructions,
alwaysAllowReadOnly,
@@ -688,6 +698,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
this.getGlobalState("ollamaModelId") as Promise,
this.getGlobalState("ollamaBaseUrl") as Promise,
this.getGlobalState("anthropicBaseUrl") as Promise,
+ this.getSecret("geminiApiKey") as Promise,
this.getGlobalState("lastShownAnnouncementId") as Promise,
this.getGlobalState("customInstructions") as Promise,
this.getGlobalState("alwaysAllowReadOnly") as Promise,
@@ -726,6 +737,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
ollamaModelId,
ollamaBaseUrl,
anthropicBaseUrl,
+ geminiApiKey,
},
lastShownAnnouncementId,
customInstructions,
@@ -804,6 +816,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
"awsSecretKey",
"awsSessionToken",
"openAiApiKey",
+ "geminiApiKey",
]
for (const key of secretKeys) {
await this.storeSecret(key, undefined)
diff --git a/src/shared/api.ts b/src/shared/api.ts
index 0d64a78..db40002 100644
--- a/src/shared/api.ts
+++ b/src/shared/api.ts
@@ -1,4 +1,4 @@
-export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" | "openai" | "ollama"
+export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" | "openai" | "ollama" | "gemini"
export interface ApiHandlerOptions {
apiModelId?: string
@@ -16,6 +16,7 @@ export interface ApiHandlerOptions {
openAiModelId?: string
ollamaModelId?: string
ollamaBaseUrl?: string
+ geminiApiKey?: string
}
export type ApiConfiguration = ApiHandlerOptions & {
@@ -305,3 +306,26 @@ export const openAiModelInfoSaneDefaults: ModelInfo = {
inputPrice: 0,
outputPrice: 0,
}
+
+// Gemini
+// https://ai.google.dev/gemini-api/docs/models/gemini
+export type GeminiModelId = keyof typeof geminiModels
+export const geminiDefaultModelId: GeminiModelId = "gemini-1.5-flash-latest"
+export const geminiModels = {
+ "gemini-1.5-flash-latest": {
+ maxTokens: 8192,
+ contextWindow: 1_048_576,
+ supportsImages: true,
+ supportsPromptCache: false,
+ inputPrice: 0,
+ outputPrice: 0,
+ },
+ "gemini-1.5-pro-latest": {
+ maxTokens: 8192,
+ contextWindow: 2_097_152,
+ supportsImages: true,
+ supportsPromptCache: false,
+ inputPrice: 0,
+ outputPrice: 0,
+ },
+} as const satisfies Record
diff --git a/src/utils/gemini-format.ts b/src/utils/gemini-format.ts
new file mode 100644
index 0000000..dd3207f
--- /dev/null
+++ b/src/utils/gemini-format.ts
@@ -0,0 +1,137 @@
+import { Anthropic } from "@anthropic-ai/sdk"
+import { Content, EnhancedGenerateContentResponse, FunctionDeclaration, Part, SchemaType } from "@google/generative-ai"
+
+export function convertAnthropicContentToGemini(
+ content:
+ | string
+ | Array<
+ | Anthropic.Messages.TextBlockParam
+ | Anthropic.Messages.ImageBlockParam
+ | Anthropic.Messages.ToolUseBlockParam
+ | Anthropic.Messages.ToolResultBlockParam
+ >
+): Part[] {
+ if (typeof content === "string") {
+ return [{ text: content }]
+ }
+ return content.map((block) => {
+ switch (block.type) {
+ case "text":
+ return { text: block.text }
+ case "image":
+ if (block.source.type !== "base64") {
+ throw new Error("Unsupported image source type")
+ }
+ return {
+ inlineData: {
+ data: block.source.data,
+ mimeType: block.source.media_type,
+ },
+ }
+ case "tool_use":
+ return {
+ functionCall: {
+ name: block.name,
+ args: block.input,
+ },
+ } as Part
+ case "tool_result":
+ return {
+ functionResponse: {
+ name: block.tool_use_id,
+ response: {
+ content: block.content,
+ },
+ },
+ }
+ default:
+ throw new Error(`Unsupported content block type: ${(block as any).type}`)
+ }
+ })
+}
+
+export function convertAnthropicMessageToGemini(message: Anthropic.Messages.MessageParam): Content {
+ return {
+ role: message.role === "assistant" ? "model" : message.role,
+ parts: convertAnthropicContentToGemini(message.content),
+ }
+}
+
+export function convertAnthropicToolToGemini(tool: Anthropic.Messages.Tool): FunctionDeclaration {
+ return {
+ name: tool.name,
+ description: tool.description || "",
+ parameters: {
+ type: SchemaType.OBJECT,
+ properties: Object.fromEntries(
+ Object.entries(tool.input_schema.properties || {}).map(([key, value]) => [
+ key,
+ {
+ type: (value as any).type.toUpperCase(),
+ description: (value as any).description || "",
+ },
+ ])
+ ),
+ required: (tool.input_schema.required as string[]) || [],
+ },
+ }
+}
+
+export function convertGeminiResponseToAnthropic(
+ response: EnhancedGenerateContentResponse
+): Anthropic.Messages.Message {
+ const content: Anthropic.Messages.ContentBlock[] = []
+
+ // Add the main text response
+ const text = response.text()
+ if (text) {
+ content.push({ type: "text", text })
+ }
+
+ // Add function calls as tool_use blocks
+ const functionCalls = response.functionCalls()
+ if (functionCalls) {
+ functionCalls.forEach((call, index) => {
+ content.push({
+ type: "tool_use",
+ id: `tool_${index}`,
+ name: call.name,
+ input: call.args,
+ })
+ })
+ }
+
+ // Determine stop reason
+ let stop_reason: Anthropic.Messages.Message["stop_reason"] = null
+ const finishReason = response.candidates?.[0]?.finishReason
+ if (finishReason) {
+ switch (finishReason) {
+ case "STOP":
+ stop_reason = "end_turn"
+ break
+ case "MAX_TOKENS":
+ stop_reason = "max_tokens"
+ break
+ case "SAFETY":
+ case "RECITATION":
+ case "OTHER":
+ stop_reason = "stop_sequence"
+ break
+ // Add more cases if needed
+ }
+ }
+
+ return {
+ id: `msg_${Date.now()}`, // Generate a unique ID
+ type: "message",
+ role: "assistant",
+ content,
+ model: "",
+ stop_reason,
+ stop_sequence: null, // Gemini doesn't provide this information
+ usage: {
+ input_tokens: response.usageMetadata?.promptTokenCount ?? 0,
+ output_tokens: response.usageMetadata?.candidatesTokenCount ?? 0,
+ },
+ }
+}
diff --git a/webview-ui/src/components/ApiOptions.tsx b/webview-ui/src/components/ApiOptions.tsx
index 7621fef..c1f0c8e 100644
--- a/webview-ui/src/components/ApiOptions.tsx
+++ b/webview-ui/src/components/ApiOptions.tsx
@@ -1,13 +1,14 @@
import {
+ VSCodeCheckbox,
VSCodeDropdown,
VSCodeLink,
VSCodeOption,
VSCodeRadio,
VSCodeRadioGroup,
VSCodeTextField,
- VSCodeCheckbox,
} from "@vscode/webview-ui-toolkit/react"
import { memo, useCallback, useEffect, useMemo, useState } from "react"
+import { useEvent, useInterval } from "react-use"
import {
ApiConfiguration,
ModelInfo,
@@ -15,17 +16,18 @@ import {
anthropicModels,
bedrockDefaultModelId,
bedrockModels,
+ geminiDefaultModelId,
+ geminiModels,
openAiModelInfoSaneDefaults,
openRouterDefaultModelId,
openRouterModels,
vertexDefaultModelId,
vertexModels,
} from "../../../src/shared/api"
-import { useExtensionState } from "../context/ExtensionStateContext"
-import VSCodeButtonLink from "./VSCodeButtonLink"
import { ExtensionMessage } from "../../../src/shared/ExtensionMessage"
-import { useEvent, useInterval } from "react-use"
+import { useExtensionState } from "../context/ExtensionStateContext"
import { vscode } from "../utils/vscode"
+import VSCodeButtonLink from "./VSCodeButtonLink"
interface ApiOptionsProps {
showModelOptions: boolean
@@ -113,6 +115,7 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
OpenRouter
AWS Bedrock
GCP Vertex AI
+ Google Gemini
OpenAI Compatible
Ollama
@@ -161,7 +164,9 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
}}>
This key is stored locally and only used to make API requests from this extension.
{!apiConfiguration?.apiKey && (
-
+
You can get an Anthropic API key by signing up here.
)}
@@ -311,20 +316,48 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
To use Google Cloud Vertex AI, you need to
+ style={{ display: "inline", fontSize: "inherit" }}>
{
"1) create a Google Cloud account › enable the Vertex AI API › enable the desired Claude models,"
}
{" "}
+ style={{ display: "inline", fontSize: "inherit" }}>
{"2) install the Google Cloud CLI › configure Application Default Credentials."}
)}
+ {selectedProvider === "gemini" && (
+
+
+ Gemini API Key
+
+
+ This key is stored locally and only used to make API requests from this extension.
+ {!apiConfiguration?.geminiApiKey && (
+
+ You can get a Gemini API key by signing up here.
+
+ )}
+
+
+ )}
+
{selectedProvider === "openai" && (
{
started, see their
+ style={{ display: "inline", fontSize: "inherit" }}>
quickstart guide.
{" "}
You can use any model that supports{" "}
-
+
tool use.
@@ -454,9 +489,10 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
{selectedProvider === "openrouter" && createDropdown(openRouterModels)}
{selectedProvider === "bedrock" && createDropdown(bedrockModels)}
{selectedProvider === "vertex" && createDropdown(vertexModels)}
+ {selectedProvider === "gemini" && createDropdown(geminiModels)}
-
+
>
)}
@@ -476,7 +512,8 @@ export const formatPrice = (price: number) => {
}).format(price)
}
-const ModelInfoView = ({ modelInfo }: { modelInfo: ModelInfo }) => {
+const ModelInfoView = ({ selectedModelId, modelInfo }: { selectedModelId: string; modelInfo: ModelInfo }) => {
+ const isGemini = Object.keys(geminiModels).includes(selectedModelId)
return (
{
doesNotSupportLabel="Does not support images"
/>
-
-
+ {!isGemini && (
+ <>
+
+
+ >
+ )}
Max output: {modelInfo?.maxTokens?.toLocaleString()} tokens
- Input price: {formatPrice(modelInfo.inputPrice)}/million tokens
+ {modelInfo.inputPrice > 0 && (
+ <>
+ Input price: {formatPrice(modelInfo.inputPrice)}/million
+ tokens
+
+ >
+ )}
{modelInfo.supportsPromptCache && modelInfo.cacheWritesPrice && modelInfo.cacheReadsPrice && (
<>
-
Cache writes price:{" "}
{formatPrice(modelInfo.cacheWritesPrice || 0)}/million tokens
Cache reads price:{" "}
{formatPrice(modelInfo.cacheReadsPrice || 0)}/million tokens
+
+ >
+ )}
+ {modelInfo.outputPrice > 0 && (
+ <>
+ Output price: {formatPrice(modelInfo.outputPrice)}/million
+ tokens
+ >
+ )}
+ {isGemini && (
+ <>
+
+ * Free up to {selectedModelId === geminiDefaultModelId ? "15" : "2"} requests per minute. After
+ that, billing depends on prompt size.{" "}
+
+ For more info, see pricing details.
+
+
>
)}
-
- Output price: {formatPrice(modelInfo.outputPrice)}/million tokens
)
}
@@ -563,6 +630,8 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
return getProviderData(bedrockModels, bedrockDefaultModelId)
case "vertex":
return getProviderData(vertexModels, vertexDefaultModelId)
+ case "gemini":
+ return getProviderData(geminiModels, geminiDefaultModelId)
case "openai":
return {
selectedProvider: provider,
diff --git a/webview-ui/src/components/TaskHeader.tsx b/webview-ui/src/components/TaskHeader.tsx
index f05e86d..c746581 100644
--- a/webview-ui/src/components/TaskHeader.tsx
+++ b/webview-ui/src/components/TaskHeader.tsx
@@ -1,5 +1,5 @@
import { VSCodeButton } from "@vscode/webview-ui-toolkit/react"
-import React, { memo, useEffect, useRef, useState } from "react"
+import React, { memo, useEffect, useMemo, useRef, useState } from "react"
import { useWindowSize } from "react-use"
import { ClaudeMessage } from "../../../src/shared/ExtensionMessage"
import { useExtensionState } from "../context/ExtensionStateContext"
@@ -90,6 +90,14 @@ const TaskHeader: React.FC = ({
}
}, [task.text, windowWidth])
+ const isCostAvailable = useMemo(() => {
+ return (
+ apiConfiguration?.apiProvider !== "openai" &&
+ apiConfiguration?.apiProvider !== "ollama" &&
+ apiConfiguration?.apiProvider !== "gemini"
+ )
+ }, [apiConfiguration?.apiProvider])
+
return (
= ({
{!isTaskExpanded && {task.text}}
- {!isTaskExpanded &&
- apiConfiguration?.apiProvider !== "openai" &&
- apiConfiguration?.apiProvider !== "ollama" && (
-
- ${totalCost?.toFixed(4)}
-
- )}
+ {!isTaskExpanded && isCostAvailable && (
+
+ ${totalCost?.toFixed(4)}
+
+ )}
@@ -257,8 +262,7 @@ const TaskHeader: React.FC = ({
{tokensOut?.toLocaleString()}
- {(apiConfiguration?.apiProvider === "openai" ||
- apiConfiguration?.apiProvider === "ollama") && }
+ {!isCostAvailable && }
{(doesModelSupportPromptCache || cacheReads !== undefined || cacheWrites !== undefined) && (
@@ -280,21 +284,20 @@ const TaskHeader: React.FC = ({
)}
- {apiConfiguration?.apiProvider !== "openai" &&
- apiConfiguration?.apiProvider !== "ollama" && (
-
-
- API Cost:
- ${totalCost?.toFixed(4)}
-
-
+ {isCostAvailable && (
+
+
+ API Cost:
+ ${totalCost?.toFixed(4)}
- )}
+
+
+ )}
>
)}
diff --git a/webview-ui/src/context/ExtensionStateContext.tsx b/webview-ui/src/context/ExtensionStateContext.tsx
index 97da44f..cb9c970 100644
--- a/webview-ui/src/context/ExtensionStateContext.tsx
+++ b/webview-ui/src/context/ExtensionStateContext.tsx
@@ -40,6 +40,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
config.vertexProjectId,
config.openAiApiKey,
config.ollamaModelId,
+ config.geminiApiKey,
].some((key) => key !== undefined)
: false
setShowWelcome(!hasKey)
diff --git a/webview-ui/src/utils/validate.ts b/webview-ui/src/utils/validate.ts
index c09a3e1..5e3dd1b 100644
--- a/webview-ui/src/utils/validate.ts
+++ b/webview-ui/src/utils/validate.ts
@@ -23,6 +23,11 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s
return "You must provide a valid Google Cloud Project ID and Region."
}
break
+ case "gemini":
+ if (!apiConfiguration.geminiApiKey) {
+ return "You must provide a valid API key or choose a different provider."
+ }
+ break
case "openai":
if (
!apiConfiguration.openAiBaseUrl ||