Get openrouter streaming working

This commit is contained in:
Saoud Rizwan
2024-09-29 01:00:01 -04:00
parent 7271152f62
commit 59c188019a
4 changed files with 66 additions and 206 deletions

View File

@@ -11,11 +11,7 @@ import { OpenAiNativeHandler } from "./providers/openai-native"
import { ApiStream } from "./transform/stream"
export interface ApiHandler {
createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): ApiStream
createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream
getModel(): { id: string; info: ModelInfo }
}

View File

@@ -1,6 +1,7 @@
import { Anthropic } from "@anthropic-ai/sdk"
import axios from "axios"
import OpenAI from "openai"
import { ApiHandler, ApiHandlerMessageResponse } from "../"
import { ApiHandler } from "../"
import {
ApiHandlerOptions,
ModelInfo,
@@ -8,9 +9,8 @@ import {
OpenRouterModelId,
openRouterModels,
} from "../../shared/api"
import { convertToAnthropicMessage, convertToOpenAiMessages } from "../transform/openai-format"
import axios from "axios"
import { convertO1ResponseToAnthropicMessage, convertToO1Messages } from "../transform/o1-format"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"
export class OpenRouterHandler implements ApiHandler {
private options: ApiHandlerOptions
@@ -28,11 +28,7 @@ export class OpenRouterHandler implements ApiHandler {
})
}
async createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<ApiHandlerMessageResponse> {
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
// Convert Anthropic messages to OpenAI format
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
@@ -77,206 +73,65 @@ export class OpenRouterHandler implements ApiHandler {
break
}
// Convert Anthropic tools to OpenAI tools
const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.input_schema, // matches anthropic tool input schema (see https://platform.openai.com/docs/guides/function-calling)
},
}))
let createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
switch (this.getModel().id) {
case "openai/o1-preview":
case "openai/o1-mini":
createParams = {
const stream = await this.client.chat.completions.create({
model: this.getModel().id,
max_tokens: this.getModel().info.maxTokens,
temperature: 0.2,
messages: convertToO1Messages(convertToOpenAiMessages(messages), systemPrompt),
}
break
default:
createParams = {
model: this.getModel().id,
max_tokens: this.getModel().info.maxTokens,
temperature: 0.2,
temperature: 0,
messages: openAiMessages,
tools: openAiTools,
tool_choice: "auto",
}
break
stream: true,
})
let genId: string | undefined
console.log("Starting stream processing for OpenRouter")
for await (const chunk of stream) {
console.log("Received chunk:", chunk)
// openrouter returns an error object instead of the openai sdk throwing an error
if ("error" in chunk) {
const error = chunk.error as { message?: string; code?: number }
console.error(`OpenRouter API Error: ${error?.code} - ${error?.message}`)
throw new Error(`OpenRouter API Error ${error?.code}: ${error?.message}`)
}
let completion: OpenAI.Chat.Completions.ChatCompletion
try {
completion = await this.client.chat.completions.create(createParams)
} catch (error) {
console.error("Error creating message from normal request. Using streaming fallback...", error)
completion = await this.streamCompletion(createParams)
}
const errorMessage = (completion as any).error?.message // openrouter returns an error object instead of the openai sdk throwing an error
if (errorMessage) {
throw new Error(errorMessage)
}
let anthropicMessage: Anthropic.Messages.Message
switch (this.getModel().id) {
case "openai/o1-preview":
case "openai/o1-mini":
anthropicMessage = convertO1ResponseToAnthropicMessage(completion)
break
default:
anthropicMessage = convertToAnthropicMessage(completion)
break
}
// Check if the model is Gemini Flash and remove extra escapes in tool result args
// switch (this.getModel().id) {
// case "google/gemini-pro-1.5":
// case "google/gemini-flash-1.5":
// const content = anthropicMessage.content
// for (const block of content) {
// if (
// block.type === "tool_use" &&
// typeof block.input === "object" &&
// block.input !== null &&
// "content" in block.input &&
// typeof block.input.content === "string"
// ) {
// block.input.content = unescapeGeminiContent(block.input.content)
// }
// }
// break
// default:
// break
// }
const genId = completion.id
// Log the generation details from OpenRouter API
if (!genId && chunk.id) {
genId = chunk.id
console.log("Generation ID set:", genId)
}
const delta = chunk.choices[0]?.delta
if (delta?.content) {
console.log("Yielding content:", delta.content)
yield {
type: "text",
text: delta.content,
}
}
}
console.log("Stream processing completed")
try {
console.log("Fetching generation details for ID:", genId)
const response = await axios.get(`https://openrouter.ai/api/v1/generation?id=${genId}`, {
headers: {
Authorization: `Bearer ${this.options.openRouterApiKey}`,
},
})
// @ts-ignore-next-line
anthropicMessage.usage.total_cost = response.data?.data?.total_cost
const generation = response.data?.data
console.log("OpenRouter generation details:", response.data)
console.log("Yielding usage information")
yield {
type: "usage",
inputTokens: generation?.native_tokens_prompt || 0,
outputTokens: generation?.native_tokens_completion || 0,
cacheWriteTokens: 0,
cacheReadTokens: 0,
totalCost: generation?.total_cost || 0,
}
} catch (error) {
// ignore if fails
console.error("Error fetching OpenRouter generation details:", error)
}
return { message: anthropicMessage }
}
/*
Streaming the completion is a fallback behavior for when a normal request responds with an invalid JSON object ("Unexpected end of JSON input"). This would usually happen in cases where the model makes tool calls with large arguments. After talking with OpenRouter folks, streaming mitigates this issue for now until they fix the underlying problem ("some weird data from anthropic got decoded wrongly and crashed the buffer")
*/
async streamCompletion(
createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
): Promise<OpenAI.Chat.Completions.ChatCompletion> {
const stream = await this.client.chat.completions.create({
...createParams,
stream: true,
})
let textContent: string = ""
let toolCalls: OpenAI.Chat.ChatCompletionMessageToolCall[] = []
try {
let currentToolCall: (OpenAI.Chat.ChatCompletionMessageToolCall & { index?: number }) | null = null
for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta
if (delta?.content) {
textContent += delta.content
}
if (delta?.tool_calls) {
for (const toolCallDelta of delta.tool_calls) {
if (toolCallDelta.index === undefined) {
continue
}
if (!currentToolCall || currentToolCall.index !== toolCallDelta.index) {
// new index means new tool call, so add the previous one to the list
if (currentToolCall) {
toolCalls.push(currentToolCall)
}
currentToolCall = {
index: toolCallDelta.index,
id: toolCallDelta.id || "",
type: "function",
function: { name: "", arguments: "" },
}
}
if (toolCallDelta.id) {
currentToolCall.id = toolCallDelta.id
}
if (toolCallDelta.type) {
currentToolCall.type = toolCallDelta.type
}
if (toolCallDelta.function) {
if (toolCallDelta.function.name) {
currentToolCall.function.name = toolCallDelta.function.name
}
if (toolCallDelta.function.arguments) {
currentToolCall.function.arguments =
(currentToolCall.function.arguments || "") + toolCallDelta.function.arguments
}
}
}
}
}
if (currentToolCall) {
toolCalls.push(currentToolCall)
}
} catch (error) {
console.error("Error streaming completion:", error)
throw error
}
// Usage information is not available in streaming responses, so we need to estimate token counts
function approximateTokenCount(text: string): number {
return Math.ceil(new TextEncoder().encode(text).length / 4)
}
const promptTokens = approximateTokenCount(
createParams.messages
.map((m) => (typeof m.content === "string" ? m.content : JSON.stringify(m.content)))
.join(" ")
)
const completionTokens = approximateTokenCount(
textContent + toolCalls.map((toolCall) => toolCall.function.arguments || "").join(" ")
)
const completion: OpenAI.Chat.Completions.ChatCompletion = {
created: Date.now(),
object: "chat.completion",
id: `openrouter-${Date.now()}-${Math.random().toString(36).slice(2, 11)}`, // this ID won't be traceable back to OpenRouter's systems if you need to debug issues
choices: [
{
message: {
role: "assistant",
content: textContent,
tool_calls: toolCalls.length > 0 ? toolCalls : undefined,
refusal: null,
},
finish_reason: toolCalls.length > 0 ? "tool_calls" : "stop",
index: 0,
logprobs: null,
},
],
model: this.getModel().id,
usage: {
prompt_tokens: promptTokens,
completion_tokens: completionTokens,
total_tokens: promptTokens + completionTokens,
},
}
return completion
}
getModel(): { id: OpenRouterModelId; info: ModelInfo } {

View File

@@ -577,7 +577,7 @@ export class ClaudeDev {
nextUserContent = [
{
type: "text",
text: "If you have completed the user's task, use the attempt_completion tool. If you require additional information from the user, use the ask_followup_question tool. Otherwise, if you have not completed the task and do not need additional information, then proceed with the next step of the task. (This is an automated message, so do not respond to it conversationally.)",
text: this.formatNoToolsResponse(),
},
]
this.consecutiveMistakeCount++
@@ -1242,11 +1242,7 @@ ${this.customInstructions.trim()}
}
}
}
const stream = this.api.createMessage(
systemPrompt,
this.apiConversationHistory,
TOOLS(cwd, this.api.getModel().info.supportsImages)
)
const stream = this.api.createMessage(systemPrompt, this.apiConversationHistory)
return stream
} catch (error) {
const { response } = await this.ask(
@@ -2467,6 +2463,15 @@ ${this.customInstructions.trim()}
await pWaitFor(() => this.userMessageContentReady)
// if the model did not tool use, then we need to tell it to either use a tool or attempt_completion
const didToolUse = this.assistantMessageContent.some((block) => block.type === "tool_call")
if (!didToolUse) {
this.userMessageContent.push({
type: "text",
text: this.formatNoToolsResponse(),
})
}
const recDidEndLoop = await this.recursivelyMakeClaudeRequests(this.userMessageContent)
didEndLoop = recDidEndLoop
} else {
@@ -2703,6 +2708,10 @@ ${this.customInstructions.trim()}
return `The tool execution failed with the following error:\n<error>\n${error}\n</error>`
}
formatNoToolsResponse() {
return "If you have completed the user's task, use the attempt_completion tool. If you require additional information from the user, use the ask_followup_question tool. Otherwise, if you have not completed the task and do not need additional information, then proceed with the next step of the task. (This is an automated message, so do not respond to it conversationally.)"
}
async sayAndCreateMissingParamError(toolName: ToolName, paramName: string, relPath?: string) {
await this.say(
"error",

View File

@@ -270,7 +270,7 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
{!isCostAvailable && <ExportButton />}
</div>
{(shouldShowPromptCacheInfo || cacheReads !== undefined || cacheWrites !== undefined) && (
{shouldShowPromptCacheInfo && (cacheReads !== undefined || cacheWrites !== undefined) && (
<div style={{ display: "flex", alignItems: "center", gap: "4px", flexWrap: "wrap" }}>
<span style={{ fontWeight: "bold" }}>Cache:</span>
<span style={{ display: "flex", alignItems: "center", gap: "3px" }}>