Refactor API

This commit is contained in:
Saoud Rizwan
2024-09-24 10:43:31 -04:00
parent f774e62c13
commit a009c84597
12 changed files with 25 additions and 19 deletions

View File

@@ -0,0 +1,118 @@
import { Anthropic } from "@anthropic-ai/sdk"
import { ApiHandler, ApiHandlerMessageResponse } from "../index"
import {
anthropicDefaultModelId,
AnthropicModelId,
anthropicModels,
ApiHandlerOptions,
ModelInfo,
} from "../../shared/api"
export class AnthropicHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: Anthropic
constructor(options: ApiHandlerOptions) {
this.options = options
this.client = new Anthropic({
apiKey: this.options.apiKey,
baseURL: this.options.anthropicBaseUrl || undefined,
})
}
async createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<ApiHandlerMessageResponse> {
const modelId = this.getModel().id
switch (modelId) {
case "claude-3-5-sonnet-20240620":
case "claude-3-opus-20240229":
case "claude-3-haiku-20240307": {
/*
The latest message will be the new user message, one before will be the assistant message from a previous request, and the user message before that will be a previously cached user message. So we need to mark the latest user message as ephemeral to cache it for the next request, and mark the second to last user message as ephemeral to let the server know the last message to retrieve from the cache for the current request..
*/
const userMsgIndices = messages.reduce(
(acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc),
[] as number[]
)
const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1
const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1
const message = await this.client.beta.promptCaching.messages.create(
{
model: modelId,
max_tokens: this.getModel().info.maxTokens,
temperature: 0.2,
system: [{ text: systemPrompt, type: "text", cache_control: { type: "ephemeral" } }], // setting cache breakpoint for system prompt so new tasks can reuse it
messages: messages.map((message, index) => {
if (index === lastUserMsgIndex || index === secondLastMsgUserIndex) {
return {
...message,
content:
typeof message.content === "string"
? [
{
type: "text",
text: message.content,
cache_control: { type: "ephemeral" },
},
]
: message.content.map((content, contentIndex) =>
contentIndex === message.content.length - 1
? { ...content, cache_control: { type: "ephemeral" } }
: content
),
}
}
return message
}),
tools, // cache breakpoints go from tools > system > messages, and since tools dont change, we can just set the breakpoint at the end of system (this avoids having to set a breakpoint at the end of tools which by itself does not meet min requirements for haiku caching)
tool_choice: { type: "auto" },
},
(() => {
// prompt caching: https://x.com/alexalbert__/status/1823751995901272068
// https://github.com/anthropics/anthropic-sdk-typescript?tab=readme-ov-file#default-headers
// https://github.com/anthropics/anthropic-sdk-typescript/commit/c920b77fc67bd839bfeb6716ceab9d7c9bbe7393
switch (modelId) {
case "claude-3-5-sonnet-20240620":
return {
headers: {
"anthropic-beta": "prompt-caching-2024-07-31",
},
}
case "claude-3-haiku-20240307":
return {
headers: { "anthropic-beta": "prompt-caching-2024-07-31" },
}
default:
return undefined
}
})()
)
return { message }
}
default: {
const message = await this.client.messages.create({
model: modelId,
max_tokens: this.getModel().info.maxTokens,
temperature: 0.2,
system: [{ text: systemPrompt, type: "text" }],
messages,
tools,
tool_choice: { type: "auto" },
})
return { message }
}
}
}
getModel(): { id: AnthropicModelId; info: ModelInfo } {
const modelId = this.options.apiModelId
if (modelId && modelId in anthropicModels) {
const id = modelId as AnthropicModelId
return { id, info: anthropicModels[id] }
}
return { id: anthropicDefaultModelId, info: anthropicModels[anthropicDefaultModelId] }
}
}

View File

@@ -0,0 +1,51 @@
import AnthropicBedrock from "@anthropic-ai/bedrock-sdk"
import { Anthropic } from "@anthropic-ai/sdk"
import { ApiHandler, ApiHandlerMessageResponse } from "."
import { ApiHandlerOptions, bedrockDefaultModelId, BedrockModelId, bedrockModels, ModelInfo } from "../shared/api"
// https://docs.anthropic.com/en/api/claude-on-amazon-bedrock
export class AwsBedrockHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: AnthropicBedrock
constructor(options: ApiHandlerOptions) {
this.options = options
this.client = new AnthropicBedrock({
// Authenticate by either providing the keys below or use the default AWS credential providers, such as
// using ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables.
...(this.options.awsAccessKey ? { awsAccessKey: this.options.awsAccessKey } : {}),
...(this.options.awsSecretKey ? { awsSecretKey: this.options.awsSecretKey } : {}),
...(this.options.awsSessionToken ? { awsSessionToken: this.options.awsSessionToken } : {}),
// awsRegion changes the aws region to which the request is made. By default, we read AWS_REGION,
// and if that's not present, we default to us-east-1. Note that we do not read ~/.aws/config for the region.
awsRegion: this.options.awsRegion,
})
}
async createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<ApiHandlerMessageResponse> {
const message = await this.client.messages.create({
model: this.getModel().id,
max_tokens: this.getModel().info.maxTokens,
temperature: 0.2,
system: systemPrompt,
messages,
tools,
tool_choice: { type: "auto" },
})
return { message }
}
getModel(): { id: BedrockModelId; info: ModelInfo } {
const modelId = this.options.apiModelId
if (modelId && modelId in bedrockModels) {
const id = modelId as BedrockModelId
return { id, info: bedrockModels[id] }
}
return { id: bedrockDefaultModelId, info: bedrockModels[bedrockDefaultModelId] }
}
}

View File

@@ -0,0 +1,58 @@
import { Anthropic } from "@anthropic-ai/sdk"
import { FunctionCallingMode, GoogleGenerativeAI } from "@google/generative-ai"
import { ApiHandler, ApiHandlerMessageResponse } from "."
import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../shared/api"
import {
convertAnthropicMessageToGemini,
convertAnthropicToolToGemini,
convertGeminiResponseToAnthropic,
} from "./transform/gemini-format"
export class GeminiHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: GoogleGenerativeAI
constructor(options: ApiHandlerOptions) {
if (!options.geminiApiKey) {
throw new Error("API key is required for Google Gemini")
}
this.options = options
this.client = new GoogleGenerativeAI(options.geminiApiKey)
}
async createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<ApiHandlerMessageResponse> {
const model = this.client.getGenerativeModel({
model: this.getModel().id,
systemInstruction: systemPrompt,
tools: [{ functionDeclarations: tools.map(convertAnthropicToolToGemini) }],
toolConfig: {
functionCallingConfig: {
mode: FunctionCallingMode.AUTO,
},
},
})
const result = await model.generateContent({
contents: messages.map(convertAnthropicMessageToGemini),
generationConfig: {
maxOutputTokens: this.getModel().info.maxTokens,
temperature: 0.2,
},
})
const message = convertGeminiResponseToAnthropic(result.response)
return { message }
}
getModel(): { id: GeminiModelId; info: ModelInfo } {
const modelId = this.options.apiModelId
if (modelId && modelId in geminiModels) {
const id = modelId as GeminiModelId
return { id, info: geminiModels[id] }
}
return { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] }
}
}

View File

@@ -0,0 +1,58 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import { ApiHandler, ApiHandlerMessageResponse } from "."
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../shared/api"
import { convertToAnthropicMessage, convertToOpenAiMessages } from "./transform/openai-format"
export class OllamaHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: OpenAI
constructor(options: ApiHandlerOptions) {
this.options = options
this.client = new OpenAI({
baseURL: (this.options.ollamaBaseUrl || "http://localhost:11434") + "/v1",
apiKey: "ollama",
})
}
async createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<ApiHandlerMessageResponse> {
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
...convertToOpenAiMessages(messages),
]
const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
},
}))
const createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: this.options.ollamaModelId ?? "",
messages: openAiMessages,
temperature: 0.2,
tools: openAiTools,
tool_choice: "auto",
}
const completion = await this.client.chat.completions.create(createParams)
const errorMessage = (completion as any).error?.message
if (errorMessage) {
throw new Error(errorMessage)
}
const anthropicMessage = convertToAnthropicMessage(completion)
return { message: anthropicMessage }
}
getModel(): { id: string; info: ModelInfo } {
return {
id: this.options.ollamaModelId ?? "",
info: openAiModelInfoSaneDefaults,
}
}
}

View File

@@ -0,0 +1,94 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import { ApiHandler, ApiHandlerMessageResponse } from "."
import {
ApiHandlerOptions,
ModelInfo,
openAiNativeDefaultModelId,
OpenAiNativeModelId,
openAiNativeModels,
} from "../shared/api"
import { convertToAnthropicMessage, convertToOpenAiMessages } from "./transform/openai-format"
import { convertO1ResponseToAnthropicMessage, convertToO1Messages } from "./transform/o1-format"
export class OpenAiNativeHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: OpenAI
constructor(options: ApiHandlerOptions) {
this.options = options
this.client = new OpenAI({
apiKey: this.options.openAiNativeApiKey,
})
}
async createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<ApiHandlerMessageResponse> {
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
...convertToOpenAiMessages(messages),
]
const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
},
}))
let createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
switch (this.getModel().id) {
case "o1-preview":
case "o1-mini":
createParams = {
model: this.getModel().id,
max_completion_tokens: this.getModel().info.maxTokens,
messages: convertToO1Messages(convertToOpenAiMessages(messages), systemPrompt),
}
break
default:
createParams = {
model: this.getModel().id,
max_completion_tokens: this.getModel().info.maxTokens,
temperature: 0.2,
messages: openAiMessages,
tools: openAiTools,
tool_choice: "auto",
}
break
}
const completion = await this.client.chat.completions.create(createParams)
const errorMessage = (completion as any).error?.message
if (errorMessage) {
throw new Error(errorMessage)
}
let anthropicMessage: Anthropic.Messages.Message
switch (this.getModel().id) {
case "o1-preview":
case "o1-mini":
anthropicMessage = convertO1ResponseToAnthropicMessage(completion)
break
default:
anthropicMessage = convertToAnthropicMessage(completion)
break
}
return { message: anthropicMessage }
}
getModel(): { id: OpenAiNativeModelId; info: ModelInfo } {
const modelId = this.options.apiModelId
if (modelId && modelId in openAiNativeModels) {
const id = modelId as OpenAiNativeModelId
return { id, info: openAiNativeModels[id] }
}
return { id: openAiNativeDefaultModelId, info: openAiNativeModels[openAiNativeDefaultModelId] }
}
}

View File

@@ -0,0 +1,70 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI, { AzureOpenAI } from "openai"
import { ApiHandler, ApiHandlerMessageResponse } from "../index"
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
import { convertToAnthropicMessage, convertToOpenAiMessages } from "../transform/openai-format"
export class OpenAiHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: OpenAI
constructor(options: ApiHandlerOptions) {
this.options = options
// Azure API shape slightly differs from the core API shape: https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
if (this.options.openAiBaseUrl?.toLowerCase().includes("azure.com")) {
this.client = new AzureOpenAI({
baseURL: this.options.openAiBaseUrl,
apiKey: this.options.openAiApiKey,
// https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
// https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#api-specs
// (make sure to update API options placeholder)
apiVersion: this.options.azureApiVersion || "2024-08-01-preview",
})
} else {
this.client = new OpenAI({
baseURL: this.options.openAiBaseUrl,
apiKey: this.options.openAiApiKey,
})
}
}
async createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<ApiHandlerMessageResponse> {
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
...convertToOpenAiMessages(messages),
]
const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
},
}))
const createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: this.options.openAiModelId ?? "",
messages: openAiMessages,
temperature: 0.2,
tools: openAiTools,
tool_choice: "auto",
}
const completion = await this.client.chat.completions.create(createParams)
const errorMessage = (completion as any).error?.message
if (errorMessage) {
throw new Error(errorMessage)
}
const anthropicMessage = convertToAnthropicMessage(completion)
return { message: anthropicMessage }
}
getModel(): { id: string; info: ModelInfo } {
return {
id: this.options.openAiModelId ?? "",
info: openAiModelInfoSaneDefaults,
}
}
}

View File

@@ -0,0 +1,290 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import { ApiHandler, ApiHandlerMessageResponse } from "."
import {
ApiHandlerOptions,
ModelInfo,
openRouterDefaultModelId,
OpenRouterModelId,
openRouterModels,
} from "../shared/api"
import { convertToAnthropicMessage, convertToOpenAiMessages } from "./transform/openai-format"
import axios from "axios"
import { convertO1ResponseToAnthropicMessage, convertToO1Messages } from "./transform/o1-format"
export class OpenRouterHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: OpenAI
constructor(options: ApiHandlerOptions) {
this.options = options
this.client = new OpenAI({
baseURL: "https://openrouter.ai/api/v1",
apiKey: this.options.openRouterApiKey,
defaultHeaders: {
"HTTP-Referer": "https://github.com/saoudrizwan/claude-dev", // Optional, for including your app on openrouter.ai rankings.
"X-Title": "claude-dev", // Optional. Shows in rankings on openrouter.ai.
},
})
}
async createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<ApiHandlerMessageResponse> {
// Convert Anthropic messages to OpenAI format
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
...convertToOpenAiMessages(messages),
]
// prompt caching: https://openrouter.ai/docs/prompt-caching
switch (this.getModel().id) {
case "anthropic/claude-3.5-sonnet:beta":
case "anthropic/claude-3-haiku:beta":
case "anthropic/claude-3-opus:beta":
openAiMessages[0] = {
role: "system",
content: [
{
type: "text",
text: systemPrompt,
// @ts-ignore-next-line
cache_control: { type: "ephemeral" },
},
],
}
// Add cache_control to the last two user messages
const lastTwoUserMessages = openAiMessages.filter((msg) => msg.role === "user").slice(-2)
lastTwoUserMessages.forEach((msg) => {
if (typeof msg.content === "string") {
msg.content = [{ type: "text", text: msg.content }]
}
if (Array.isArray(msg.content)) {
let lastTextPart = msg.content.filter((part) => part.type === "text").pop()
if (!lastTextPart) {
lastTextPart = { type: "text", text: "..." }
msg.content.push(lastTextPart)
}
// @ts-ignore-next-line
lastTextPart["cache_control"] = { type: "ephemeral" }
}
})
break
default:
break
}
// Convert Anthropic tools to OpenAI tools
const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.input_schema, // matches anthropic tool input schema (see https://platform.openai.com/docs/guides/function-calling)
},
}))
let createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
switch (this.getModel().id) {
case "openai/o1-preview":
case "openai/o1-mini":
createParams = {
model: this.getModel().id,
max_tokens: this.getModel().info.maxTokens,
temperature: 0.2,
messages: convertToO1Messages(convertToOpenAiMessages(messages), systemPrompt),
}
break
default:
createParams = {
model: this.getModel().id,
max_tokens: this.getModel().info.maxTokens,
temperature: 0.2,
messages: openAiMessages,
tools: openAiTools,
tool_choice: "auto",
}
break
}
let completion: OpenAI.Chat.Completions.ChatCompletion
try {
completion = await this.client.chat.completions.create(createParams)
} catch (error) {
console.error("Error creating message from normal request. Using streaming fallback...", error)
completion = await this.streamCompletion(createParams)
}
const errorMessage = (completion as any).error?.message // openrouter returns an error object instead of the openai sdk throwing an error
if (errorMessage) {
throw new Error(errorMessage)
}
let anthropicMessage: Anthropic.Messages.Message
switch (this.getModel().id) {
case "openai/o1-preview":
case "openai/o1-mini":
anthropicMessage = convertO1ResponseToAnthropicMessage(completion)
break
default:
anthropicMessage = convertToAnthropicMessage(completion)
break
}
// Check if the model is Gemini Flash and remove extra escapes in tool result args
// switch (this.getModel().id) {
// case "google/gemini-pro-1.5":
// case "google/gemini-flash-1.5":
// const content = anthropicMessage.content
// for (const block of content) {
// if (
// block.type === "tool_use" &&
// typeof block.input === "object" &&
// block.input !== null &&
// "content" in block.input &&
// typeof block.input.content === "string"
// ) {
// block.input.content = unescapeGeminiContent(block.input.content)
// }
// }
// break
// default:
// break
// }
const genId = completion.id
// Log the generation details from OpenRouter API
try {
const response = await axios.get(`https://openrouter.ai/api/v1/generation?id=${genId}`, {
headers: {
Authorization: `Bearer ${this.options.openRouterApiKey}`,
},
})
// @ts-ignore-next-line
anthropicMessage.usage.total_cost = response.data?.data?.total_cost
console.log("OpenRouter generation details:", response.data)
} catch (error) {
console.error("Error fetching OpenRouter generation details:", error)
}
return { message: anthropicMessage }
}
/*
Streaming the completion is a fallback behavior for when a normal request responds with an invalid JSON object ("Unexpected end of JSON input"). This would usually happen in cases where the model makes tool calls with large arguments. After talking with OpenRouter folks, streaming mitigates this issue for now until they fix the underlying problem ("some weird data from anthropic got decoded wrongly and crashed the buffer")
*/
async streamCompletion(
createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
): Promise<OpenAI.Chat.Completions.ChatCompletion> {
const stream = await this.client.chat.completions.create({
...createParams,
stream: true,
})
let textContent: string = ""
let toolCalls: OpenAI.Chat.ChatCompletionMessageToolCall[] = []
try {
let currentToolCall: (OpenAI.Chat.ChatCompletionMessageToolCall & { index?: number }) | null = null
for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta
if (delta?.content) {
textContent += delta.content
}
if (delta?.tool_calls) {
for (const toolCallDelta of delta.tool_calls) {
if (toolCallDelta.index === undefined) {
continue
}
if (!currentToolCall || currentToolCall.index !== toolCallDelta.index) {
// new index means new tool call, so add the previous one to the list
if (currentToolCall) {
toolCalls.push(currentToolCall)
}
currentToolCall = {
index: toolCallDelta.index,
id: toolCallDelta.id || "",
type: "function",
function: { name: "", arguments: "" },
}
}
if (toolCallDelta.id) {
currentToolCall.id = toolCallDelta.id
}
if (toolCallDelta.type) {
currentToolCall.type = toolCallDelta.type
}
if (toolCallDelta.function) {
if (toolCallDelta.function.name) {
currentToolCall.function.name = toolCallDelta.function.name
}
if (toolCallDelta.function.arguments) {
currentToolCall.function.arguments =
(currentToolCall.function.arguments || "") + toolCallDelta.function.arguments
}
}
}
}
}
if (currentToolCall) {
toolCalls.push(currentToolCall)
}
} catch (error) {
console.error("Error streaming completion:", error)
throw error
}
// Usage information is not available in streaming responses, so we need to estimate token counts
function approximateTokenCount(text: string): number {
return Math.ceil(new TextEncoder().encode(text).length / 4)
}
const promptTokens = approximateTokenCount(
createParams.messages
.map((m) => (typeof m.content === "string" ? m.content : JSON.stringify(m.content)))
.join(" ")
)
const completionTokens = approximateTokenCount(
textContent + toolCalls.map((toolCall) => toolCall.function.arguments || "").join(" ")
)
const completion: OpenAI.Chat.Completions.ChatCompletion = {
created: Date.now(),
object: "chat.completion",
id: `openrouter-${Date.now()}-${Math.random().toString(36).slice(2, 11)}`, // this ID won't be traceable back to OpenRouter's systems if you need to debug issues
choices: [
{
message: {
role: "assistant",
content: textContent,
tool_calls: toolCalls.length > 0 ? toolCalls : undefined,
refusal: null,
},
finish_reason: toolCalls.length > 0 ? "tool_calls" : "stop",
index: 0,
logprobs: null,
},
],
model: this.getModel().id,
usage: {
prompt_tokens: promptTokens,
completion_tokens: completionTokens,
total_tokens: promptTokens + completionTokens,
},
}
return completion
}
getModel(): { id: OpenRouterModelId; info: ModelInfo } {
const modelId = this.options.apiModelId
if (modelId && modelId in openRouterModels) {
const id = modelId as OpenRouterModelId
return { id, info: openRouterModels[id] }
}
return { id: openRouterDefaultModelId, info: openRouterModels[openRouterDefaultModelId] }
}
}

View File

@@ -0,0 +1,45 @@
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
import { Anthropic } from "@anthropic-ai/sdk"
import { ApiHandler, ApiHandlerMessageResponse } from "."
import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../shared/api"
// https://docs.anthropic.com/en/api/claude-on-vertex-ai
export class VertexHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: AnthropicVertex
constructor(options: ApiHandlerOptions) {
this.options = options
this.client = new AnthropicVertex({
projectId: this.options.vertexProjectId,
// https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#regions
region: this.options.vertexRegion,
})
}
async createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<ApiHandlerMessageResponse> {
const message = await this.client.messages.create({
model: this.getModel().id,
max_tokens: this.getModel().info.maxTokens,
temperature: 0.2,
system: systemPrompt,
messages,
tools,
tool_choice: { type: "auto" },
})
return { message }
}
getModel(): { id: VertexModelId; info: ModelInfo } {
const modelId = this.options.apiModelId
if (modelId && modelId in vertexModels) {
const id = modelId as VertexModelId
return { id, info: vertexModels[id] }
}
return { id: vertexDefaultModelId, info: vertexModels[vertexDefaultModelId] }
}
}