diff --git a/src/api/index.ts b/src/api/index.ts index 641c50d..b3927b4 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -14,6 +14,7 @@ import { DeepSeekHandler } from "./providers/deepseek" import { MistralHandler } from "./providers/mistral" import { VsCodeLmHandler } from "./providers/vscode-lm" import { ApiStream } from "./transform/stream" +import { UnboundHandler } from "./providers/unbound" export interface SingleCompletionHandler { completePrompt(prompt: string): Promise @@ -53,6 +54,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler { return new VsCodeLmHandler(options) case "mistral": return new MistralHandler(options) + case "unbound": + return new UnboundHandler(options) default: return new AnthropicHandler(options) } diff --git a/src/api/providers/__tests__/unbound.test.ts b/src/api/providers/__tests__/unbound.test.ts new file mode 100644 index 0000000..7d11e6d --- /dev/null +++ b/src/api/providers/__tests__/unbound.test.ts @@ -0,0 +1,210 @@ +import { UnboundHandler } from "../unbound" +import { ApiHandlerOptions } from "../../../shared/api" +import OpenAI from "openai" +import { Anthropic } from "@anthropic-ai/sdk" + +// Mock OpenAI client +const mockCreate = jest.fn() +const mockWithResponse = jest.fn() + +jest.mock("openai", () => { + return { + __esModule: true, + default: jest.fn().mockImplementation(() => ({ + chat: { + completions: { + create: (...args: any[]) => { + const stream = { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { content: "Test response" }, + index: 0, + }, + ], + } + yield { + choices: [ + { + delta: {}, + index: 0, + }, + ], + } + }, + } + + const result = mockCreate(...args) + if (args[0].stream) { + mockWithResponse.mockReturnValue( + Promise.resolve({ + data: stream, + response: { headers: new Map() }, + }), + ) + result.withResponse = mockWithResponse + } + return result + }, + }, + }, + })), + } +}) + +describe("UnboundHandler", () => { + let handler: UnboundHandler + let mockOptions: ApiHandlerOptions + + beforeEach(() => { + mockOptions = { + apiModelId: "anthropic/claude-3-5-sonnet-20241022", + unboundApiKey: "test-api-key", + } + handler = new UnboundHandler(mockOptions) + mockCreate.mockClear() + mockWithResponse.mockClear() + + // Default mock implementation for non-streaming responses + mockCreate.mockResolvedValue({ + id: "test-completion", + choices: [ + { + message: { role: "assistant", content: "Test response" }, + finish_reason: "stop", + index: 0, + }, + ], + }) + }) + + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(UnboundHandler) + expect(handler.getModel().id).toBe(mockOptions.apiModelId) + }) + }) + + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: "Hello!", + }, + ] + + it("should handle streaming responses", async () => { + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBe(1) + expect(chunks[0]).toEqual({ + type: "text", + text: "Test response", + }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: "claude-3-5-sonnet-20241022", + messages: expect.any(Array), + stream: true, + }), + expect.objectContaining({ + headers: { + "X-Unbound-Metadata": expect.stringContaining("roo-code"), + }, + }), + ) + }) + + it("should handle API errors", async () => { + mockCreate.mockImplementationOnce(() => { + throw new Error("API Error") + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks = [] + + try { + for await (const chunk of stream) { + chunks.push(chunk) + } + fail("Expected error to be thrown") + } catch (error) { + expect(error).toBeInstanceOf(Error) + expect(error.message).toBe("API Error") + } + }) + }) + + describe("completePrompt", () => { + it("should complete prompt successfully", async () => { + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test response") + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: "claude-3-5-sonnet-20241022", + messages: [{ role: "user", content: "Test prompt" }], + temperature: 0, + max_tokens: 8192, + }), + ) + }) + + it("should handle API errors", async () => { + mockCreate.mockRejectedValueOnce(new Error("API Error")) + await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Unbound completion error: API Error") + }) + + it("should handle empty response", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "" } }], + }) + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("") + }) + + it("should not set max_tokens for non-Anthropic models", async () => { + mockCreate.mockClear() + + const nonAnthropicOptions = { + apiModelId: "openai/gpt-4o", + unboundApiKey: "test-key", + } + const nonAnthropicHandler = new UnboundHandler(nonAnthropicOptions) + + await nonAnthropicHandler.completePrompt("Test prompt") + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: "gpt-4o", + messages: [{ role: "user", content: "Test prompt" }], + temperature: 0, + }), + ) + expect(mockCreate.mock.calls[0][0]).not.toHaveProperty("max_tokens") + }) + }) + + describe("getModel", () => { + it("should return model info", () => { + const modelInfo = handler.getModel() + expect(modelInfo.id).toBe(mockOptions.apiModelId) + expect(modelInfo.info).toBeDefined() + }) + + it("should return default model when invalid model provided", () => { + const handlerWithInvalidModel = new UnboundHandler({ + ...mockOptions, + apiModelId: "invalid/model", + }) + const modelInfo = handlerWithInvalidModel.getModel() + expect(modelInfo.id).toBe("openai/gpt-4o") // Default model + expect(modelInfo.info).toBeDefined() + }) + }) +}) diff --git a/src/api/providers/unbound.ts b/src/api/providers/unbound.ts new file mode 100644 index 0000000..23e419c --- /dev/null +++ b/src/api/providers/unbound.ts @@ -0,0 +1,151 @@ +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" +import { ApiHandler, SingleCompletionHandler } from "../" +import { ApiHandlerOptions, ModelInfo, UnboundModelId, unboundDefaultModelId, unboundModels } from "../../shared/api" +import { convertToOpenAiMessages } from "../transform/openai-format" +import { ApiStream } from "../transform/stream" + +export class UnboundHandler implements ApiHandler, SingleCompletionHandler { + private options: ApiHandlerOptions + private client: OpenAI + + constructor(options: ApiHandlerOptions) { + this.options = options + this.client = new OpenAI({ + baseURL: "https://api.getunbound.ai/v1", + apiKey: this.options.unboundApiKey, + }) + } + + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { + // Convert Anthropic messages to OpenAI format + const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ + { role: "system", content: systemPrompt }, + ...convertToOpenAiMessages(messages), + ] + + // this is specifically for claude models (some models may 'support prompt caching' automatically without this) + if (this.getModel().id.startsWith("anthropic/claude-3")) { + openAiMessages[0] = { + role: "system", + content: [ + { + type: "text", + text: systemPrompt, + // @ts-ignore-next-line + cache_control: { type: "ephemeral" }, + }, + ], + } + + // Add cache_control to the last two user messages + // (note: this works because we only ever add one user message at a time, + // but if we added multiple we'd need to mark the user message before the last assistant message) + const lastTwoUserMessages = openAiMessages.filter((msg) => msg.role === "user").slice(-2) + lastTwoUserMessages.forEach((msg) => { + if (typeof msg.content === "string") { + msg.content = [{ type: "text", text: msg.content }] + } + if (Array.isArray(msg.content)) { + // NOTE: this is fine since env details will always be added at the end. + // but if it weren't there, and the user added a image_url type message, + // it would pop a text part before it and then move it after to the end. + let lastTextPart = msg.content.filter((part) => part.type === "text").pop() + + if (!lastTextPart) { + lastTextPart = { type: "text", text: "..." } + msg.content.push(lastTextPart) + } + // @ts-ignore-next-line + lastTextPart["cache_control"] = { type: "ephemeral" } + } + }) + } + + // Required by Anthropic + // Other providers default to max tokens allowed. + let maxTokens: number | undefined + + if (this.getModel().id.startsWith("anthropic/")) { + maxTokens = 8_192 + } + + const { data: completion, response } = await this.client.chat.completions + .create( + { + model: this.getModel().id.split("/")[1], + max_tokens: maxTokens, + temperature: 0, + messages: openAiMessages, + stream: true, + }, + { + headers: { + "X-Unbound-Metadata": JSON.stringify({ + labels: [ + { + key: "app", + value: "roo-code", + }, + ], + }), + }, + }, + ) + .withResponse() + + for await (const chunk of completion) { + const delta = chunk.choices[0]?.delta + const usage = chunk.usage + + if (delta?.content) { + yield { + type: "text", + text: delta.content, + } + } + + if (usage) { + yield { + type: "usage", + inputTokens: usage?.prompt_tokens || 0, + outputTokens: usage?.completion_tokens || 0, + } + } + } + } + + getModel(): { id: UnboundModelId; info: ModelInfo } { + const modelId = this.options.apiModelId + if (modelId && modelId in unboundModels) { + const id = modelId as UnboundModelId + return { id, info: unboundModels[id] } + } + return { + id: unboundDefaultModelId, + info: unboundModels[unboundDefaultModelId], + } + } + + async completePrompt(prompt: string): Promise { + try { + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { + model: this.getModel().id.split("/")[1], + messages: [{ role: "user", content: prompt }], + temperature: 0, + } + + if (this.getModel().id.startsWith("anthropic/")) { + requestOptions.max_tokens = 8192 + } + + const response = await this.client.chat.completions.create(requestOptions) + return response.choices[0]?.message.content || "" + } catch (error) { + if (error instanceof Error) { + throw new Error(`Unbound completion error: ${error.message}`) + } + throw error + } + } +} diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 1c15369..ba6a06e 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -63,6 +63,7 @@ type SecretKey = | "openAiNativeApiKey" | "deepSeekApiKey" | "mistralApiKey" + | "unboundApiKey" type GlobalStateKey = | "apiProvider" | "apiModelId" @@ -122,6 +123,7 @@ type GlobalStateKey = | "experiments" // Map of experiment IDs to their enabled state | "autoApprovalEnabled" | "customModes" // Array of custom modes + | "unboundModelId" export const GlobalFileNames = { apiConversationHistory: "api_conversation_history.json", @@ -1397,6 +1399,8 @@ export class ClineProvider implements vscode.WebviewViewProvider { openRouterUseMiddleOutTransform, vsCodeLmModelSelector, mistralApiKey, + unboundApiKey, + unboundModelId, } = apiConfiguration await this.updateGlobalState("apiProvider", apiProvider) await this.updateGlobalState("apiModelId", apiModelId) @@ -1435,6 +1439,8 @@ export class ClineProvider implements vscode.WebviewViewProvider { await this.updateGlobalState("openRouterUseMiddleOutTransform", openRouterUseMiddleOutTransform) await this.updateGlobalState("vsCodeLmModelSelector", vsCodeLmModelSelector) await this.storeSecret("mistralApiKey", mistralApiKey) + await this.storeSecret("unboundApiKey", unboundApiKey) + await this.updateGlobalState("unboundModelId", unboundModelId) if (this.cline) { this.cline.api = buildApiHandler(apiConfiguration) } @@ -2102,6 +2108,8 @@ export class ClineProvider implements vscode.WebviewViewProvider { autoApprovalEnabled, customModes, experiments, + unboundApiKey, + unboundModelId, ] = await Promise.all([ this.getGlobalState("apiProvider") as Promise, this.getGlobalState("apiModelId") as Promise, @@ -2172,6 +2180,8 @@ export class ClineProvider implements vscode.WebviewViewProvider { this.getGlobalState("autoApprovalEnabled") as Promise, this.customModesManager.getCustomModes(), this.getGlobalState("experiments") as Promise | undefined>, + this.getSecret("unboundApiKey") as Promise, + this.getGlobalState("unboundModelId") as Promise, ]) let apiProvider: ApiProvider @@ -2227,6 +2237,8 @@ export class ClineProvider implements vscode.WebviewViewProvider { openRouterBaseUrl, openRouterUseMiddleOutTransform, vsCodeLmModelSelector, + unboundApiKey, + unboundModelId, }, lastShownAnnouncementId, customInstructions, @@ -2376,6 +2388,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { "openAiNativeApiKey", "deepSeekApiKey", "mistralApiKey", + "unboundApiKey", ] for (const key of secretKeys) { await this.storeSecret(key, undefined) diff --git a/src/shared/api.ts b/src/shared/api.ts index 950b94b..e5bcda4 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -14,6 +14,7 @@ export type ApiProvider = | "deepseek" | "vscode-lm" | "mistral" + | "unbound" export interface ApiHandlerOptions { apiModelId?: string @@ -57,6 +58,8 @@ export interface ApiHandlerOptions { deepSeekBaseUrl?: string deepSeekApiKey?: string includeMaxTokens?: boolean + unboundApiKey?: string + unboundModelId?: string } export type ApiConfiguration = ApiHandlerOptions & { @@ -593,3 +596,14 @@ export const mistralModels = { outputPrice: 0.9, }, } as const satisfies Record + +// Unbound Security +export type UnboundModelId = keyof typeof unboundModels +export const unboundDefaultModelId = "openai/gpt-4o" +export const unboundModels = { + "anthropic/claude-3-5-sonnet-20241022": anthropicModels["claude-3-5-sonnet-20241022"], + "openai/gpt-4o": openAiNativeModels["gpt-4o"], + "deepseek/deepseek-chat": deepSeekModels["deepseek-chat"], + "deepseek/deepseek-reasoner": deepSeekModels["deepseek-reasoner"], + "mistral/codestral-latest": mistralModels["codestral-latest"], +} as const satisfies Record diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 1be00c7..1199914 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -26,6 +26,8 @@ import { openRouterDefaultModelInfo, vertexDefaultModelId, vertexModels, + unboundDefaultModelId, + unboundModels, } from "../../../../src/shared/api" import { ExtensionMessage } from "../../../../src/shared/ExtensionMessage" import { useExtensionState } from "../../context/ExtensionStateContext" @@ -147,6 +149,7 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = { value: "mistral", label: "Mistral" }, { value: "lmstudio", label: "LM Studio" }, { value: "ollama", label: "Ollama" }, + { value: "unbound", label: "Unbound" }, ]} /> @@ -1283,6 +1286,35 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = )} + {selectedProvider === "unbound" && ( +
+ + Unbound API Key + + {!apiConfiguration?.unboundApiKey && ( + + Get Unbound API Key + + )} +

+ This key is stored locally and only used to make API requests from this extension. +

+
+ )} + {apiErrorMessage && (