mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-22 05:11:06 -05:00
Refactor API
This commit is contained in:
118
src/api/providers/anthropic.ts
Normal file
118
src/api/providers/anthropic.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { ApiHandler, ApiHandlerMessageResponse } from "../index"
|
||||
import {
|
||||
anthropicDefaultModelId,
|
||||
AnthropicModelId,
|
||||
anthropicModels,
|
||||
ApiHandlerOptions,
|
||||
ModelInfo,
|
||||
} from "../../shared/api"
|
||||
|
||||
export class AnthropicHandler implements ApiHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: Anthropic
|
||||
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
this.options = options
|
||||
this.client = new Anthropic({
|
||||
apiKey: this.options.apiKey,
|
||||
baseURL: this.options.anthropicBaseUrl || undefined,
|
||||
})
|
||||
}
|
||||
|
||||
async createMessage(
|
||||
systemPrompt: string,
|
||||
messages: Anthropic.Messages.MessageParam[],
|
||||
tools: Anthropic.Messages.Tool[]
|
||||
): Promise<ApiHandlerMessageResponse> {
|
||||
const modelId = this.getModel().id
|
||||
switch (modelId) {
|
||||
case "claude-3-5-sonnet-20240620":
|
||||
case "claude-3-opus-20240229":
|
||||
case "claude-3-haiku-20240307": {
|
||||
/*
|
||||
The latest message will be the new user message, one before will be the assistant message from a previous request, and the user message before that will be a previously cached user message. So we need to mark the latest user message as ephemeral to cache it for the next request, and mark the second to last user message as ephemeral to let the server know the last message to retrieve from the cache for the current request..
|
||||
*/
|
||||
const userMsgIndices = messages.reduce(
|
||||
(acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc),
|
||||
[] as number[]
|
||||
)
|
||||
const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1
|
||||
const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1
|
||||
const message = await this.client.beta.promptCaching.messages.create(
|
||||
{
|
||||
model: modelId,
|
||||
max_tokens: this.getModel().info.maxTokens,
|
||||
temperature: 0.2,
|
||||
system: [{ text: systemPrompt, type: "text", cache_control: { type: "ephemeral" } }], // setting cache breakpoint for system prompt so new tasks can reuse it
|
||||
messages: messages.map((message, index) => {
|
||||
if (index === lastUserMsgIndex || index === secondLastMsgUserIndex) {
|
||||
return {
|
||||
...message,
|
||||
content:
|
||||
typeof message.content === "string"
|
||||
? [
|
||||
{
|
||||
type: "text",
|
||||
text: message.content,
|
||||
cache_control: { type: "ephemeral" },
|
||||
},
|
||||
]
|
||||
: message.content.map((content, contentIndex) =>
|
||||
contentIndex === message.content.length - 1
|
||||
? { ...content, cache_control: { type: "ephemeral" } }
|
||||
: content
|
||||
),
|
||||
}
|
||||
}
|
||||
return message
|
||||
}),
|
||||
tools, // cache breakpoints go from tools > system > messages, and since tools dont change, we can just set the breakpoint at the end of system (this avoids having to set a breakpoint at the end of tools which by itself does not meet min requirements for haiku caching)
|
||||
tool_choice: { type: "auto" },
|
||||
},
|
||||
(() => {
|
||||
// prompt caching: https://x.com/alexalbert__/status/1823751995901272068
|
||||
// https://github.com/anthropics/anthropic-sdk-typescript?tab=readme-ov-file#default-headers
|
||||
// https://github.com/anthropics/anthropic-sdk-typescript/commit/c920b77fc67bd839bfeb6716ceab9d7c9bbe7393
|
||||
switch (modelId) {
|
||||
case "claude-3-5-sonnet-20240620":
|
||||
return {
|
||||
headers: {
|
||||
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||
},
|
||||
}
|
||||
case "claude-3-haiku-20240307":
|
||||
return {
|
||||
headers: { "anthropic-beta": "prompt-caching-2024-07-31" },
|
||||
}
|
||||
default:
|
||||
return undefined
|
||||
}
|
||||
})()
|
||||
)
|
||||
return { message }
|
||||
}
|
||||
default: {
|
||||
const message = await this.client.messages.create({
|
||||
model: modelId,
|
||||
max_tokens: this.getModel().info.maxTokens,
|
||||
temperature: 0.2,
|
||||
system: [{ text: systemPrompt, type: "text" }],
|
||||
messages,
|
||||
tools,
|
||||
tool_choice: { type: "auto" },
|
||||
})
|
||||
return { message }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getModel(): { id: AnthropicModelId; info: ModelInfo } {
|
||||
const modelId = this.options.apiModelId
|
||||
if (modelId && modelId in anthropicModels) {
|
||||
const id = modelId as AnthropicModelId
|
||||
return { id, info: anthropicModels[id] }
|
||||
}
|
||||
return { id: anthropicDefaultModelId, info: anthropicModels[anthropicDefaultModelId] }
|
||||
}
|
||||
}
|
||||
51
src/api/providers/bedrock.ts
Normal file
51
src/api/providers/bedrock.ts
Normal file
@@ -0,0 +1,51 @@
|
||||
import AnthropicBedrock from "@anthropic-ai/bedrock-sdk"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { ApiHandler, ApiHandlerMessageResponse } from "."
|
||||
import { ApiHandlerOptions, bedrockDefaultModelId, BedrockModelId, bedrockModels, ModelInfo } from "../shared/api"
|
||||
|
||||
// https://docs.anthropic.com/en/api/claude-on-amazon-bedrock
|
||||
export class AwsBedrockHandler implements ApiHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: AnthropicBedrock
|
||||
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
this.options = options
|
||||
this.client = new AnthropicBedrock({
|
||||
// Authenticate by either providing the keys below or use the default AWS credential providers, such as
|
||||
// using ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables.
|
||||
...(this.options.awsAccessKey ? { awsAccessKey: this.options.awsAccessKey } : {}),
|
||||
...(this.options.awsSecretKey ? { awsSecretKey: this.options.awsSecretKey } : {}),
|
||||
...(this.options.awsSessionToken ? { awsSessionToken: this.options.awsSessionToken } : {}),
|
||||
|
||||
// awsRegion changes the aws region to which the request is made. By default, we read AWS_REGION,
|
||||
// and if that's not present, we default to us-east-1. Note that we do not read ~/.aws/config for the region.
|
||||
awsRegion: this.options.awsRegion,
|
||||
})
|
||||
}
|
||||
|
||||
async createMessage(
|
||||
systemPrompt: string,
|
||||
messages: Anthropic.Messages.MessageParam[],
|
||||
tools: Anthropic.Messages.Tool[]
|
||||
): Promise<ApiHandlerMessageResponse> {
|
||||
const message = await this.client.messages.create({
|
||||
model: this.getModel().id,
|
||||
max_tokens: this.getModel().info.maxTokens,
|
||||
temperature: 0.2,
|
||||
system: systemPrompt,
|
||||
messages,
|
||||
tools,
|
||||
tool_choice: { type: "auto" },
|
||||
})
|
||||
return { message }
|
||||
}
|
||||
|
||||
getModel(): { id: BedrockModelId; info: ModelInfo } {
|
||||
const modelId = this.options.apiModelId
|
||||
if (modelId && modelId in bedrockModels) {
|
||||
const id = modelId as BedrockModelId
|
||||
return { id, info: bedrockModels[id] }
|
||||
}
|
||||
return { id: bedrockDefaultModelId, info: bedrockModels[bedrockDefaultModelId] }
|
||||
}
|
||||
}
|
||||
58
src/api/providers/gemini.ts
Normal file
58
src/api/providers/gemini.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { FunctionCallingMode, GoogleGenerativeAI } from "@google/generative-ai"
|
||||
import { ApiHandler, ApiHandlerMessageResponse } from "."
|
||||
import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../shared/api"
|
||||
import {
|
||||
convertAnthropicMessageToGemini,
|
||||
convertAnthropicToolToGemini,
|
||||
convertGeminiResponseToAnthropic,
|
||||
} from "./transform/gemini-format"
|
||||
|
||||
export class GeminiHandler implements ApiHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: GoogleGenerativeAI
|
||||
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
if (!options.geminiApiKey) {
|
||||
throw new Error("API key is required for Google Gemini")
|
||||
}
|
||||
this.options = options
|
||||
this.client = new GoogleGenerativeAI(options.geminiApiKey)
|
||||
}
|
||||
|
||||
async createMessage(
|
||||
systemPrompt: string,
|
||||
messages: Anthropic.Messages.MessageParam[],
|
||||
tools: Anthropic.Messages.Tool[]
|
||||
): Promise<ApiHandlerMessageResponse> {
|
||||
const model = this.client.getGenerativeModel({
|
||||
model: this.getModel().id,
|
||||
systemInstruction: systemPrompt,
|
||||
tools: [{ functionDeclarations: tools.map(convertAnthropicToolToGemini) }],
|
||||
toolConfig: {
|
||||
functionCallingConfig: {
|
||||
mode: FunctionCallingMode.AUTO,
|
||||
},
|
||||
},
|
||||
})
|
||||
const result = await model.generateContent({
|
||||
contents: messages.map(convertAnthropicMessageToGemini),
|
||||
generationConfig: {
|
||||
maxOutputTokens: this.getModel().info.maxTokens,
|
||||
temperature: 0.2,
|
||||
},
|
||||
})
|
||||
const message = convertGeminiResponseToAnthropic(result.response)
|
||||
|
||||
return { message }
|
||||
}
|
||||
|
||||
getModel(): { id: GeminiModelId; info: ModelInfo } {
|
||||
const modelId = this.options.apiModelId
|
||||
if (modelId && modelId in geminiModels) {
|
||||
const id = modelId as GeminiModelId
|
||||
return { id, info: geminiModels[id] }
|
||||
}
|
||||
return { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] }
|
||||
}
|
||||
}
|
||||
58
src/api/providers/ollama.ts
Normal file
58
src/api/providers/ollama.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import OpenAI from "openai"
|
||||
import { ApiHandler, ApiHandlerMessageResponse } from "."
|
||||
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../shared/api"
|
||||
import { convertToAnthropicMessage, convertToOpenAiMessages } from "./transform/openai-format"
|
||||
|
||||
export class OllamaHandler implements ApiHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: OpenAI
|
||||
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
this.options = options
|
||||
this.client = new OpenAI({
|
||||
baseURL: (this.options.ollamaBaseUrl || "http://localhost:11434") + "/v1",
|
||||
apiKey: "ollama",
|
||||
})
|
||||
}
|
||||
|
||||
async createMessage(
|
||||
systemPrompt: string,
|
||||
messages: Anthropic.Messages.MessageParam[],
|
||||
tools: Anthropic.Messages.Tool[]
|
||||
): Promise<ApiHandlerMessageResponse> {
|
||||
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
|
||||
{ role: "system", content: systemPrompt },
|
||||
...convertToOpenAiMessages(messages),
|
||||
]
|
||||
const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.input_schema,
|
||||
},
|
||||
}))
|
||||
const createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
|
||||
model: this.options.ollamaModelId ?? "",
|
||||
messages: openAiMessages,
|
||||
temperature: 0.2,
|
||||
tools: openAiTools,
|
||||
tool_choice: "auto",
|
||||
}
|
||||
const completion = await this.client.chat.completions.create(createParams)
|
||||
const errorMessage = (completion as any).error?.message
|
||||
if (errorMessage) {
|
||||
throw new Error(errorMessage)
|
||||
}
|
||||
const anthropicMessage = convertToAnthropicMessage(completion)
|
||||
return { message: anthropicMessage }
|
||||
}
|
||||
|
||||
getModel(): { id: string; info: ModelInfo } {
|
||||
return {
|
||||
id: this.options.ollamaModelId ?? "",
|
||||
info: openAiModelInfoSaneDefaults,
|
||||
}
|
||||
}
|
||||
}
|
||||
94
src/api/providers/openai-native.ts
Normal file
94
src/api/providers/openai-native.ts
Normal file
@@ -0,0 +1,94 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import OpenAI from "openai"
|
||||
import { ApiHandler, ApiHandlerMessageResponse } from "."
|
||||
import {
|
||||
ApiHandlerOptions,
|
||||
ModelInfo,
|
||||
openAiNativeDefaultModelId,
|
||||
OpenAiNativeModelId,
|
||||
openAiNativeModels,
|
||||
} from "../shared/api"
|
||||
import { convertToAnthropicMessage, convertToOpenAiMessages } from "./transform/openai-format"
|
||||
import { convertO1ResponseToAnthropicMessage, convertToO1Messages } from "./transform/o1-format"
|
||||
|
||||
export class OpenAiNativeHandler implements ApiHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: OpenAI
|
||||
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
this.options = options
|
||||
this.client = new OpenAI({
|
||||
apiKey: this.options.openAiNativeApiKey,
|
||||
})
|
||||
}
|
||||
|
||||
async createMessage(
|
||||
systemPrompt: string,
|
||||
messages: Anthropic.Messages.MessageParam[],
|
||||
tools: Anthropic.Messages.Tool[]
|
||||
): Promise<ApiHandlerMessageResponse> {
|
||||
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
|
||||
{ role: "system", content: systemPrompt },
|
||||
...convertToOpenAiMessages(messages),
|
||||
]
|
||||
const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.input_schema,
|
||||
},
|
||||
}))
|
||||
|
||||
let createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
|
||||
|
||||
switch (this.getModel().id) {
|
||||
case "o1-preview":
|
||||
case "o1-mini":
|
||||
createParams = {
|
||||
model: this.getModel().id,
|
||||
max_completion_tokens: this.getModel().info.maxTokens,
|
||||
messages: convertToO1Messages(convertToOpenAiMessages(messages), systemPrompt),
|
||||
}
|
||||
break
|
||||
default:
|
||||
createParams = {
|
||||
model: this.getModel().id,
|
||||
max_completion_tokens: this.getModel().info.maxTokens,
|
||||
temperature: 0.2,
|
||||
messages: openAiMessages,
|
||||
tools: openAiTools,
|
||||
tool_choice: "auto",
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
const completion = await this.client.chat.completions.create(createParams)
|
||||
const errorMessage = (completion as any).error?.message
|
||||
if (errorMessage) {
|
||||
throw new Error(errorMessage)
|
||||
}
|
||||
|
||||
let anthropicMessage: Anthropic.Messages.Message
|
||||
switch (this.getModel().id) {
|
||||
case "o1-preview":
|
||||
case "o1-mini":
|
||||
anthropicMessage = convertO1ResponseToAnthropicMessage(completion)
|
||||
break
|
||||
default:
|
||||
anthropicMessage = convertToAnthropicMessage(completion)
|
||||
break
|
||||
}
|
||||
|
||||
return { message: anthropicMessage }
|
||||
}
|
||||
|
||||
getModel(): { id: OpenAiNativeModelId; info: ModelInfo } {
|
||||
const modelId = this.options.apiModelId
|
||||
if (modelId && modelId in openAiNativeModels) {
|
||||
const id = modelId as OpenAiNativeModelId
|
||||
return { id, info: openAiNativeModels[id] }
|
||||
}
|
||||
return { id: openAiNativeDefaultModelId, info: openAiNativeModels[openAiNativeDefaultModelId] }
|
||||
}
|
||||
}
|
||||
70
src/api/providers/openai.ts
Normal file
70
src/api/providers/openai.ts
Normal file
@@ -0,0 +1,70 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import OpenAI, { AzureOpenAI } from "openai"
|
||||
import { ApiHandler, ApiHandlerMessageResponse } from "../index"
|
||||
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
|
||||
import { convertToAnthropicMessage, convertToOpenAiMessages } from "../transform/openai-format"
|
||||
|
||||
export class OpenAiHandler implements ApiHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: OpenAI
|
||||
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
this.options = options
|
||||
// Azure API shape slightly differs from the core API shape: https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
|
||||
if (this.options.openAiBaseUrl?.toLowerCase().includes("azure.com")) {
|
||||
this.client = new AzureOpenAI({
|
||||
baseURL: this.options.openAiBaseUrl,
|
||||
apiKey: this.options.openAiApiKey,
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#api-specs
|
||||
// (make sure to update API options placeholder)
|
||||
apiVersion: this.options.azureApiVersion || "2024-08-01-preview",
|
||||
})
|
||||
} else {
|
||||
this.client = new OpenAI({
|
||||
baseURL: this.options.openAiBaseUrl,
|
||||
apiKey: this.options.openAiApiKey,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async createMessage(
|
||||
systemPrompt: string,
|
||||
messages: Anthropic.Messages.MessageParam[],
|
||||
tools: Anthropic.Messages.Tool[]
|
||||
): Promise<ApiHandlerMessageResponse> {
|
||||
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
|
||||
{ role: "system", content: systemPrompt },
|
||||
...convertToOpenAiMessages(messages),
|
||||
]
|
||||
const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.input_schema,
|
||||
},
|
||||
}))
|
||||
const createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
|
||||
model: this.options.openAiModelId ?? "",
|
||||
messages: openAiMessages,
|
||||
temperature: 0.2,
|
||||
tools: openAiTools,
|
||||
tool_choice: "auto",
|
||||
}
|
||||
const completion = await this.client.chat.completions.create(createParams)
|
||||
const errorMessage = (completion as any).error?.message
|
||||
if (errorMessage) {
|
||||
throw new Error(errorMessage)
|
||||
}
|
||||
const anthropicMessage = convertToAnthropicMessage(completion)
|
||||
return { message: anthropicMessage }
|
||||
}
|
||||
|
||||
getModel(): { id: string; info: ModelInfo } {
|
||||
return {
|
||||
id: this.options.openAiModelId ?? "",
|
||||
info: openAiModelInfoSaneDefaults,
|
||||
}
|
||||
}
|
||||
}
|
||||
290
src/api/providers/openrouter.ts
Normal file
290
src/api/providers/openrouter.ts
Normal file
@@ -0,0 +1,290 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import OpenAI from "openai"
|
||||
import { ApiHandler, ApiHandlerMessageResponse } from "."
|
||||
import {
|
||||
ApiHandlerOptions,
|
||||
ModelInfo,
|
||||
openRouterDefaultModelId,
|
||||
OpenRouterModelId,
|
||||
openRouterModels,
|
||||
} from "../shared/api"
|
||||
import { convertToAnthropicMessage, convertToOpenAiMessages } from "./transform/openai-format"
|
||||
import axios from "axios"
|
||||
import { convertO1ResponseToAnthropicMessage, convertToO1Messages } from "./transform/o1-format"
|
||||
|
||||
export class OpenRouterHandler implements ApiHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: OpenAI
|
||||
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
this.options = options
|
||||
this.client = new OpenAI({
|
||||
baseURL: "https://openrouter.ai/api/v1",
|
||||
apiKey: this.options.openRouterApiKey,
|
||||
defaultHeaders: {
|
||||
"HTTP-Referer": "https://github.com/saoudrizwan/claude-dev", // Optional, for including your app on openrouter.ai rankings.
|
||||
"X-Title": "claude-dev", // Optional. Shows in rankings on openrouter.ai.
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
async createMessage(
|
||||
systemPrompt: string,
|
||||
messages: Anthropic.Messages.MessageParam[],
|
||||
tools: Anthropic.Messages.Tool[]
|
||||
): Promise<ApiHandlerMessageResponse> {
|
||||
// Convert Anthropic messages to OpenAI format
|
||||
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
|
||||
{ role: "system", content: systemPrompt },
|
||||
...convertToOpenAiMessages(messages),
|
||||
]
|
||||
|
||||
// prompt caching: https://openrouter.ai/docs/prompt-caching
|
||||
switch (this.getModel().id) {
|
||||
case "anthropic/claude-3.5-sonnet:beta":
|
||||
case "anthropic/claude-3-haiku:beta":
|
||||
case "anthropic/claude-3-opus:beta":
|
||||
openAiMessages[0] = {
|
||||
role: "system",
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: systemPrompt,
|
||||
// @ts-ignore-next-line
|
||||
cache_control: { type: "ephemeral" },
|
||||
},
|
||||
],
|
||||
}
|
||||
// Add cache_control to the last two user messages
|
||||
const lastTwoUserMessages = openAiMessages.filter((msg) => msg.role === "user").slice(-2)
|
||||
lastTwoUserMessages.forEach((msg) => {
|
||||
if (typeof msg.content === "string") {
|
||||
msg.content = [{ type: "text", text: msg.content }]
|
||||
}
|
||||
if (Array.isArray(msg.content)) {
|
||||
let lastTextPart = msg.content.filter((part) => part.type === "text").pop()
|
||||
|
||||
if (!lastTextPart) {
|
||||
lastTextPart = { type: "text", text: "..." }
|
||||
msg.content.push(lastTextPart)
|
||||
}
|
||||
// @ts-ignore-next-line
|
||||
lastTextPart["cache_control"] = { type: "ephemeral" }
|
||||
}
|
||||
})
|
||||
break
|
||||
default:
|
||||
break
|
||||
}
|
||||
|
||||
// Convert Anthropic tools to OpenAI tools
|
||||
const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.input_schema, // matches anthropic tool input schema (see https://platform.openai.com/docs/guides/function-calling)
|
||||
},
|
||||
}))
|
||||
|
||||
let createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
|
||||
|
||||
switch (this.getModel().id) {
|
||||
case "openai/o1-preview":
|
||||
case "openai/o1-mini":
|
||||
createParams = {
|
||||
model: this.getModel().id,
|
||||
max_tokens: this.getModel().info.maxTokens,
|
||||
temperature: 0.2,
|
||||
messages: convertToO1Messages(convertToOpenAiMessages(messages), systemPrompt),
|
||||
}
|
||||
break
|
||||
default:
|
||||
createParams = {
|
||||
model: this.getModel().id,
|
||||
max_tokens: this.getModel().info.maxTokens,
|
||||
temperature: 0.2,
|
||||
messages: openAiMessages,
|
||||
tools: openAiTools,
|
||||
tool_choice: "auto",
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
let completion: OpenAI.Chat.Completions.ChatCompletion
|
||||
try {
|
||||
completion = await this.client.chat.completions.create(createParams)
|
||||
} catch (error) {
|
||||
console.error("Error creating message from normal request. Using streaming fallback...", error)
|
||||
completion = await this.streamCompletion(createParams)
|
||||
}
|
||||
|
||||
const errorMessage = (completion as any).error?.message // openrouter returns an error object instead of the openai sdk throwing an error
|
||||
if (errorMessage) {
|
||||
throw new Error(errorMessage)
|
||||
}
|
||||
|
||||
let anthropicMessage: Anthropic.Messages.Message
|
||||
switch (this.getModel().id) {
|
||||
case "openai/o1-preview":
|
||||
case "openai/o1-mini":
|
||||
anthropicMessage = convertO1ResponseToAnthropicMessage(completion)
|
||||
break
|
||||
default:
|
||||
anthropicMessage = convertToAnthropicMessage(completion)
|
||||
break
|
||||
}
|
||||
|
||||
// Check if the model is Gemini Flash and remove extra escapes in tool result args
|
||||
// switch (this.getModel().id) {
|
||||
// case "google/gemini-pro-1.5":
|
||||
// case "google/gemini-flash-1.5":
|
||||
// const content = anthropicMessage.content
|
||||
// for (const block of content) {
|
||||
// if (
|
||||
// block.type === "tool_use" &&
|
||||
// typeof block.input === "object" &&
|
||||
// block.input !== null &&
|
||||
// "content" in block.input &&
|
||||
// typeof block.input.content === "string"
|
||||
// ) {
|
||||
// block.input.content = unescapeGeminiContent(block.input.content)
|
||||
// }
|
||||
// }
|
||||
// break
|
||||
// default:
|
||||
// break
|
||||
// }
|
||||
|
||||
const genId = completion.id
|
||||
// Log the generation details from OpenRouter API
|
||||
try {
|
||||
const response = await axios.get(`https://openrouter.ai/api/v1/generation?id=${genId}`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.options.openRouterApiKey}`,
|
||||
},
|
||||
})
|
||||
// @ts-ignore-next-line
|
||||
anthropicMessage.usage.total_cost = response.data?.data?.total_cost
|
||||
console.log("OpenRouter generation details:", response.data)
|
||||
} catch (error) {
|
||||
console.error("Error fetching OpenRouter generation details:", error)
|
||||
}
|
||||
|
||||
return { message: anthropicMessage }
|
||||
}
|
||||
|
||||
/*
|
||||
Streaming the completion is a fallback behavior for when a normal request responds with an invalid JSON object ("Unexpected end of JSON input"). This would usually happen in cases where the model makes tool calls with large arguments. After talking with OpenRouter folks, streaming mitigates this issue for now until they fix the underlying problem ("some weird data from anthropic got decoded wrongly and crashed the buffer")
|
||||
*/
|
||||
async streamCompletion(
|
||||
createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
|
||||
): Promise<OpenAI.Chat.Completions.ChatCompletion> {
|
||||
const stream = await this.client.chat.completions.create({
|
||||
...createParams,
|
||||
stream: true,
|
||||
})
|
||||
|
||||
let textContent: string = ""
|
||||
let toolCalls: OpenAI.Chat.ChatCompletionMessageToolCall[] = []
|
||||
|
||||
try {
|
||||
let currentToolCall: (OpenAI.Chat.ChatCompletionMessageToolCall & { index?: number }) | null = null
|
||||
for await (const chunk of stream) {
|
||||
const delta = chunk.choices[0]?.delta
|
||||
if (delta?.content) {
|
||||
textContent += delta.content
|
||||
}
|
||||
if (delta?.tool_calls) {
|
||||
for (const toolCallDelta of delta.tool_calls) {
|
||||
if (toolCallDelta.index === undefined) {
|
||||
continue
|
||||
}
|
||||
if (!currentToolCall || currentToolCall.index !== toolCallDelta.index) {
|
||||
// new index means new tool call, so add the previous one to the list
|
||||
if (currentToolCall) {
|
||||
toolCalls.push(currentToolCall)
|
||||
}
|
||||
currentToolCall = {
|
||||
index: toolCallDelta.index,
|
||||
id: toolCallDelta.id || "",
|
||||
type: "function",
|
||||
function: { name: "", arguments: "" },
|
||||
}
|
||||
}
|
||||
if (toolCallDelta.id) {
|
||||
currentToolCall.id = toolCallDelta.id
|
||||
}
|
||||
if (toolCallDelta.type) {
|
||||
currentToolCall.type = toolCallDelta.type
|
||||
}
|
||||
if (toolCallDelta.function) {
|
||||
if (toolCallDelta.function.name) {
|
||||
currentToolCall.function.name = toolCallDelta.function.name
|
||||
}
|
||||
if (toolCallDelta.function.arguments) {
|
||||
currentToolCall.function.arguments =
|
||||
(currentToolCall.function.arguments || "") + toolCallDelta.function.arguments
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (currentToolCall) {
|
||||
toolCalls.push(currentToolCall)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error streaming completion:", error)
|
||||
throw error
|
||||
}
|
||||
|
||||
// Usage information is not available in streaming responses, so we need to estimate token counts
|
||||
function approximateTokenCount(text: string): number {
|
||||
return Math.ceil(new TextEncoder().encode(text).length / 4)
|
||||
}
|
||||
const promptTokens = approximateTokenCount(
|
||||
createParams.messages
|
||||
.map((m) => (typeof m.content === "string" ? m.content : JSON.stringify(m.content)))
|
||||
.join(" ")
|
||||
)
|
||||
const completionTokens = approximateTokenCount(
|
||||
textContent + toolCalls.map((toolCall) => toolCall.function.arguments || "").join(" ")
|
||||
)
|
||||
|
||||
const completion: OpenAI.Chat.Completions.ChatCompletion = {
|
||||
created: Date.now(),
|
||||
object: "chat.completion",
|
||||
id: `openrouter-${Date.now()}-${Math.random().toString(36).slice(2, 11)}`, // this ID won't be traceable back to OpenRouter's systems if you need to debug issues
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: textContent,
|
||||
tool_calls: toolCalls.length > 0 ? toolCalls : undefined,
|
||||
refusal: null,
|
||||
},
|
||||
finish_reason: toolCalls.length > 0 ? "tool_calls" : "stop",
|
||||
index: 0,
|
||||
logprobs: null,
|
||||
},
|
||||
],
|
||||
model: this.getModel().id,
|
||||
usage: {
|
||||
prompt_tokens: promptTokens,
|
||||
completion_tokens: completionTokens,
|
||||
total_tokens: promptTokens + completionTokens,
|
||||
},
|
||||
}
|
||||
|
||||
return completion
|
||||
}
|
||||
|
||||
getModel(): { id: OpenRouterModelId; info: ModelInfo } {
|
||||
const modelId = this.options.apiModelId
|
||||
if (modelId && modelId in openRouterModels) {
|
||||
const id = modelId as OpenRouterModelId
|
||||
return { id, info: openRouterModels[id] }
|
||||
}
|
||||
return { id: openRouterDefaultModelId, info: openRouterModels[openRouterDefaultModelId] }
|
||||
}
|
||||
}
|
||||
45
src/api/providers/vertex.ts
Normal file
45
src/api/providers/vertex.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { ApiHandler, ApiHandlerMessageResponse } from "."
|
||||
import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../shared/api"
|
||||
|
||||
// https://docs.anthropic.com/en/api/claude-on-vertex-ai
|
||||
export class VertexHandler implements ApiHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: AnthropicVertex
|
||||
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
this.options = options
|
||||
this.client = new AnthropicVertex({
|
||||
projectId: this.options.vertexProjectId,
|
||||
// https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#regions
|
||||
region: this.options.vertexRegion,
|
||||
})
|
||||
}
|
||||
|
||||
async createMessage(
|
||||
systemPrompt: string,
|
||||
messages: Anthropic.Messages.MessageParam[],
|
||||
tools: Anthropic.Messages.Tool[]
|
||||
): Promise<ApiHandlerMessageResponse> {
|
||||
const message = await this.client.messages.create({
|
||||
model: this.getModel().id,
|
||||
max_tokens: this.getModel().info.maxTokens,
|
||||
temperature: 0.2,
|
||||
system: systemPrompt,
|
||||
messages,
|
||||
tools,
|
||||
tool_choice: { type: "auto" },
|
||||
})
|
||||
return { message }
|
||||
}
|
||||
|
||||
getModel(): { id: VertexModelId; info: ModelInfo } {
|
||||
const modelId = this.options.apiModelId
|
||||
if (modelId && modelId in vertexModels) {
|
||||
const id = modelId as VertexModelId
|
||||
return { id, info: vertexModels[id] }
|
||||
}
|
||||
return { id: vertexDefaultModelId, info: vertexModels[vertexDefaultModelId] }
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user