From 06ccaf6f6761bfe5b586156d56ec2d194da86797 Mon Sep 17 00:00:00 2001 From: Saoud Rizwan <7799382+saoudrizwan@users.noreply.github.com> Date: Sun, 29 Sep 2024 01:32:58 -0400 Subject: [PATCH] Implement streaming for all providers --- src/api/index.ts | 1 - src/api/providers/anthropic.ts | 6 +-- src/api/providers/bedrock.ts | 63 +++++++++++++++++++---- src/api/providers/gemini.ts | 42 ++++++++-------- src/api/providers/ollama.ts | 43 +++++++--------- src/api/providers/openai-native.ts | 80 ++++++++++-------------------- src/api/providers/openai.ts | 50 +++++++++---------- src/api/providers/openrouter.ts | 4 +- src/api/providers/vertex.ts | 65 +++++++++++++++++++----- src/api/transform/stream.ts | 4 +- src/core/ClaudeDev.ts | 4 +- 11 files changed, 201 insertions(+), 161 deletions(-) diff --git a/src/api/index.ts b/src/api/index.ts index 69e2daa..388b9ce 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -12,7 +12,6 @@ import { ApiStream } from "./transform/stream" export interface ApiHandler { createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream - getModel(): { id: string; info: ModelInfo } } diff --git a/src/api/providers/anthropic.ts b/src/api/providers/anthropic.ts index d01a44b..168d9d5 100644 --- a/src/api/providers/anthropic.ts +++ b/src/api/providers/anthropic.ts @@ -117,8 +117,8 @@ export class AnthropicHandler implements ApiHandler { type: "usage", inputTokens: usage.input_tokens || 0, outputTokens: usage.output_tokens || 0, - cacheWriteTokens: usage.cache_creation_input_tokens || 0, - cacheReadTokens: usage.cache_read_input_tokens || 0, + cacheWriteTokens: usage.cache_creation_input_tokens || undefined, + cacheReadTokens: usage.cache_read_input_tokens || undefined, } break case "message_delta": @@ -128,8 +128,6 @@ export class AnthropicHandler implements ApiHandler { type: "usage", inputTokens: 0, outputTokens: chunk.usage.output_tokens || 0, - cacheWriteTokens: 0, - cacheReadTokens: 0, } break case "message_stop": diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 7dd90e8..f098426 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -1,7 +1,8 @@ import AnthropicBedrock from "@anthropic-ai/bedrock-sdk" import { Anthropic } from "@anthropic-ai/sdk" -import { ApiHandler, ApiHandlerMessageResponse } from "../" +import { ApiHandler } from "../" import { ApiHandlerOptions, bedrockDefaultModelId, BedrockModelId, bedrockModels, ModelInfo } from "../../shared/api" +import { ApiStream } from "../transform/stream" // https://docs.anthropic.com/en/api/claude-on-amazon-bedrock export class AwsBedrockHandler implements ApiHandler { @@ -23,21 +24,61 @@ export class AwsBedrockHandler implements ApiHandler { }) } - async createMessage( - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - tools: Anthropic.Messages.Tool[] - ): Promise { - const message = await this.client.messages.create({ + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { + const stream = await this.client.messages.create({ model: this.getModel().id, max_tokens: this.getModel().info.maxTokens, - temperature: 0.2, + temperature: 0, system: systemPrompt, messages, - tools, - tool_choice: { type: "auto" }, + stream: true, }) - return { message } + for await (const chunk of stream) { + switch (chunk.type) { + case "message_start": + const usage = chunk.message.usage + yield { + type: "usage", + inputTokens: usage.input_tokens || 0, + outputTokens: usage.output_tokens || 0, + } + break + case "message_delta": + yield { + type: "usage", + inputTokens: 0, + outputTokens: chunk.usage.output_tokens || 0, + } + break + + case "content_block_start": + switch (chunk.content_block.type) { + case "text": + if (chunk.index > 0) { + yield { + type: "text", + text: "\n", + } + } + yield { + type: "text", + text: chunk.content_block.text, + } + break + } + break + case "content_block_delta": + switch (chunk.delta.type) { + case "text_delta": + yield { + type: "text", + text: chunk.delta.text, + } + break + } + break + } + } } getModel(): { id: BedrockModelId; info: ModelInfo } { diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index e1a7f63..3c667be 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -1,12 +1,9 @@ import { Anthropic } from "@anthropic-ai/sdk" -import { FunctionCallingMode, GoogleGenerativeAI } from "@google/generative-ai" -import { ApiHandler, ApiHandlerMessageResponse } from "../" +import { GoogleGenerativeAI } from "@google/generative-ai" +import { ApiHandler } from "../" import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../../shared/api" -import { - convertAnthropicMessageToGemini, - convertAnthropicToolToGemini, - convertGeminiResponseToAnthropic, -} from "../transform/gemini-format" +import { convertAnthropicMessageToGemini } from "../transform/gemini-format" +import { ApiStream } from "../transform/stream" export class GeminiHandler implements ApiHandler { private options: ApiHandlerOptions @@ -20,31 +17,32 @@ export class GeminiHandler implements ApiHandler { this.client = new GoogleGenerativeAI(options.geminiApiKey) } - async createMessage( - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - tools: Anthropic.Messages.Tool[] - ): Promise { + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { const model = this.client.getGenerativeModel({ model: this.getModel().id, systemInstruction: systemPrompt, - tools: [{ functionDeclarations: tools.map(convertAnthropicToolToGemini) }], - toolConfig: { - functionCallingConfig: { - mode: FunctionCallingMode.AUTO, - }, - }, }) - const result = await model.generateContent({ + const result = await model.generateContentStream({ contents: messages.map(convertAnthropicMessageToGemini), generationConfig: { maxOutputTokens: this.getModel().info.maxTokens, - temperature: 0.2, + temperature: 0, }, }) - const message = convertGeminiResponseToAnthropic(result.response) - return { message } + for await (const chunk of result.stream) { + yield { + type: "text", + text: chunk.text(), + } + } + + const response = await result.response + yield { + type: "usage", + inputTokens: response.usageMetadata?.promptTokenCount ?? 0, + outputTokens: response.usageMetadata?.candidatesTokenCount ?? 0, + } } getModel(): { id: GeminiModelId; info: ModelInfo } { diff --git a/src/api/providers/ollama.ts b/src/api/providers/ollama.ts index ecce564..67e77c3 100644 --- a/src/api/providers/ollama.ts +++ b/src/api/providers/ollama.ts @@ -1,8 +1,9 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" -import { ApiHandler, ApiHandlerMessageResponse } from "../" +import { ApiHandler } from "../" import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api" -import { convertToAnthropicMessage, convertToOpenAiMessages } from "../transform/openai-format" +import { convertToOpenAiMessages } from "../transform/openai-format" +import { ApiStream } from "../transform/stream" export class OllamaHandler implements ApiHandler { private options: ApiHandlerOptions @@ -16,37 +17,27 @@ export class OllamaHandler implements ApiHandler { }) } - async createMessage( - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - tools: Anthropic.Messages.Tool[] - ): Promise { + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ { role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages), ] - const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({ - type: "function", - function: { - name: tool.name, - description: tool.description, - parameters: tool.input_schema, - }, - })) - const createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { + + const stream = await this.client.chat.completions.create({ model: this.options.ollamaModelId ?? "", messages: openAiMessages, - temperature: 0.2, - tools: openAiTools, - tool_choice: "auto", + temperature: 0, + stream: true, + }) + for await (const chunk of stream) { + const delta = chunk.choices[0]?.delta + if (delta?.content) { + yield { + type: "text", + text: delta.content, + } + } } - const completion = await this.client.chat.completions.create(createParams) - const errorMessage = (completion as any).error?.message - if (errorMessage) { - throw new Error(errorMessage) - } - const anthropicMessage = convertToAnthropicMessage(completion) - return { message: anthropicMessage } } getModel(): { id: string; info: ModelInfo } { diff --git a/src/api/providers/openai-native.ts b/src/api/providers/openai-native.ts index aa1d334..e0f87d5 100644 --- a/src/api/providers/openai-native.ts +++ b/src/api/providers/openai-native.ts @@ -1,6 +1,6 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" -import { ApiHandler, ApiHandlerMessageResponse } from "../" +import { ApiHandler } from "../" import { ApiHandlerOptions, ModelInfo, @@ -8,8 +8,8 @@ import { OpenAiNativeModelId, openAiNativeModels, } from "../../shared/api" -import { convertToAnthropicMessage, convertToOpenAiMessages } from "../transform/openai-format" -import { convertO1ResponseToAnthropicMessage, convertToO1Messages } from "../transform/o1-format" +import { convertToOpenAiMessages } from "../transform/openai-format" +import { ApiStream } from "../transform/stream" export class OpenAiNativeHandler implements ApiHandler { private options: ApiHandlerOptions @@ -22,65 +22,39 @@ export class OpenAiNativeHandler implements ApiHandler { }) } - async createMessage( - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - tools: Anthropic.Messages.Tool[] - ): Promise { + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ { role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages), ] - const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({ - type: "function", - function: { - name: tool.name, - description: tool.description, - parameters: tool.input_schema, - }, - })) - let createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming + const stream = await this.client.chat.completions.create({ + model: this.getModel().id, + max_completion_tokens: this.getModel().info.maxTokens, + temperature: 0, + messages: openAiMessages, + stream: true, + stream_options: { include_usage: true }, + }) - switch (this.getModel().id) { - case "o1-preview": - case "o1-mini": - createParams = { - model: this.getModel().id, - max_completion_tokens: this.getModel().info.maxTokens, - messages: convertToO1Messages(convertToOpenAiMessages(messages), systemPrompt), + for await (const chunk of stream) { + const delta = chunk.choices[0]?.delta + if (delta?.content) { + yield { + type: "text", + text: delta.content, } - break - default: - createParams = { - model: this.getModel().id, - max_completion_tokens: this.getModel().info.maxTokens, - temperature: 0.2, - messages: openAiMessages, - tools: openAiTools, - tool_choice: "auto", + } + + // contains a null value except for the last chunk which contains the token usage statistics for the entire request + if (chunk.usage) { + yield { + type: "usage", + inputTokens: chunk.usage.prompt_tokens || 0, + outputTokens: chunk.usage.completion_tokens || 0, } - break + } } - - const completion = await this.client.chat.completions.create(createParams) - const errorMessage = (completion as any).error?.message - if (errorMessage) { - throw new Error(errorMessage) - } - - let anthropicMessage: Anthropic.Messages.Message - switch (this.getModel().id) { - case "o1-preview": - case "o1-mini": - anthropicMessage = convertO1ResponseToAnthropicMessage(completion) - break - default: - anthropicMessage = convertToAnthropicMessage(completion) - break - } - - return { message: anthropicMessage } } getModel(): { id: OpenAiNativeModelId; info: ModelInfo } { diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index ca6bbff..57cab17 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -1,13 +1,14 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI, { AzureOpenAI } from "openai" -import { ApiHandler, ApiHandlerMessageResponse } from "../index" import { ApiHandlerOptions, azureOpenAiDefaultApiVersion, ModelInfo, openAiModelInfoSaneDefaults, } from "../../shared/api" -import { convertToAnthropicMessage, convertToOpenAiMessages } from "../transform/openai-format" +import { ApiHandler } from "../index" +import { convertToOpenAiMessages } from "../transform/openai-format" +import { ApiStream } from "../transform/stream" export class OpenAiHandler implements ApiHandler { private options: ApiHandlerOptions @@ -30,37 +31,34 @@ export class OpenAiHandler implements ApiHandler { } } - async createMessage( - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - tools: Anthropic.Messages.Tool[] - ): Promise { + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ { role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages), ] - const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({ - type: "function", - function: { - name: tool.name, - description: tool.description, - parameters: tool.input_schema, - }, - })) - const createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { + const stream = await this.client.chat.completions.create({ model: this.options.openAiModelId ?? "", messages: openAiMessages, - temperature: 0.2, - tools: openAiTools, - tool_choice: "auto", + temperature: 0, + stream: true, + stream_options: { include_usage: true }, + }) + for await (const chunk of stream) { + const delta = chunk.choices[0]?.delta + if (delta?.content) { + yield { + type: "text", + text: delta.content, + } + } + if (chunk.usage) { + yield { + type: "usage", + inputTokens: chunk.usage.prompt_tokens || 0, + outputTokens: chunk.usage.completion_tokens || 0, + } + } } - const completion = await this.client.chat.completions.create(createParams) - const errorMessage = (completion as any).error?.message - if (errorMessage) { - throw new Error(errorMessage) - } - const anthropicMessage = convertToAnthropicMessage(completion) - return { message: anthropicMessage } } getModel(): { id: string; info: ModelInfo } { diff --git a/src/api/providers/openrouter.ts b/src/api/providers/openrouter.ts index b862b67..5fa4311 100644 --- a/src/api/providers/openrouter.ts +++ b/src/api/providers/openrouter.ts @@ -124,8 +124,8 @@ export class OpenRouterHandler implements ApiHandler { type: "usage", inputTokens: generation?.native_tokens_prompt || 0, outputTokens: generation?.native_tokens_completion || 0, - cacheWriteTokens: 0, - cacheReadTokens: 0, + // cacheWriteTokens: 0, + // cacheReadTokens: 0, totalCost: generation?.total_cost || 0, } } catch (error) { diff --git a/src/api/providers/vertex.ts b/src/api/providers/vertex.ts index e368883..6e02b80 100644 --- a/src/api/providers/vertex.ts +++ b/src/api/providers/vertex.ts @@ -1,7 +1,8 @@ -import { AnthropicVertex } from "@anthropic-ai/vertex-sdk" import { Anthropic } from "@anthropic-ai/sdk" -import { ApiHandler, ApiHandlerMessageResponse } from "../" +import { AnthropicVertex } from "@anthropic-ai/vertex-sdk" +import { ApiHandler } from "../" import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../../shared/api" +import { ApiStream } from "../transform/stream" // https://docs.anthropic.com/en/api/claude-on-vertex-ai export class VertexHandler implements ApiHandler { @@ -17,21 +18,61 @@ export class VertexHandler implements ApiHandler { }) } - async createMessage( - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - tools: Anthropic.Messages.Tool[] - ): Promise { - const message = await this.client.messages.create({ + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { + const stream = await this.client.messages.create({ model: this.getModel().id, max_tokens: this.getModel().info.maxTokens, - temperature: 0.2, + temperature: 0, system: systemPrompt, messages, - tools, - tool_choice: { type: "auto" }, + stream: true, }) - return { message } + for await (const chunk of stream) { + switch (chunk.type) { + case "message_start": + const usage = chunk.message.usage + yield { + type: "usage", + inputTokens: usage.input_tokens || 0, + outputTokens: usage.output_tokens || 0, + } + break + case "message_delta": + yield { + type: "usage", + inputTokens: 0, + outputTokens: chunk.usage.output_tokens || 0, + } + break + + case "content_block_start": + switch (chunk.content_block.type) { + case "text": + if (chunk.index > 0) { + yield { + type: "text", + text: "\n", + } + } + yield { + type: "text", + text: chunk.content_block.text, + } + break + } + break + case "content_block_delta": + switch (chunk.delta.type) { + case "text_delta": + yield { + type: "text", + text: chunk.delta.text, + } + break + } + break + } + } } getModel(): { id: VertexModelId; info: ModelInfo } { diff --git a/src/api/transform/stream.ts b/src/api/transform/stream.ts index 4f8c562..0290201 100644 --- a/src/api/transform/stream.ts +++ b/src/api/transform/stream.ts @@ -10,7 +10,7 @@ export interface ApiStreamUsageChunk { type: "usage" inputTokens: number outputTokens: number - cacheWriteTokens: number - cacheReadTokens: number + cacheWriteTokens?: number + cacheReadTokens?: number totalCost?: number // openrouter } diff --git a/src/core/ClaudeDev.ts b/src/core/ClaudeDev.ts index ad12a59..6ad7f3f 100644 --- a/src/core/ClaudeDev.ts +++ b/src/core/ClaudeDev.ts @@ -2397,8 +2397,8 @@ ${this.customInstructions.trim()} case "usage": inputTokens += chunk.inputTokens outputTokens += chunk.outputTokens - cacheWriteTokens += chunk.cacheWriteTokens - cacheReadTokens += chunk.cacheReadTokens + cacheWriteTokens += chunk.cacheWriteTokens ?? 0 + cacheReadTokens += chunk.cacheReadTokens ?? 0 totalCost = chunk.totalCost break case "text":