From 4008a1a53e3c2c6609f0409c1d92998f4747db27 Mon Sep 17 00:00:00 2001 From: Vignesh Subbiah Date: Tue, 28 Jan 2025 19:43:52 +0530 Subject: [PATCH] Modifying the usage of unbound.ts in compliance with all providers --- src/api/providers/__tests__/unbound.test.ts | 236 ++++++++++++++++---- src/api/providers/unbound.ts | 146 +++++++++--- 2 files changed, 304 insertions(+), 78 deletions(-) diff --git a/src/api/providers/__tests__/unbound.test.ts b/src/api/providers/__tests__/unbound.test.ts index 721ba53..7d11e6d 100644 --- a/src/api/providers/__tests__/unbound.test.ts +++ b/src/api/providers/__tests__/unbound.test.ts @@ -1,64 +1,210 @@ import { UnboundHandler } from "../unbound" import { ApiHandlerOptions } from "../../../shared/api" -import fetchMock from "jest-fetch-mock" +import OpenAI from "openai" +import { Anthropic } from "@anthropic-ai/sdk" -fetchMock.enableMocks() +// Mock OpenAI client +const mockCreate = jest.fn() +const mockWithResponse = jest.fn() + +jest.mock("openai", () => { + return { + __esModule: true, + default: jest.fn().mockImplementation(() => ({ + chat: { + completions: { + create: (...args: any[]) => { + const stream = { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { content: "Test response" }, + index: 0, + }, + ], + } + yield { + choices: [ + { + delta: {}, + index: 0, + }, + ], + } + }, + } + + const result = mockCreate(...args) + if (args[0].stream) { + mockWithResponse.mockReturnValue( + Promise.resolve({ + data: stream, + response: { headers: new Map() }, + }), + ) + result.withResponse = mockWithResponse + } + return result + }, + }, + }, + })), + } +}) describe("UnboundHandler", () => { - const mockOptions: ApiHandlerOptions = { - unboundApiKey: "test-api-key", - apiModelId: "test-model-id", - } + let handler: UnboundHandler + let mockOptions: ApiHandlerOptions beforeEach(() => { - fetchMock.resetMocks() - }) - - it("should initialize with options", () => { - const handler = new UnboundHandler(mockOptions) - expect(handler).toBeDefined() - }) - - it("should create a message successfully", async () => { - const handler = new UnboundHandler(mockOptions) - const mockResponse = { - choices: [{ message: { content: "Hello, world!" } }], - usage: { prompt_tokens: 5, completion_tokens: 7 }, + mockOptions = { + apiModelId: "anthropic/claude-3-5-sonnet-20241022", + unboundApiKey: "test-api-key", } + handler = new UnboundHandler(mockOptions) + mockCreate.mockClear() + mockWithResponse.mockClear() - fetchMock.mockResponseOnce(JSON.stringify(mockResponse)) - - const generator = handler.createMessage("system prompt", []) - const textResult = await generator.next() - const usageResult = await generator.next() - - expect(textResult.value).toEqual({ type: "text", text: "Hello, world!" }) - expect(usageResult.value).toEqual({ - type: "usage", - inputTokens: 5, - outputTokens: 7, + // Default mock implementation for non-streaming responses + mockCreate.mockResolvedValue({ + id: "test-completion", + choices: [ + { + message: { role: "assistant", content: "Test response" }, + finish_reason: "stop", + index: 0, + }, + ], }) }) - it("should handle API errors", async () => { - const handler = new UnboundHandler(mockOptions) - fetchMock.mockResponseOnce(JSON.stringify({ error: "API error" }), { status: 400 }) - - const generator = handler.createMessage("system prompt", []) - await expect(generator.next()).rejects.toThrow("Unbound Gateway completion error: API error") + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(UnboundHandler) + expect(handler.getModel().id).toBe(mockOptions.apiModelId) + }) }) - it("should handle network errors", async () => { - const handler = new UnboundHandler(mockOptions) - fetchMock.mockRejectOnce(new Error("Network error")) + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: "Hello!", + }, + ] - const generator = handler.createMessage("system prompt", []) - await expect(generator.next()).rejects.toThrow("Unbound Gateway completion error: Network error") + it("should handle streaming responses", async () => { + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBe(1) + expect(chunks[0]).toEqual({ + type: "text", + text: "Test response", + }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: "claude-3-5-sonnet-20241022", + messages: expect.any(Array), + stream: true, + }), + expect.objectContaining({ + headers: { + "X-Unbound-Metadata": expect.stringContaining("roo-code"), + }, + }), + ) + }) + + it("should handle API errors", async () => { + mockCreate.mockImplementationOnce(() => { + throw new Error("API Error") + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks = [] + + try { + for await (const chunk of stream) { + chunks.push(chunk) + } + fail("Expected error to be thrown") + } catch (error) { + expect(error).toBeInstanceOf(Error) + expect(error.message).toBe("API Error") + } + }) }) - it("should return the correct model", () => { - const handler = new UnboundHandler(mockOptions) - const model = handler.getModel() - expect(model.id).toBe("gpt-4o") + describe("completePrompt", () => { + it("should complete prompt successfully", async () => { + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test response") + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: "claude-3-5-sonnet-20241022", + messages: [{ role: "user", content: "Test prompt" }], + temperature: 0, + max_tokens: 8192, + }), + ) + }) + + it("should handle API errors", async () => { + mockCreate.mockRejectedValueOnce(new Error("API Error")) + await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Unbound completion error: API Error") + }) + + it("should handle empty response", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "" } }], + }) + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("") + }) + + it("should not set max_tokens for non-Anthropic models", async () => { + mockCreate.mockClear() + + const nonAnthropicOptions = { + apiModelId: "openai/gpt-4o", + unboundApiKey: "test-key", + } + const nonAnthropicHandler = new UnboundHandler(nonAnthropicOptions) + + await nonAnthropicHandler.completePrompt("Test prompt") + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: "gpt-4o", + messages: [{ role: "user", content: "Test prompt" }], + temperature: 0, + }), + ) + expect(mockCreate.mock.calls[0][0]).not.toHaveProperty("max_tokens") + }) + }) + + describe("getModel", () => { + it("should return model info", () => { + const modelInfo = handler.getModel() + expect(modelInfo.id).toBe(mockOptions.apiModelId) + expect(modelInfo.info).toBeDefined() + }) + + it("should return default model when invalid model provided", () => { + const handlerWithInvalidModel = new UnboundHandler({ + ...mockOptions, + apiModelId: "invalid/model", + }) + const modelInfo = handlerWithInvalidModel.getModel() + expect(modelInfo.id).toBe("openai/gpt-4o") // Default model + expect(modelInfo.info).toBeDefined() + }) }) }) diff --git a/src/api/providers/unbound.ts b/src/api/providers/unbound.ts index 842ab26..1992d71 100644 --- a/src/api/providers/unbound.ts +++ b/src/api/providers/unbound.ts @@ -1,50 +1,108 @@ -import { ApiHandlerOptions, unboundModels, UnboundModelId, unboundDefaultModelId, ModelInfo } from "../../shared/api" -import { ApiStream } from "../transform/stream" import { Anthropic } from "@anthropic-ai/sdk" -import { ApiHandler } from "../index" +import OpenAI from "openai" +import { ApiHandler, SingleCompletionHandler } from "../" +import { ApiHandlerOptions, ModelInfo, UnboundModelId, unboundDefaultModelId, unboundModels } from "../../shared/api" +import { convertToOpenAiMessages } from "../transform/openai-format" +import { ApiStream } from "../transform/stream" -export class UnboundHandler implements ApiHandler { - private unboundBaseUrl: string = "https://api.getunbound.ai/v1" +export class UnboundHandler implements ApiHandler, SingleCompletionHandler { private options: ApiHandlerOptions + private client: OpenAI constructor(options: ApiHandlerOptions) { this.options = options + this.client = new OpenAI({ + baseURL: "https://api.getunbound.ai/v1", + apiKey: this.options.unboundApiKey, + }) } async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - try { - const response = await fetch(`${this.unboundBaseUrl}/chat/completions`, { - method: "POST", - headers: { - Authorization: `Bearer ${this.options.unboundApiKey}`, - "Content-Type": "application/json", - }, - body: JSON.stringify({ - model: this.getModel().id.split("/")[1], - messages: [{ role: "system", content: systemPrompt }, ...messages], - }), + // Convert Anthropic messages to OpenAI format + const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ + { role: "system", content: systemPrompt }, + ...convertToOpenAiMessages(messages), + ] + + // this is specifically for claude models (some models may 'support prompt caching' automatically without this) + if (this.getModel().id.startsWith("anthropic/claude-3")) { + openAiMessages[0] = { + role: "system", + content: [ + { + type: "text", + text: systemPrompt, + // @ts-ignore-next-line + cache_control: { type: "ephemeral" }, + }, + ], + } + + // Add cache_control to the last two user messages + // (note: this works because we only ever add one user message at a time, + // but if we added multiple we'd need to mark the user message before the last assistant message) + const lastTwoUserMessages = openAiMessages.filter((msg) => msg.role === "user").slice(-2) + lastTwoUserMessages.forEach((msg) => { + if (typeof msg.content === "string") { + msg.content = [{ type: "text", text: msg.content }] + } + if (Array.isArray(msg.content)) { + // NOTE: this is fine since env details will always be added at the end. + // but if it weren't there, and the user added a image_url type message, + // it would pop a text part before it and then move it after to the end. + let lastTextPart = msg.content.filter((part) => part.type === "text").pop() + + if (!lastTextPart) { + lastTextPart = { type: "text", text: "..." } + msg.content.push(lastTextPart) + } + // @ts-ignore-next-line + lastTextPart["cache_control"] = { type: "ephemeral" } + } }) + } - const data = await response.json() + // Required by Anthropic + // Other providers default to max tokens allowed. + let maxTokens: number | undefined - if (!response.ok) { - throw new Error(data.error.message) - } + if (this.getModel().id.startsWith("anthropic/")) { + maxTokens = 8_192 + } - yield { - type: "text", - text: data.choices[0]?.message?.content || "", + const { data: completion, response } = await this.client.chat.completions + .create( + { + model: this.getModel().id.split("/")[1], + max_tokens: maxTokens, + temperature: 0, + messages: openAiMessages, + stream: true, + }, + { + headers: { + "X-Unbound-Metadata": JSON.stringify({ + labels: [ + { + key: "app", + value: "roo-code", + }, + ], + }), + }, + }, + ) + .withResponse() + + for await (const chunk of completion) { + const delta = chunk.choices[0]?.delta + + if (delta?.content) { + yield { + type: "text", + text: delta.content, + } } - yield { - type: "usage", - inputTokens: data.usage?.prompt_tokens || 0, - outputTokens: data.usage?.completion_tokens || 0, - } - } catch (error) { - if (error instanceof Error) { - throw new Error(`Unbound Gateway completion error:\n ${error.message}`) - } - throw error } } @@ -59,4 +117,26 @@ export class UnboundHandler implements ApiHandler { info: unboundModels[unboundDefaultModelId], } } + + async completePrompt(prompt: string): Promise { + try { + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { + model: this.getModel().id.split("/")[1], + messages: [{ role: "user", content: prompt }], + temperature: 0, + } + + if (this.getModel().id.startsWith("anthropic/")) { + requestOptions.max_tokens = 8192 + } + + const response = await this.client.chat.completions.create(requestOptions) + return response.choices[0]?.message.content || "" + } catch (error) { + if (error instanceof Error) { + throw new Error(`Unbound completion error: ${error.message}`) + } + throw error + } + } }