Add support for OpenRouter and AWS Bedrock

This commit is contained in:
Saoud Rizwan
2024-08-03 14:24:56 -04:00
parent d441950b7f
commit c09a8462d7
19 changed files with 4458 additions and 194 deletions

View File

@@ -8,15 +8,17 @@ import osName from "os-name"
import pWaitFor from "p-wait-for"
import * as path from "path"
import { serializeError } from "serialize-error"
import treeKill from "tree-kill"
import * as vscode from "vscode"
import { ApiHandler, buildApiHandler } from "./api"
import { listFiles, parseSourceCodeForDefinitionsTopLevel } from "./parse-source-code"
import { ClaudeDevProvider } from "./providers/ClaudeDevProvider"
import { ApiConfiguration } from "./shared/api"
import { ClaudeRequestResult } from "./shared/ClaudeRequestResult"
import { DEFAULT_MAX_REQUESTS_PER_TASK } from "./shared/Constants"
import { ClaudeAsk, ClaudeMessage, ClaudeSay, ClaudeSayTool } from "./shared/ExtensionMessage"
import { Tool, ToolName } from "./shared/Tool"
import { ClaudeAskResponse } from "./shared/WebviewMessage"
import treeKill from "tree-kill"
const SYSTEM_PROMPT =
() => `You are Claude Dev, a highly skilled software developer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
@@ -225,7 +227,7 @@ const tools: Tool[] = [
]
export class ClaudeDev {
private client: Anthropic
private api: ApiHandler
private maxRequestsPerTask: number
private requestCount = 0
apiConversationHistory: Anthropic.MessageParam[] = []
@@ -236,16 +238,21 @@ export class ClaudeDev {
private providerRef: WeakRef<ClaudeDevProvider>
abort: boolean = false
constructor(provider: ClaudeDevProvider, task: string, apiKey: string, maxRequestsPerTask?: number) {
constructor(
provider: ClaudeDevProvider,
task: string,
apiConfiguration: ApiConfiguration,
maxRequestsPerTask?: number
) {
this.providerRef = new WeakRef(provider)
this.client = new Anthropic({ apiKey })
this.api = buildApiHandler(apiConfiguration)
this.maxRequestsPerTask = maxRequestsPerTask ?? DEFAULT_MAX_REQUESTS_PER_TASK
this.startTask(task)
}
updateApiKey(apiKey: string) {
this.client = new Anthropic({ apiKey })
updateApi(apiConfiguration: ApiConfiguration) {
this.api = buildApiHandler(apiConfiguration)
}
updateMaxRequestsPerTask(maxRequestsPerTask: number | undefined) {
@@ -699,22 +706,7 @@ export class ClaudeDev {
async attemptApiRequest(): Promise<Anthropic.Messages.Message> {
try {
const response = await this.client.messages.create(
{
model: "claude-3-5-sonnet-20240620", // https://docs.anthropic.com/en/docs/about-claude/models
// beta max tokens
max_tokens: 8192,
system: SYSTEM_PROMPT(),
messages: this.apiConversationHistory,
tools: tools,
tool_choice: { type: "auto" },
},
{
// https://github.com/anthropics/anthropic-sdk-typescript?tab=readme-ov-file#default-headers
headers: { "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" },
}
)
return response
return await this.api.createMessage(SYSTEM_PROMPT(), this.apiConversationHistory, tools)
} catch (error) {
const { response } = await this.ask(
"api_req_failed",

34
src/api/anthropic.ts Normal file
View File

@@ -0,0 +1,34 @@
import { Anthropic } from "@anthropic-ai/sdk"
import { ApiHandler } from "."
import { ApiHandlerOptions } from "../shared/api"
export class AnthropicHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: Anthropic
constructor(options: ApiHandlerOptions) {
this.options = options
this.client = new Anthropic({ apiKey: this.options.apiKey })
}
async createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<Anthropic.Messages.Message> {
return await this.client.messages.create(
{
model: "claude-3-5-sonnet-20240620", // https://docs.anthropic.com/en/docs/about-claude/models
max_tokens: 8192, // beta max tokens
system: systemPrompt,
messages,
tools,
tool_choice: { type: "auto" },
},
{
// https://github.com/anthropics/anthropic-sdk-typescript?tab=readme-ov-file#default-headers
headers: { "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" },
}
)
}
}

39
src/api/bedrock.ts Normal file
View File

@@ -0,0 +1,39 @@
import AnthropicBedrock from "@anthropic-ai/bedrock-sdk"
import { Anthropic } from "@anthropic-ai/sdk"
import { ApiHandlerOptions } from "../shared/api"
import { ApiHandler } from "."
// https://docs.anthropic.com/en/api/claude-on-amazon-bedrock
export class AwsBedrockHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: AnthropicBedrock
constructor(options: ApiHandlerOptions) {
this.options = options
this.client = new AnthropicBedrock({
// Authenticate by either providing the keys below or use the default AWS credential providers, such as
// using ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables.
awsAccessKey: this.options.awsAccessKey,
awsSecretKey: this.options.awsSecretKey,
// awsRegion changes the aws region to which the request is made. By default, we read AWS_REGION,
// and if that's not present, we default to us-east-1. Note that we do not read ~/.aws/config for the region.
awsRegion: this.options.awsRegion,
})
}
async createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<Anthropic.Messages.Message> {
return await this.client.messages.create({
model: "anthropic.claude-3-5-sonnet-20240620-v1:0",
max_tokens: 4096,
system: systemPrompt,
messages,
tools,
tool_choice: { type: "auto" },
})
}
}

27
src/api/index.ts Normal file
View File

@@ -0,0 +1,27 @@
import { Anthropic } from "@anthropic-ai/sdk"
import { ApiConfiguration } from "../shared/api"
import { AnthropicHandler } from "./anthropic"
import { AwsBedrockHandler } from "./bedrock"
import { OpenRouterHandler } from "./openrouter"
export interface ApiHandler {
createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<Anthropic.Messages.Message>
}
export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
const { apiProvider, ...options } = configuration
switch (apiProvider) {
case "anthropic":
return new AnthropicHandler(options)
case "openrouter":
return new OpenRouterHandler(options)
case "bedrock":
return new AwsBedrockHandler(options)
default:
throw new Error(`Unknown API provider: ${apiProvider}`)
}
}

140
src/api/openrouter.ts Normal file
View File

@@ -0,0 +1,140 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import { ApiHandlerOptions } from "../shared/api"
import { ApiHandler } from "."
export class OpenRouterHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: OpenAI
constructor(options: ApiHandlerOptions) {
this.options = options
this.client = new OpenAI({
baseURL: "https://openrouter.ai/api/v1",
apiKey: this.options.openRouterApiKey,
defaultHeaders: {
"HTTP-Referer": "https://github.com/saoudrizwan/claude-dev", // Optional, for including your app on openrouter.ai rankings.
"X-Title": "claude-dev", // Optional. Shows in rankings on openrouter.ai.
},
})
}
async createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<Anthropic.Messages.Message> {
// Convert Anthropic messages to OpenAI format
const openAIMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
...messages.map((msg) => {
const baseMessage = {
content:
typeof msg.content === "string"
? msg.content
: msg.content
.map((part) => {
if ("text" in part) {
return part.text
} else if ("source" in part) {
return { type: "image_url" as const, image_url: { url: part.source.data } }
}
return ""
})
.filter(Boolean)
.join("\n"),
}
if (msg.role === "user") {
return { ...baseMessage, role: "user" as const }
} else if (msg.role === "assistant") {
const assistantMessage: OpenAI.Chat.ChatCompletionAssistantMessageParam = {
...baseMessage,
role: "assistant" as const,
}
if ("tool_calls" in msg && Array.isArray(msg.tool_calls) && msg.tool_calls.length > 0) {
assistantMessage.tool_calls = msg.tool_calls.map((toolCall) => ({
id: toolCall.id,
type: "function",
function: {
name: toolCall.function.name,
arguments: JSON.stringify(toolCall.function.arguments),
},
}))
}
return assistantMessage
}
throw new Error(`Unsupported message role: ${msg.role}`)
}),
]
// Convert Anthropic tools to OpenAI tools
const openAITools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
},
}))
const completion = await this.client.chat.completions.create({
model: "anthropic/claude-3.5-sonnet:beta",
max_tokens: 4096,
messages: openAIMessages,
tools: openAITools,
tool_choice: "auto",
})
// Convert OpenAI response to Anthropic format
const openAIMessage = completion.choices[0].message
const anthropicMessage: Anthropic.Messages.Message = {
id: completion.id,
type: "message",
role: "assistant",
content: [
{
type: "text",
text: openAIMessage.content || "",
},
],
model: completion.model,
stop_reason: this.mapFinishReason(completion.choices[0].finish_reason),
stop_sequence: null,
usage: {
input_tokens: completion.usage?.prompt_tokens || 0,
output_tokens: completion.usage?.completion_tokens || 0,
},
}
if (openAIMessage.tool_calls && openAIMessage.tool_calls.length > 0) {
anthropicMessage.content.push(
...openAIMessage.tool_calls.map((toolCall) => ({
type: "tool_use" as const,
id: toolCall.id,
name: toolCall.function.name,
input: JSON.parse(toolCall.function.arguments || "{}"),
}))
)
}
return anthropicMessage
}
private mapFinishReason(
finishReason: OpenAI.Chat.ChatCompletion.Choice["finish_reason"]
): Anthropic.Messages.Message["stop_reason"] {
switch (finishReason) {
case "stop":
return "end_turn"
case "length":
return "max_tokens"
case "tool_calls":
return "tool_use"
case "content_filter":
return null // Anthropic doesn't have an exact equivalent
default:
return null
}
}
}

View File

@@ -1,12 +1,12 @@
import { Uri, Webview } from "vscode"
//import * as weather from "weather-js"
import { Anthropic } from "@anthropic-ai/sdk"
import os from "os"
import * as path from "path"
import * as vscode from "vscode"
import { ClaudeDev } from "../ClaudeDev"
import { ClaudeMessage, ExtensionMessage } from "../shared/ExtensionMessage"
import { ApiProvider } from "../shared/api"
import { ExtensionMessage } from "../shared/ExtensionMessage"
import { WebviewMessage } from "../shared/WebviewMessage"
import { Anthropic } from "@anthropic-ai/sdk"
import * as path from "path"
import os from "os"
/*
https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default/weather-webview/src/providers/WeatherViewProvider.ts
@@ -14,6 +14,9 @@ https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default
https://github.com/KumarVariable/vscode-extension-sidebar-html/blob/master/src/customSidebarViewProvider.ts
*/
type SecretKey = "apiKey" | "openRouterApiKey" | "awsAccessKey" | "awsSecretKey"
type GlobalStateKey = "apiProvider" | "awsRegion" | "maxRequestsPerTask" | "lastShownAnnouncementId"
export class ClaudeDevProvider implements vscode.WebviewViewProvider {
public static readonly sideBarId = "claude-dev.SidebarProvider" // used in package.json as the view's id. This value cannot be changed due to how vscode caches views based on their id, and updating the id would break existing instances of the extension.
public static readonly tabPanelId = "claude-dev.TabPanelProvider"
@@ -131,15 +134,16 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
this.outputChannel.appendLine("Webview view resolved")
}
async tryToInitClaudeDevWithTask(task: string) {
async initClaudeDevWithTask(task: string) {
await this.clearTask() // ensures that an exising task doesn't exist before starting a new one, although this shouldn't be possible since user must clear task before starting a new one
const [apiKey, maxRequestsPerTask] = await Promise.all([
this.getSecret("apiKey") as Promise<string | undefined>,
this.getGlobalState("maxRequestsPerTask") as Promise<number | undefined>,
])
if (this.view && apiKey) {
this.claudeDev = new ClaudeDev(this, task, apiKey, maxRequestsPerTask)
}
const { apiProvider, apiKey, openRouterApiKey, awsAccessKey, awsSecretKey, awsRegion, maxRequestsPerTask } =
await this.getState()
this.claudeDev = new ClaudeDev(
this,
task,
{ apiProvider, apiKey, openRouterApiKey, awsAccessKey, awsSecretKey, awsRegion },
maxRequestsPerTask
)
}
// Send any JSON serializable data to the react app
@@ -249,11 +253,20 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
// Could also do this in extension .ts
//this.postMessageToWebview({ type: "text", text: `Extension: ${Date.now()}` })
// initializing new instance of ClaudeDev will make sure that any agentically running promises in old instance don't affect our new task. this essentially creates a fresh slate for the new task
await this.tryToInitClaudeDevWithTask(message.text!)
await this.initClaudeDevWithTask(message.text!)
break
case "apiKey":
await this.storeSecret("apiKey", message.text!)
this.claudeDev?.updateApiKey(message.text!)
case "apiConfiguration":
if (message.apiConfiguration) {
const { apiProvider, apiKey, openRouterApiKey, awsAccessKey, awsSecretKey, awsRegion } =
message.apiConfiguration
await this.updateGlobalState("apiProvider", apiProvider)
await this.storeSecret("apiKey", apiKey)
await this.storeSecret("openRouterApiKey", openRouterApiKey)
await this.storeSecret("awsAccessKey", awsAccessKey)
await this.storeSecret("awsSecretKey", awsSecretKey)
await this.updateGlobalState("awsRegion", awsRegion)
this.claudeDev?.updateApi(message.apiConfiguration)
}
await this.postStateToWebview()
break
case "maxRequestsPerTask":
@@ -369,15 +382,20 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
}
async postStateToWebview() {
const [apiKey, maxRequestsPerTask, lastShownAnnouncementId] = await Promise.all([
this.getSecret("apiKey") as Promise<string | undefined>,
this.getGlobalState("maxRequestsPerTask") as Promise<number | undefined>,
this.getGlobalState("lastShownAnnouncementId") as Promise<string | undefined>,
])
const {
apiProvider,
apiKey,
openRouterApiKey,
awsAccessKey,
awsSecretKey,
awsRegion,
maxRequestsPerTask,
lastShownAnnouncementId,
} = await this.getState()
this.postMessageToWebview({
type: "state",
state: {
apiKey,
apiConfiguration: { apiProvider, apiKey, openRouterApiKey, awsAccessKey, awsSecretKey, awsRegion },
maxRequestsPerTask,
themeName: vscode.workspace.getConfiguration("workbench").get<string>("colorTheme"),
claudeMessages: this.claudeDev?.claudeMessages || [],
@@ -476,13 +494,45 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
https://www.eliostruyf.com/devhack-code-extension-storage-options/
*/
async getState() {
const [
apiProvider,
apiKey,
openRouterApiKey,
awsAccessKey,
awsSecretKey,
awsRegion,
maxRequestsPerTask,
lastShownAnnouncementId,
] = await Promise.all([
this.getGlobalState("apiProvider") as Promise<ApiProvider | undefined>,
this.getSecret("apiKey") as Promise<string | undefined>,
this.getSecret("openRouterApiKey") as Promise<string | undefined>,
this.getSecret("awsAccessKey") as Promise<string | undefined>,
this.getSecret("awsSecretKey") as Promise<string | undefined>,
this.getGlobalState("awsRegion") as Promise<string | undefined>,
this.getGlobalState("maxRequestsPerTask") as Promise<number | undefined>,
this.getGlobalState("lastShownAnnouncementId") as Promise<string | undefined>,
])
return {
apiProvider: apiProvider || "anthropic", // for legacy users that were using Anthropic by default
apiKey,
openRouterApiKey,
awsAccessKey,
awsSecretKey,
awsRegion,
maxRequestsPerTask,
lastShownAnnouncementId,
}
}
// global
private async updateGlobalState(key: string, value: any) {
private async updateGlobalState(key: GlobalStateKey, value: any) {
await this.context.globalState.update(key, value)
}
private async getGlobalState(key: string) {
private async getGlobalState(key: GlobalStateKey) {
return await this.context.globalState.get(key)
}
@@ -508,11 +558,11 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
// secrets
private async storeSecret(key: string, value: any) {
private async storeSecret(key: SecretKey, value: any) {
await this.context.secrets.store(key, value)
}
private async getSecret(key: string) {
private async getSecret(key: SecretKey) {
return await this.context.secrets.get(key)
}
}

View File

@@ -1,5 +1,7 @@
// type that represents json data that is sent from extension to webview, called ExtensionMessage and has 'type' enum which can be 'plusButtonTapped' or 'settingsButtonTapped' or 'hello'
import { ApiConfiguration } from "./api"
// webview will hold state
export interface ExtensionMessage {
type: "action" | "state"
@@ -9,7 +11,7 @@ export interface ExtensionMessage {
}
export interface ExtensionState {
apiKey?: string
apiConfiguration?: ApiConfiguration
maxRequestsPerTask?: number
themeName?: string
claudeMessages: ClaudeMessage[]

View File

@@ -1,6 +1,8 @@
import { ApiConfiguration, ApiProvider } from "./api"
export interface WebviewMessage {
type:
| "apiKey"
| "apiConfiguration"
| "maxRequestsPerTask"
| "webviewDidLaunch"
| "newTask"
@@ -10,6 +12,7 @@ export interface WebviewMessage {
| "downloadTask"
text?: string
askResponse?: ClaudeAskResponse
apiConfiguration?: ApiConfiguration
}
export type ClaudeAskResponse = "yesButtonTapped" | "noButtonTapped" | "textResponse"

13
src/shared/api.ts Normal file
View File

@@ -0,0 +1,13 @@
export type ApiProvider = "anthropic" | "openrouter" | "bedrock"
export interface ApiHandlerOptions {
apiKey?: string // anthropic
openRouterApiKey?: string
awsAccessKey?: string
awsSecretKey?: string
awsRegion?: string
}
export type ApiConfiguration = ApiHandlerOptions & {
apiProvider?: ApiProvider
}

16
src/utils/getNonce.ts Normal file
View File

@@ -0,0 +1,16 @@
/**
* A helper function that returns a unique alphanumeric identifier called a nonce.
*
* @remarks This function is primarily used to help enforce content security
* policies for resources/scripts being executed in a webview context.
*
* @returns A nonce
*/
export function getNonce() {
let text = ""
const possible = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
for (let i = 0; i < 32; i++) {
text += possible.charAt(Math.floor(Math.random() * possible.length))
}
return text
}

15
src/utils/getUri.ts Normal file
View File

@@ -0,0 +1,15 @@
import { Uri, Webview } from "vscode"
/**
* A helper function which will get the webview URI of a given file or resource.
*
* @remarks This URI can be used within a webview's HTML as a link to the
* given file/resource.
*
* @param webview A reference to the extension webview
* @param extensionUri The URI of the directory containing the extension
* @param pathList An array of strings representing the path to a file/resource
* @returns A URI pointing to the file/resource
*/
export function getUri(webview: Webview, extensionUri: Uri, pathList: string[]) {
return webview.asWebviewUri(Uri.joinPath(extensionUri, ...pathList))
}