From 4027e1c10c92dd3362a2afee1fd00db177f5053a Mon Sep 17 00:00:00 2001 From: Matt Rubens Date: Mon, 13 Jan 2025 16:16:58 -0500 Subject: [PATCH] Add non-streaming completePrompt to all providers --- src/api/providers/__tests__/anthropic.test.ts | 73 +++- src/api/providers/__tests__/bedrock.test.ts | 102 +++++ src/api/providers/__tests__/gemini.test.ts | 60 ++- src/api/providers/__tests__/glama.test.ts | 226 ++++++++++ src/api/providers/__tests__/lmstudio.test.ts | 212 ++++----- src/api/providers/__tests__/ollama.test.ts | 208 ++++----- .../providers/__tests__/openai-native.test.ts | 407 +++++++++--------- src/api/providers/__tests__/openai.test.ts | 26 ++ src/api/providers/__tests__/vertex.test.ts | 80 +++- src/api/providers/anthropic.ts | 27 +- src/api/providers/bedrock.ts | 73 +++- src/api/providers/gemini.ts | 26 +- src/api/providers/glama.ts | 26 +- src/api/providers/lmstudio.ts | 20 +- src/api/providers/ollama.ts | 21 +- src/api/providers/openai-native.ts | 37 +- src/api/providers/openai.ts | 22 +- src/api/providers/vertex.ts | 27 +- 18 files changed, 1235 insertions(+), 438 deletions(-) create mode 100644 src/api/providers/__tests__/glama.test.ts diff --git a/src/api/providers/__tests__/anthropic.test.ts b/src/api/providers/__tests__/anthropic.test.ts index f730f78..d0357d7 100644 --- a/src/api/providers/__tests__/anthropic.test.ts +++ b/src/api/providers/__tests__/anthropic.test.ts @@ -46,7 +46,42 @@ jest.mock('@anthropic-ai/sdk', () => { } }, messages: { - create: mockCreate + create: mockCreate.mockImplementation(async (options) => { + if (!options.stream) { + return { + id: 'test-completion', + content: [ + { type: 'text', text: 'Test response' } + ], + role: 'assistant', + model: options.model, + usage: { + input_tokens: 10, + output_tokens: 5 + } + } + } + return { + async *[Symbol.asyncIterator]() { + yield { + type: 'message_start', + message: { + usage: { + input_tokens: 10, + output_tokens: 5 + } + } + } + yield { + type: 'content_block_start', + content_block: { + type: 'text', + text: 'Test response' + } + } + } + } + }) } })) }; @@ -144,6 +179,42 @@ describe('AnthropicHandler', () => { }); }); + describe('completePrompt', () => { + it('should complete prompt successfully', async () => { + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe('Test response'); + expect(mockCreate).toHaveBeenCalledWith({ + model: mockOptions.apiModelId, + messages: [{ role: 'user', content: 'Test prompt' }], + max_tokens: 8192, + temperature: 0, + stream: false + }); + }); + + it('should handle API errors', async () => { + mockCreate.mockRejectedValueOnce(new Error('API Error')); + await expect(handler.completePrompt('Test prompt')) + .rejects.toThrow('Anthropic completion error: API Error'); + }); + + it('should handle non-text content', async () => { + mockCreate.mockImplementationOnce(async () => ({ + content: [{ type: 'image' }] + })); + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe(''); + }); + + it('should handle empty response', async () => { + mockCreate.mockImplementationOnce(async () => ({ + content: [{ type: 'text', text: '' }] + })); + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe(''); + }); + }); + describe('getModel', () => { it('should return default model if no model ID is provided', () => { const handlerWithoutModel = new AnthropicHandler({ diff --git a/src/api/providers/__tests__/bedrock.test.ts b/src/api/providers/__tests__/bedrock.test.ts index 36cccc1..e8e3f3a 100644 --- a/src/api/providers/__tests__/bedrock.test.ts +++ b/src/api/providers/__tests__/bedrock.test.ts @@ -119,6 +119,108 @@ describe('AwsBedrockHandler', () => { }); }); + describe('completePrompt', () => { + it('should complete prompt successfully', async () => { + const mockResponse = { + output: new TextEncoder().encode(JSON.stringify({ + content: 'Test response' + })) + }; + + const mockSend = jest.fn().mockResolvedValue(mockResponse); + handler['client'] = { + send: mockSend + } as unknown as BedrockRuntimeClient; + + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe('Test response'); + expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({ + input: expect.objectContaining({ + modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0', + messages: expect.arrayContaining([ + expect.objectContaining({ + role: 'user', + content: [{ text: 'Test prompt' }] + }) + ]), + inferenceConfig: expect.objectContaining({ + maxTokens: 5000, + temperature: 0.3, + topP: 0.1 + }) + }) + })); + }); + + it('should handle API errors', async () => { + const mockError = new Error('AWS Bedrock error'); + const mockSend = jest.fn().mockRejectedValue(mockError); + handler['client'] = { + send: mockSend + } as unknown as BedrockRuntimeClient; + + await expect(handler.completePrompt('Test prompt')) + .rejects.toThrow('Bedrock completion error: AWS Bedrock error'); + }); + + it('should handle invalid response format', async () => { + const mockResponse = { + output: new TextEncoder().encode('invalid json') + }; + + const mockSend = jest.fn().mockResolvedValue(mockResponse); + handler['client'] = { + send: mockSend + } as unknown as BedrockRuntimeClient; + + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe(''); + }); + + it('should handle empty response', async () => { + const mockResponse = { + output: new TextEncoder().encode(JSON.stringify({})) + }; + + const mockSend = jest.fn().mockResolvedValue(mockResponse); + handler['client'] = { + send: mockSend + } as unknown as BedrockRuntimeClient; + + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe(''); + }); + + it('should handle cross-region inference', async () => { + handler = new AwsBedrockHandler({ + apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0', + awsAccessKey: 'test-access-key', + awsSecretKey: 'test-secret-key', + awsRegion: 'us-east-1', + awsUseCrossRegionInference: true + }); + + const mockResponse = { + output: new TextEncoder().encode(JSON.stringify({ + content: 'Test response' + })) + }; + + const mockSend = jest.fn().mockResolvedValue(mockResponse); + handler['client'] = { + send: mockSend + } as unknown as BedrockRuntimeClient; + + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe('Test response'); + expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({ + input: expect.objectContaining({ + modelId: 'us.anthropic.claude-3-5-sonnet-20241022-v2:0' + }) + })); + }); + }); + describe('getModel', () => { it('should return correct model info in test environment', () => { const modelInfo = handler.getModel(); diff --git a/src/api/providers/__tests__/gemini.test.ts b/src/api/providers/__tests__/gemini.test.ts index b979714..a59028e 100644 --- a/src/api/providers/__tests__/gemini.test.ts +++ b/src/api/providers/__tests__/gemini.test.ts @@ -6,7 +6,12 @@ import { GoogleGenerativeAI } from '@google/generative-ai'; jest.mock('@google/generative-ai', () => ({ GoogleGenerativeAI: jest.fn().mockImplementation(() => ({ getGenerativeModel: jest.fn().mockReturnValue({ - generateContentStream: jest.fn() + generateContentStream: jest.fn(), + generateContent: jest.fn().mockResolvedValue({ + response: { + text: () => 'Test response' + } + }) }) })) })); @@ -133,6 +138,59 @@ describe('GeminiHandler', () => { }); }); + describe('completePrompt', () => { + it('should complete prompt successfully', async () => { + const mockGenerateContent = jest.fn().mockResolvedValue({ + response: { + text: () => 'Test response' + } + }); + const mockGetGenerativeModel = jest.fn().mockReturnValue({ + generateContent: mockGenerateContent + }); + (handler['client'] as any).getGenerativeModel = mockGetGenerativeModel; + + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe('Test response'); + expect(mockGetGenerativeModel).toHaveBeenCalledWith({ + model: 'gemini-2.0-flash-thinking-exp-1219' + }); + expect(mockGenerateContent).toHaveBeenCalledWith({ + contents: [{ role: 'user', parts: [{ text: 'Test prompt' }] }], + generationConfig: { + temperature: 0 + } + }); + }); + + it('should handle API errors', async () => { + const mockError = new Error('Gemini API error'); + const mockGenerateContent = jest.fn().mockRejectedValue(mockError); + const mockGetGenerativeModel = jest.fn().mockReturnValue({ + generateContent: mockGenerateContent + }); + (handler['client'] as any).getGenerativeModel = mockGetGenerativeModel; + + await expect(handler.completePrompt('Test prompt')) + .rejects.toThrow('Gemini completion error: Gemini API error'); + }); + + it('should handle empty response', async () => { + const mockGenerateContent = jest.fn().mockResolvedValue({ + response: { + text: () => '' + } + }); + const mockGetGenerativeModel = jest.fn().mockReturnValue({ + generateContent: mockGenerateContent + }); + (handler['client'] as any).getGenerativeModel = mockGetGenerativeModel; + + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe(''); + }); + }); + describe('getModel', () => { it('should return correct model info', () => { const modelInfo = handler.getModel(); diff --git a/src/api/providers/__tests__/glama.test.ts b/src/api/providers/__tests__/glama.test.ts new file mode 100644 index 0000000..e67b80e --- /dev/null +++ b/src/api/providers/__tests__/glama.test.ts @@ -0,0 +1,226 @@ +import { GlamaHandler } from '../glama'; +import { ApiHandlerOptions } from '../../../shared/api'; +import OpenAI from 'openai'; +import { Anthropic } from '@anthropic-ai/sdk'; +import axios from 'axios'; + +// Mock OpenAI client +const mockCreate = jest.fn(); +const mockWithResponse = jest.fn(); + +jest.mock('openai', () => { + return { + __esModule: true, + default: jest.fn().mockImplementation(() => ({ + chat: { + completions: { + create: (...args: any[]) => { + const stream = { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [{ + delta: { content: 'Test response' }, + index: 0 + }], + usage: null + }; + yield { + choices: [{ + delta: {}, + index: 0 + }], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15 + } + }; + } + }; + + const result = mockCreate(...args); + if (args[0].stream) { + mockWithResponse.mockReturnValue(Promise.resolve({ + data: stream, + response: { + headers: { + get: (name: string) => name === 'x-completion-request-id' ? 'test-request-id' : null + } + } + })); + result.withResponse = mockWithResponse; + } + return result; + } + } + } + })) + }; +}); + +describe('GlamaHandler', () => { + let handler: GlamaHandler; + let mockOptions: ApiHandlerOptions; + + beforeEach(() => { + mockOptions = { + apiModelId: 'anthropic/claude-3-5-sonnet', + glamaModelId: 'anthropic/claude-3-5-sonnet', + glamaApiKey: 'test-api-key' + }; + handler = new GlamaHandler(mockOptions); + mockCreate.mockClear(); + mockWithResponse.mockClear(); + + // Default mock implementation for non-streaming responses + mockCreate.mockResolvedValue({ + id: 'test-completion', + choices: [{ + message: { role: 'assistant', content: 'Test response' }, + finish_reason: 'stop', + index: 0 + }], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15 + } + }); + }); + + describe('constructor', () => { + it('should initialize with provided options', () => { + expect(handler).toBeInstanceOf(GlamaHandler); + expect(handler.getModel().id).toBe(mockOptions.apiModelId); + }); + }); + + describe('createMessage', () => { + const systemPrompt = 'You are a helpful assistant.'; + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: 'user', + content: 'Hello!' + } + ]; + + it('should handle streaming responses', async () => { + // Mock axios for token usage request + const mockAxios = jest.spyOn(axios, 'get').mockResolvedValueOnce({ + data: { + tokenUsage: { + promptTokens: 10, + completionTokens: 5, + cacheCreationInputTokens: 0, + cacheReadInputTokens: 0 + }, + totalCostUsd: "0.00" + } + }); + + const stream = handler.createMessage(systemPrompt, messages); + const chunks: any[] = []; + for await (const chunk of stream) { + chunks.push(chunk); + } + + expect(chunks.length).toBe(2); // Text chunk and usage chunk + expect(chunks[0]).toEqual({ + type: 'text', + text: 'Test response' + }); + expect(chunks[1]).toEqual({ + type: 'usage', + inputTokens: 10, + outputTokens: 5, + cacheWriteTokens: 0, + cacheReadTokens: 0, + totalCost: 0 + }); + + mockAxios.mockRestore(); + }); + + it('should handle API errors', async () => { + mockCreate.mockImplementationOnce(() => { + throw new Error('API Error'); + }); + + const stream = handler.createMessage(systemPrompt, messages); + const chunks = []; + + try { + for await (const chunk of stream) { + chunks.push(chunk); + } + fail('Expected error to be thrown'); + } catch (error) { + expect(error).toBeInstanceOf(Error); + expect(error.message).toBe('API Error'); + } + }); + }); + + describe('completePrompt', () => { + it('should complete prompt successfully', async () => { + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe('Test response'); + expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ + model: mockOptions.apiModelId, + messages: [{ role: 'user', content: 'Test prompt' }], + temperature: 0, + max_tokens: 8192 + })); + }); + + it('should handle API errors', async () => { + mockCreate.mockRejectedValueOnce(new Error('API Error')); + await expect(handler.completePrompt('Test prompt')) + .rejects.toThrow('Glama completion error: API Error'); + }); + + it('should handle empty response', async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: '' } }] + }); + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe(''); + }); + + it('should not set max_tokens for non-Anthropic models', async () => { + // Reset mock to clear any previous calls + mockCreate.mockClear(); + + const nonAnthropicOptions = { + apiModelId: 'openai/gpt-4', + glamaModelId: 'openai/gpt-4', + glamaApiKey: 'test-key', + glamaModelInfo: { + maxTokens: 4096, + contextWindow: 8192, + supportsImages: true, + supportsPromptCache: false + } + }; + const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions); + + await nonAnthropicHandler.completePrompt('Test prompt'); + expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ + model: 'openai/gpt-4', + messages: [{ role: 'user', content: 'Test prompt' }], + temperature: 0 + })); + expect(mockCreate.mock.calls[0][0]).not.toHaveProperty('max_tokens'); + }); + }); + + describe('getModel', () => { + it('should return model info', () => { + const modelInfo = handler.getModel(); + expect(modelInfo.id).toBe(mockOptions.apiModelId); + expect(modelInfo.info).toBeDefined(); + expect(modelInfo.info.maxTokens).toBe(8192); + expect(modelInfo.info.contextWindow).toBe(200_000); + }); + }); +}); \ No newline at end of file diff --git a/src/api/providers/__tests__/lmstudio.test.ts b/src/api/providers/__tests__/lmstudio.test.ts index 9e24053..6b84796 100644 --- a/src/api/providers/__tests__/lmstudio.test.ts +++ b/src/api/providers/__tests__/lmstudio.test.ts @@ -1,148 +1,160 @@ import { LmStudioHandler } from '../lmstudio'; -import { Anthropic } from '@anthropic-ai/sdk'; +import { ApiHandlerOptions } from '../../../shared/api'; import OpenAI from 'openai'; +import { Anthropic } from '@anthropic-ai/sdk'; -// Mock OpenAI SDK -jest.mock('openai', () => ({ - __esModule: true, - default: jest.fn().mockImplementation(() => ({ - chat: { - completions: { - create: jest.fn() +// Mock OpenAI client +const mockCreate = jest.fn(); +jest.mock('openai', () => { + return { + __esModule: true, + default: jest.fn().mockImplementation(() => ({ + chat: { + completions: { + create: mockCreate.mockImplementation(async (options) => { + if (!options.stream) { + return { + id: 'test-completion', + choices: [{ + message: { role: 'assistant', content: 'Test response' }, + finish_reason: 'stop', + index: 0 + }], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15 + } + }; + } + + return { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [{ + delta: { content: 'Test response' }, + index: 0 + }], + usage: null + }; + yield { + choices: [{ + delta: {}, + index: 0 + }], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15 + } + }; + } + }; + }) + } } - } - })) -})); + })) + }; +}); describe('LmStudioHandler', () => { let handler: LmStudioHandler; + let mockOptions: ApiHandlerOptions; beforeEach(() => { - handler = new LmStudioHandler({ - lmStudioModelId: 'mistral-7b', - lmStudioBaseUrl: 'http://localhost:1234' - }); + mockOptions = { + apiModelId: 'local-model', + lmStudioModelId: 'local-model', + lmStudioBaseUrl: 'http://localhost:1234/v1' + }; + handler = new LmStudioHandler(mockOptions); + mockCreate.mockClear(); }); describe('constructor', () => { - it('should initialize with provided config', () => { - expect(OpenAI).toHaveBeenCalledWith({ - baseURL: 'http://localhost:1234/v1', - apiKey: 'noop' - }); + it('should initialize with provided options', () => { + expect(handler).toBeInstanceOf(LmStudioHandler); + expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId); }); it('should use default base URL if not provided', () => { - const defaultHandler = new LmStudioHandler({ - lmStudioModelId: 'mistral-7b' - }); - - expect(OpenAI).toHaveBeenCalledWith({ - baseURL: 'http://localhost:1234/v1', - apiKey: 'noop' + const handlerWithoutUrl = new LmStudioHandler({ + apiModelId: 'local-model', + lmStudioModelId: 'local-model' }); + expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler); }); }); describe('createMessage', () => { - const mockMessages: Anthropic.Messages.MessageParam[] = [ + const systemPrompt = 'You are a helpful assistant.'; + const messages: Anthropic.Messages.MessageParam[] = [ { role: 'user', - content: 'Hello' - }, - { - role: 'assistant', - content: 'Hi there!' + content: 'Hello!' } ]; - const systemPrompt = 'You are a helpful assistant'; - - it('should handle streaming responses correctly', async () => { - const mockStream = [ - { - choices: [{ - delta: { content: 'Hello' } - }] - }, - { - choices: [{ - delta: { content: ' world!' } - }] - } - ]; - - // Setup async iterator for mock stream - const asyncIterator = { - async *[Symbol.asyncIterator]() { - for (const chunk of mockStream) { - yield chunk; - } - } - }; - - const mockCreate = jest.fn().mockResolvedValue(asyncIterator); - (handler['client'].chat.completions as any).create = mockCreate; - - const stream = handler.createMessage(systemPrompt, mockMessages); - const chunks = []; - + it('should handle streaming responses', async () => { + const stream = handler.createMessage(systemPrompt, messages); + const chunks: any[] = []; for await (const chunk of stream) { chunks.push(chunk); } - expect(chunks.length).toBe(2); - expect(chunks[0]).toEqual({ - type: 'text', - text: 'Hello' - }); - expect(chunks[1]).toEqual({ - type: 'text', - text: ' world!' - }); - - expect(mockCreate).toHaveBeenCalledWith({ - model: 'mistral-7b', - messages: expect.arrayContaining([ - { - role: 'system', - content: systemPrompt - } - ]), - temperature: 0, - stream: true - }); + expect(chunks.length).toBeGreaterThan(0); + const textChunks = chunks.filter(chunk => chunk.type === 'text'); + expect(textChunks).toHaveLength(1); + expect(textChunks[0].text).toBe('Test response'); }); - it('should handle API errors with custom message', async () => { - const mockError = new Error('LM Studio API error'); - const mockCreate = jest.fn().mockRejectedValue(mockError); - (handler['client'].chat.completions as any).create = mockCreate; + it('should handle API errors', async () => { + mockCreate.mockRejectedValueOnce(new Error('API Error')); - const stream = handler.createMessage(systemPrompt, mockMessages); + const stream = handler.createMessage(systemPrompt, messages); await expect(async () => { for await (const chunk of stream) { - // Should throw before yielding any chunks + // Should not reach here } }).rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong'); }); }); + describe('completePrompt', () => { + it('should complete prompt successfully', async () => { + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe('Test response'); + expect(mockCreate).toHaveBeenCalledWith({ + model: mockOptions.lmStudioModelId, + messages: [{ role: 'user', content: 'Test prompt' }], + temperature: 0, + stream: false + }); + }); + + it('should handle API errors', async () => { + mockCreate.mockRejectedValueOnce(new Error('API Error')); + await expect(handler.completePrompt('Test prompt')) + .rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong'); + }); + + it('should handle empty response', async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: '' } }] + }); + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe(''); + }); + }); + describe('getModel', () => { - it('should return model info with sane defaults', () => { + it('should return model info', () => { const modelInfo = handler.getModel(); - expect(modelInfo.id).toBe('mistral-7b'); + expect(modelInfo.id).toBe(mockOptions.lmStudioModelId); expect(modelInfo.info).toBeDefined(); expect(modelInfo.info.maxTokens).toBe(-1); expect(modelInfo.info.contextWindow).toBe(128_000); }); - - it('should return empty string as model ID if not provided', () => { - const noModelHandler = new LmStudioHandler({}); - const modelInfo = noModelHandler.getModel(); - expect(modelInfo.id).toBe(''); - expect(modelInfo.info).toBeDefined(); - }); }); }); \ No newline at end of file diff --git a/src/api/providers/__tests__/ollama.test.ts b/src/api/providers/__tests__/ollama.test.ts index 3d74e88..fc4c9f5 100644 --- a/src/api/providers/__tests__/ollama.test.ts +++ b/src/api/providers/__tests__/ollama.test.ts @@ -1,148 +1,160 @@ import { OllamaHandler } from '../ollama'; -import { Anthropic } from '@anthropic-ai/sdk'; +import { ApiHandlerOptions } from '../../../shared/api'; import OpenAI from 'openai'; +import { Anthropic } from '@anthropic-ai/sdk'; -// Mock OpenAI SDK -jest.mock('openai', () => ({ - __esModule: true, - default: jest.fn().mockImplementation(() => ({ - chat: { - completions: { - create: jest.fn() +// Mock OpenAI client +const mockCreate = jest.fn(); +jest.mock('openai', () => { + return { + __esModule: true, + default: jest.fn().mockImplementation(() => ({ + chat: { + completions: { + create: mockCreate.mockImplementation(async (options) => { + if (!options.stream) { + return { + id: 'test-completion', + choices: [{ + message: { role: 'assistant', content: 'Test response' }, + finish_reason: 'stop', + index: 0 + }], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15 + } + }; + } + + return { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [{ + delta: { content: 'Test response' }, + index: 0 + }], + usage: null + }; + yield { + choices: [{ + delta: {}, + index: 0 + }], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15 + } + }; + } + }; + }) + } } - } - })) -})); + })) + }; +}); describe('OllamaHandler', () => { let handler: OllamaHandler; + let mockOptions: ApiHandlerOptions; beforeEach(() => { - handler = new OllamaHandler({ + mockOptions = { + apiModelId: 'llama2', ollamaModelId: 'llama2', - ollamaBaseUrl: 'http://localhost:11434' - }); + ollamaBaseUrl: 'http://localhost:11434/v1' + }; + handler = new OllamaHandler(mockOptions); + mockCreate.mockClear(); }); describe('constructor', () => { - it('should initialize with provided config', () => { - expect(OpenAI).toHaveBeenCalledWith({ - baseURL: 'http://localhost:11434/v1', - apiKey: 'ollama' - }); + it('should initialize with provided options', () => { + expect(handler).toBeInstanceOf(OllamaHandler); + expect(handler.getModel().id).toBe(mockOptions.ollamaModelId); }); it('should use default base URL if not provided', () => { - const defaultHandler = new OllamaHandler({ + const handlerWithoutUrl = new OllamaHandler({ + apiModelId: 'llama2', ollamaModelId: 'llama2' }); - - expect(OpenAI).toHaveBeenCalledWith({ - baseURL: 'http://localhost:11434/v1', - apiKey: 'ollama' - }); + expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler); }); }); describe('createMessage', () => { - const mockMessages: Anthropic.Messages.MessageParam[] = [ + const systemPrompt = 'You are a helpful assistant.'; + const messages: Anthropic.Messages.MessageParam[] = [ { role: 'user', - content: 'Hello' - }, - { - role: 'assistant', - content: 'Hi there!' + content: 'Hello!' } ]; - const systemPrompt = 'You are a helpful assistant'; - - it('should handle streaming responses correctly', async () => { - const mockStream = [ - { - choices: [{ - delta: { content: 'Hello' } - }] - }, - { - choices: [{ - delta: { content: ' world!' } - }] - } - ]; - - // Setup async iterator for mock stream - const asyncIterator = { - async *[Symbol.asyncIterator]() { - for (const chunk of mockStream) { - yield chunk; - } - } - }; - - const mockCreate = jest.fn().mockResolvedValue(asyncIterator); - (handler['client'].chat.completions as any).create = mockCreate; - - const stream = handler.createMessage(systemPrompt, mockMessages); - const chunks = []; - + it('should handle streaming responses', async () => { + const stream = handler.createMessage(systemPrompt, messages); + const chunks: any[] = []; for await (const chunk of stream) { chunks.push(chunk); } - expect(chunks.length).toBe(2); - expect(chunks[0]).toEqual({ - type: 'text', - text: 'Hello' - }); - expect(chunks[1]).toEqual({ - type: 'text', - text: ' world!' - }); + expect(chunks.length).toBeGreaterThan(0); + const textChunks = chunks.filter(chunk => chunk.type === 'text'); + expect(textChunks).toHaveLength(1); + expect(textChunks[0].text).toBe('Test response'); + }); + it('should handle API errors', async () => { + mockCreate.mockRejectedValueOnce(new Error('API Error')); + + const stream = handler.createMessage(systemPrompt, messages); + + await expect(async () => { + for await (const chunk of stream) { + // Should not reach here + } + }).rejects.toThrow('API Error'); + }); + }); + + describe('completePrompt', () => { + it('should complete prompt successfully', async () => { + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe('Test response'); expect(mockCreate).toHaveBeenCalledWith({ - model: 'llama2', - messages: expect.arrayContaining([ - { - role: 'system', - content: systemPrompt - } - ]), + model: mockOptions.ollamaModelId, + messages: [{ role: 'user', content: 'Test prompt' }], temperature: 0, - stream: true + stream: false }); }); it('should handle API errors', async () => { - const mockError = new Error('Ollama API error'); - const mockCreate = jest.fn().mockRejectedValue(mockError); - (handler['client'].chat.completions as any).create = mockCreate; + mockCreate.mockRejectedValueOnce(new Error('API Error')); + await expect(handler.completePrompt('Test prompt')) + .rejects.toThrow('Ollama completion error: API Error'); + }); - const stream = handler.createMessage(systemPrompt, mockMessages); - - await expect(async () => { - for await (const chunk of stream) { - // Should throw before yielding any chunks - } - }).rejects.toThrow('Ollama API error'); + it('should handle empty response', async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: '' } }] + }); + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe(''); }); }); describe('getModel', () => { - it('should return model info with sane defaults', () => { + it('should return model info', () => { const modelInfo = handler.getModel(); - expect(modelInfo.id).toBe('llama2'); + expect(modelInfo.id).toBe(mockOptions.ollamaModelId); expect(modelInfo.info).toBeDefined(); expect(modelInfo.info.maxTokens).toBe(-1); expect(modelInfo.info.contextWindow).toBe(128_000); }); - - it('should return empty string as model ID if not provided', () => { - const noModelHandler = new OllamaHandler({}); - const modelInfo = noModelHandler.getModel(); - expect(modelInfo.id).toBe(''); - expect(modelInfo.info).toBeDefined(); - }); }); }); \ No newline at end of file diff --git a/src/api/providers/__tests__/openai-native.test.ts b/src/api/providers/__tests__/openai-native.test.ts index ece832a..fe40804 100644 --- a/src/api/providers/__tests__/openai-native.test.ts +++ b/src/api/providers/__tests__/openai-native.test.ts @@ -1,230 +1,209 @@ -import { OpenAiNativeHandler } from "../openai-native" -import OpenAI from "openai" -import { ApiHandlerOptions, openAiNativeDefaultModelId } from "../../../shared/api" -import { Anthropic } from "@anthropic-ai/sdk" +import { OpenAiNativeHandler } from '../openai-native'; +import { ApiHandlerOptions } from '../../../shared/api'; +import OpenAI from 'openai'; +import { Anthropic } from '@anthropic-ai/sdk'; -// Mock OpenAI -jest.mock("openai") - -describe("OpenAiNativeHandler", () => { - let handler: OpenAiNativeHandler - let mockOptions: ApiHandlerOptions - let mockOpenAIClient: jest.Mocked - let mockCreate: jest.Mock - - beforeEach(() => { - // Reset mocks - jest.clearAllMocks() - - // Setup mock options - mockOptions = { - openAiNativeApiKey: "test-api-key", - apiModelId: "gpt-4o", // Use the correct model ID from shared/api.ts - } - - // Setup mock create function - mockCreate = jest.fn() - - // Setup mock OpenAI client - mockOpenAIClient = { +// Mock OpenAI client +const mockCreate = jest.fn(); +jest.mock('openai', () => { + return { + __esModule: true, + default: jest.fn().mockImplementation(() => ({ chat: { completions: { - create: mockCreate, - }, - }, - } as unknown as jest.Mocked + create: mockCreate.mockImplementation(async (options) => { + if (!options.stream) { + return { + id: 'test-completion', + choices: [{ + message: { role: 'assistant', content: 'Test response' }, + finish_reason: 'stop', + index: 0 + }], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15 + } + }; + } + + return { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [{ + delta: { content: 'Test response' }, + index: 0 + }], + usage: null + }; + yield { + choices: [{ + delta: {}, + index: 0 + }], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15 + } + }; + } + }; + }) + } + } + })) + }; +}); - // Mock OpenAI constructor - ;(OpenAI as jest.MockedClass).mockImplementation(() => mockOpenAIClient) +describe('OpenAiNativeHandler', () => { + let handler: OpenAiNativeHandler; + let mockOptions: ApiHandlerOptions; - // Create handler instance - handler = new OpenAiNativeHandler(mockOptions) - }) + beforeEach(() => { + mockOptions = { + apiModelId: 'gpt-4o', + openAiNativeApiKey: 'test-api-key' + }; + handler = new OpenAiNativeHandler(mockOptions); + mockCreate.mockClear(); + }); - describe("constructor", () => { - it("should initialize with provided options", () => { - expect(OpenAI).toHaveBeenCalledWith({ - apiKey: mockOptions.openAiNativeApiKey, - }) - }) - }) + describe('constructor', () => { + it('should initialize with provided options', () => { + expect(handler).toBeInstanceOf(OpenAiNativeHandler); + expect(handler.getModel().id).toBe(mockOptions.apiModelId); + }); - describe("getModel", () => { - it("should return specified model when valid", () => { - const result = handler.getModel() - expect(result.id).toBe("gpt-4o") // Use the correct model ID - }) + it('should initialize with empty API key', () => { + const handlerWithoutKey = new OpenAiNativeHandler({ + apiModelId: 'gpt-4o', + openAiNativeApiKey: '' + }); + expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler); + }); + }); - it("should return default model when model ID is invalid", () => { - handler = new OpenAiNativeHandler({ - ...mockOptions, - apiModelId: "invalid-model" as any, - }) - const result = handler.getModel() - expect(result.id).toBe(openAiNativeDefaultModelId) - }) - - it("should return default model when model ID is not provided", () => { - handler = new OpenAiNativeHandler({ - ...mockOptions, - apiModelId: undefined, - }) - const result = handler.getModel() - expect(result.id).toBe(openAiNativeDefaultModelId) - }) - }) - - describe("createMessage", () => { - const systemPrompt = "You are a helpful assistant" + describe('createMessage', () => { + const systemPrompt = 'You are a helpful assistant.'; const messages: Anthropic.Messages.MessageParam[] = [ - { role: "user", content: "Hello" }, - ] + { + role: 'user', + content: 'Hello!' + } + ]; - describe("o1 models", () => { - beforeEach(() => { - handler = new OpenAiNativeHandler({ - ...mockOptions, - apiModelId: "o1-preview", - }) - }) + it('should handle streaming responses', async () => { + const stream = handler.createMessage(systemPrompt, messages); + const chunks: any[] = []; + for await (const chunk of stream) { + chunks.push(chunk); + } - it("should handle non-streaming response for o1 models", async () => { - const mockResponse = { - choices: [{ message: { content: "Hello there!" } }], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - }, - } + expect(chunks.length).toBeGreaterThan(0); + const textChunks = chunks.filter(chunk => chunk.type === 'text'); + expect(textChunks).toHaveLength(1); + expect(textChunks[0].text).toBe('Test response'); + }); - mockCreate.mockResolvedValueOnce(mockResponse) + it('should handle API errors', async () => { + mockCreate.mockRejectedValueOnce(new Error('API Error')); - const generator = handler.createMessage(systemPrompt, messages) - const results = [] - for await (const result of generator) { - results.push(result) - } + const stream = handler.createMessage(systemPrompt, messages); - expect(results).toEqual([ - { type: "text", text: "Hello there!" }, - { type: "usage", inputTokens: 10, outputTokens: 5 }, - ]) - - expect(mockCreate).toHaveBeenCalledWith({ - model: "o1-preview", - messages: [ - { role: "user", content: systemPrompt }, - { role: "user", content: "Hello" }, - ], - }) - }) - - it("should handle missing content in response", async () => { - const mockResponse = { - choices: [{ message: { content: null } }], - usage: null, - } - - mockCreate.mockResolvedValueOnce(mockResponse) - - const generator = handler.createMessage(systemPrompt, messages) - const results = [] - for await (const result of generator) { - results.push(result) - } - - expect(results).toEqual([ - { type: "text", text: "" }, - { type: "usage", inputTokens: 0, outputTokens: 0 }, - ]) - }) - }) - - describe("streaming models", () => { - beforeEach(() => { - handler = new OpenAiNativeHandler({ - ...mockOptions, - apiModelId: "gpt-4o", - }) - }) - - it("should handle streaming response", async () => { - const mockStream = [ - { choices: [{ delta: { content: "Hello" } }], usage: null }, - { choices: [{ delta: { content: " there" } }], usage: null }, - { choices: [{ delta: { content: "!" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } }, - ] - - mockCreate.mockResolvedValueOnce( - (async function* () { - for (const chunk of mockStream) { - yield chunk - } - })() - ) - - const generator = handler.createMessage(systemPrompt, messages) - const results = [] - for await (const result of generator) { - results.push(result) - } - - expect(results).toEqual([ - { type: "text", text: "Hello" }, - { type: "text", text: " there" }, - { type: "text", text: "!" }, - { type: "usage", inputTokens: 10, outputTokens: 5 }, - ]) - - expect(mockCreate).toHaveBeenCalledWith({ - model: "gpt-4o", - temperature: 0, - messages: [ - { role: "system", content: systemPrompt }, - { role: "user", content: "Hello" }, - ], - stream: true, - stream_options: { include_usage: true }, - }) - }) - - it("should handle empty delta content", async () => { - const mockStream = [ - { choices: [{ delta: {} }], usage: null }, - { choices: [{ delta: { content: null } }], usage: null }, - { choices: [{ delta: { content: "Hello" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } }, - ] - - mockCreate.mockResolvedValueOnce( - (async function* () { - for (const chunk of mockStream) { - yield chunk - } - })() - ) - - const generator = handler.createMessage(systemPrompt, messages) - const results = [] - for await (const result of generator) { - results.push(result) - } - - expect(results).toEqual([ - { type: "text", text: "Hello" }, - { type: "usage", inputTokens: 10, outputTokens: 5 }, - ]) - }) - }) - - it("should handle API errors", async () => { - mockCreate.mockRejectedValueOnce(new Error("API Error")) - - const generator = handler.createMessage(systemPrompt, messages) await expect(async () => { - for await (const _ of generator) { - // consume generator + for await (const chunk of stream) { + // Should not reach here } - }).rejects.toThrow("API Error") - }) - }) -}) \ No newline at end of file + }).rejects.toThrow('API Error'); + }); + }); + + describe('completePrompt', () => { + it('should complete prompt successfully with gpt-4o model', async () => { + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe('Test response'); + expect(mockCreate).toHaveBeenCalledWith({ + model: 'gpt-4o', + messages: [{ role: 'user', content: 'Test prompt' }], + temperature: 0 + }); + }); + + it('should complete prompt successfully with o1 model', async () => { + handler = new OpenAiNativeHandler({ + apiModelId: 'o1', + openAiNativeApiKey: 'test-api-key' + }); + + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe('Test response'); + expect(mockCreate).toHaveBeenCalledWith({ + model: 'o1', + messages: [{ role: 'user', content: 'Test prompt' }] + }); + }); + + it('should complete prompt successfully with o1-preview model', async () => { + handler = new OpenAiNativeHandler({ + apiModelId: 'o1-preview', + openAiNativeApiKey: 'test-api-key' + }); + + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe('Test response'); + expect(mockCreate).toHaveBeenCalledWith({ + model: 'o1-preview', + messages: [{ role: 'user', content: 'Test prompt' }] + }); + }); + + it('should complete prompt successfully with o1-mini model', async () => { + handler = new OpenAiNativeHandler({ + apiModelId: 'o1-mini', + openAiNativeApiKey: 'test-api-key' + }); + + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe('Test response'); + expect(mockCreate).toHaveBeenCalledWith({ + model: 'o1-mini', + messages: [{ role: 'user', content: 'Test prompt' }] + }); + }); + + it('should handle API errors', async () => { + mockCreate.mockRejectedValueOnce(new Error('API Error')); + await expect(handler.completePrompt('Test prompt')) + .rejects.toThrow('OpenAI Native completion error: API Error'); + }); + + it('should handle empty response', async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: '' } }] + }); + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe(''); + }); + }); + + describe('getModel', () => { + it('should return model info', () => { + const modelInfo = handler.getModel(); + expect(modelInfo.id).toBe(mockOptions.apiModelId); + expect(modelInfo.info).toBeDefined(); + expect(modelInfo.info.maxTokens).toBe(4096); + expect(modelInfo.info.contextWindow).toBe(128_000); + }); + + it('should handle undefined model ID', () => { + const handlerWithoutModel = new OpenAiNativeHandler({ + openAiNativeApiKey: 'test-api-key' + }); + const modelInfo = handlerWithoutModel.getModel(); + expect(modelInfo.id).toBe('gpt-4o'); // Default model + expect(modelInfo.info).toBeDefined(); + }); + }); +}); \ No newline at end of file diff --git a/src/api/providers/__tests__/openai.test.ts b/src/api/providers/__tests__/openai.test.ts index edd6460..4a4a449 100644 --- a/src/api/providers/__tests__/openai.test.ts +++ b/src/api/providers/__tests__/openai.test.ts @@ -176,6 +176,32 @@ describe('OpenAiHandler', () => { }); }); + describe('completePrompt', () => { + it('should complete prompt successfully', async () => { + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe('Test response'); + expect(mockCreate).toHaveBeenCalledWith({ + model: mockOptions.openAiModelId, + messages: [{ role: 'user', content: 'Test prompt' }], + temperature: 0 + }); + }); + + it('should handle API errors', async () => { + mockCreate.mockRejectedValueOnce(new Error('API Error')); + await expect(handler.completePrompt('Test prompt')) + .rejects.toThrow('OpenAI completion error: API Error'); + }); + + it('should handle empty response', async () => { + mockCreate.mockImplementationOnce(() => ({ + choices: [{ message: { content: '' } }] + })); + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe(''); + }); + }); + describe('getModel', () => { it('should return model info with sane defaults', () => { const model = handler.getModel(); diff --git a/src/api/providers/__tests__/vertex.test.ts b/src/api/providers/__tests__/vertex.test.ts index 71aa810..be5899f 100644 --- a/src/api/providers/__tests__/vertex.test.ts +++ b/src/api/providers/__tests__/vertex.test.ts @@ -6,7 +6,42 @@ import { AnthropicVertex } from '@anthropic-ai/vertex-sdk'; jest.mock('@anthropic-ai/vertex-sdk', () => ({ AnthropicVertex: jest.fn().mockImplementation(() => ({ messages: { - create: jest.fn() + create: jest.fn().mockImplementation(async (options) => { + if (!options.stream) { + return { + id: 'test-completion', + content: [ + { type: 'text', text: 'Test response' } + ], + role: 'assistant', + model: options.model, + usage: { + input_tokens: 10, + output_tokens: 5 + } + } + } + return { + async *[Symbol.asyncIterator]() { + yield { + type: 'message_start', + message: { + usage: { + input_tokens: 10, + output_tokens: 5 + } + } + } + yield { + type: 'content_block_start', + content_block: { + type: 'text', + text: 'Test response' + } + } + } + } + }) } })) })); @@ -196,6 +231,49 @@ describe('VertexHandler', () => { }); }); + describe('completePrompt', () => { + it('should complete prompt successfully', async () => { + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe('Test response'); + expect(handler['client'].messages.create).toHaveBeenCalledWith({ + model: 'claude-3-5-sonnet-v2@20241022', + max_tokens: 8192, + temperature: 0, + messages: [{ role: 'user', content: 'Test prompt' }], + stream: false + }); + }); + + it('should handle API errors', async () => { + const mockError = new Error('Vertex API error'); + const mockCreate = jest.fn().mockRejectedValue(mockError); + (handler['client'].messages as any).create = mockCreate; + + await expect(handler.completePrompt('Test prompt')) + .rejects.toThrow('Vertex completion error: Vertex API error'); + }); + + it('should handle non-text content', async () => { + const mockCreate = jest.fn().mockResolvedValue({ + content: [{ type: 'image' }] + }); + (handler['client'].messages as any).create = mockCreate; + + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe(''); + }); + + it('should handle empty response', async () => { + const mockCreate = jest.fn().mockResolvedValue({ + content: [{ type: 'text', text: '' }] + }); + (handler['client'].messages as any).create = mockCreate; + + const result = await handler.completePrompt('Test prompt'); + expect(result).toBe(''); + }); + }); + describe('getModel', () => { it('should return correct model info', () => { const modelInfo = handler.getModel(); diff --git a/src/api/providers/anthropic.ts b/src/api/providers/anthropic.ts index c090f17..5184281 100644 --- a/src/api/providers/anthropic.ts +++ b/src/api/providers/anthropic.ts @@ -7,10 +7,10 @@ import { ApiHandlerOptions, ModelInfo, } from "../../shared/api" -import { ApiHandler } from "../index" +import { ApiHandler, SingleCompletionHandler } from "../index" import { ApiStream } from "../transform/stream" -export class AnthropicHandler implements ApiHandler { +export class AnthropicHandler implements ApiHandler, SingleCompletionHandler { private options: ApiHandlerOptions private client: Anthropic @@ -173,4 +173,27 @@ export class AnthropicHandler implements ApiHandler { } return { id: anthropicDefaultModelId, info: anthropicModels[anthropicDefaultModelId] } } + + async completePrompt(prompt: string): Promise { + try { + const response = await this.client.messages.create({ + model: this.getModel().id, + max_tokens: this.getModel().info.maxTokens || 8192, + temperature: 0, + messages: [{ role: "user", content: prompt }], + stream: false + }) + + const content = response.content[0] + if (content.type === 'text') { + return content.text + } + return '' + } catch (error) { + if (error instanceof Error) { + throw new Error(`Anthropic completion error: ${error.message}`) + } + throw error + } + } } diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 3b691c1..3d07895 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -1,6 +1,6 @@ -import { BedrockRuntimeClient, ConverseStreamCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime" +import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime" import { Anthropic } from "@anthropic-ai/sdk" -import { ApiHandler } from "../" +import { ApiHandler, SingleCompletionHandler } from "../" import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api" import { ApiStream } from "../transform/stream" import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../transform/bedrock-converse-format" @@ -38,7 +38,7 @@ export interface StreamEvent { }; } -export class AwsBedrockHandler implements ApiHandler { +export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler { private options: ApiHandlerOptions private client: BedrockRuntimeClient @@ -199,7 +199,7 @@ export class AwsBedrockHandler implements ApiHandler { if (modelId) { // For tests, allow any model ID if (process.env.NODE_ENV === 'test') { - return { + return { id: modelId, info: { maxTokens: 5000, @@ -214,9 +214,68 @@ export class AwsBedrockHandler implements ApiHandler { return { id, info: bedrockModels[id] } } } - return { - id: bedrockDefaultModelId, - info: bedrockModels[bedrockDefaultModelId] + return { + id: bedrockDefaultModelId, + info: bedrockModels[bedrockDefaultModelId] + } + } + + async completePrompt(prompt: string): Promise { + try { + const modelConfig = this.getModel() + + // Handle cross-region inference + let modelId: string + if (this.options.awsUseCrossRegionInference) { + let regionPrefix = (this.options.awsRegion || "").slice(0, 3) + switch (regionPrefix) { + case "us-": + modelId = `us.${modelConfig.id}` + break + case "eu-": + modelId = `eu.${modelConfig.id}` + break + default: + modelId = modelConfig.id + break + } + } else { + modelId = modelConfig.id + } + + const payload = { + modelId, + messages: convertToBedrockConverseMessages([{ + role: "user", + content: prompt + }]), + inferenceConfig: { + maxTokens: modelConfig.info.maxTokens || 5000, + temperature: 0.3, + topP: 0.1 + } + } + + const command = new ConverseCommand(payload) + const response = await this.client.send(command) + + if (response.output && response.output instanceof Uint8Array) { + try { + const outputStr = new TextDecoder().decode(response.output) + const output = JSON.parse(outputStr) + if (output.content) { + return output.content + } + } catch (parseError) { + console.error('Failed to parse Bedrock response:', parseError) + } + } + return '' + } catch (error) { + if (error instanceof Error) { + throw new Error(`Bedrock completion error: ${error.message}`) + } + throw error } } } diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index d7ac5ec..0f6392b 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -1,11 +1,11 @@ import { Anthropic } from "@anthropic-ai/sdk" import { GoogleGenerativeAI } from "@google/generative-ai" -import { ApiHandler } from "../" +import { ApiHandler, SingleCompletionHandler } from "../" import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../../shared/api" import { convertAnthropicMessageToGemini } from "../transform/gemini-format" import { ApiStream } from "../transform/stream" -export class GeminiHandler implements ApiHandler { +export class GeminiHandler implements ApiHandler, SingleCompletionHandler { private options: ApiHandlerOptions private client: GoogleGenerativeAI @@ -53,4 +53,26 @@ export class GeminiHandler implements ApiHandler { } return { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] } } + + async completePrompt(prompt: string): Promise { + try { + const model = this.client.getGenerativeModel({ + model: this.getModel().id, + }) + + const result = await model.generateContent({ + contents: [{ role: "user", parts: [{ text: prompt }] }], + generationConfig: { + temperature: 0, + }, + }) + + return result.response.text() + } catch (error) { + if (error instanceof Error) { + throw new Error(`Gemini completion error: ${error.message}`) + } + throw error + } + } } diff --git a/src/api/providers/glama.ts b/src/api/providers/glama.ts index c17db05..7e95d0c 100644 --- a/src/api/providers/glama.ts +++ b/src/api/providers/glama.ts @@ -1,13 +1,13 @@ import { Anthropic } from "@anthropic-ai/sdk" import axios from "axios" import OpenAI from "openai" -import { ApiHandler } from "../" +import { ApiHandler, SingleCompletionHandler } from "../" import { ApiHandlerOptions, ModelInfo, glamaDefaultModelId, glamaDefaultModelInfo } from "../../shared/api" import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream } from "../transform/stream" import delay from "delay" -export class GlamaHandler implements ApiHandler { +export class GlamaHandler implements ApiHandler, SingleCompletionHandler { private options: ApiHandlerOptions private client: OpenAI @@ -129,4 +129,26 @@ export class GlamaHandler implements ApiHandler { return { id: glamaDefaultModelId, info: glamaDefaultModelInfo } } + + async completePrompt(prompt: string): Promise { + try { + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { + model: this.getModel().id, + messages: [{ role: "user", content: prompt }], + temperature: 0, + } + + if (this.getModel().id.startsWith("anthropic/")) { + requestOptions.max_tokens = 8192 + } + + const response = await this.client.chat.completions.create(requestOptions) + return response.choices[0]?.message.content || "" + } catch (error) { + if (error instanceof Error) { + throw new Error(`Glama completion error: ${error.message}`) + } + throw error + } + } } diff --git a/src/api/providers/lmstudio.ts b/src/api/providers/lmstudio.ts index 868ef7d..e5c6256 100644 --- a/src/api/providers/lmstudio.ts +++ b/src/api/providers/lmstudio.ts @@ -1,11 +1,11 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" -import { ApiHandler } from "../" +import { ApiHandler, SingleCompletionHandler } from "../" import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api" import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream } from "../transform/stream" -export class LmStudioHandler implements ApiHandler { +export class LmStudioHandler implements ApiHandler, SingleCompletionHandler { private options: ApiHandlerOptions private client: OpenAI @@ -53,4 +53,20 @@ export class LmStudioHandler implements ApiHandler { info: openAiModelInfoSaneDefaults, } } + + async completePrompt(prompt: string): Promise { + try { + const response = await this.client.chat.completions.create({ + model: this.getModel().id, + messages: [{ role: "user", content: prompt }], + temperature: 0, + stream: false + }) + return response.choices[0]?.message.content || "" + } catch (error) { + throw new Error( + "Please check the LM Studio developer logs to debug what went wrong. You may need to load the model with a larger context length to work with Cline's prompts.", + ) + } + } } diff --git a/src/api/providers/ollama.ts b/src/api/providers/ollama.ts index 7668bd3..9df73d6 100644 --- a/src/api/providers/ollama.ts +++ b/src/api/providers/ollama.ts @@ -1,11 +1,11 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" -import { ApiHandler } from "../" +import { ApiHandler, SingleCompletionHandler } from "../" import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api" import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream } from "../transform/stream" -export class OllamaHandler implements ApiHandler { +export class OllamaHandler implements ApiHandler, SingleCompletionHandler { private options: ApiHandlerOptions private client: OpenAI @@ -46,4 +46,21 @@ export class OllamaHandler implements ApiHandler { info: openAiModelInfoSaneDefaults, } } + + async completePrompt(prompt: string): Promise { + try { + const response = await this.client.chat.completions.create({ + model: this.getModel().id, + messages: [{ role: "user", content: prompt }], + temperature: 0, + stream: false + }) + return response.choices[0]?.message.content || "" + } catch (error) { + if (error instanceof Error) { + throw new Error(`Ollama completion error: ${error.message}`) + } + throw error + } + } } diff --git a/src/api/providers/openai-native.ts b/src/api/providers/openai-native.ts index 139b3a2..83644c9 100644 --- a/src/api/providers/openai-native.ts +++ b/src/api/providers/openai-native.ts @@ -1,6 +1,6 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" -import { ApiHandler } from "../" +import { ApiHandler, SingleCompletionHandler } from "../" import { ApiHandlerOptions, ModelInfo, @@ -11,7 +11,7 @@ import { import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream } from "../transform/stream" -export class OpenAiNativeHandler implements ApiHandler { +export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler { private options: ApiHandlerOptions private client: OpenAI @@ -83,4 +83,37 @@ export class OpenAiNativeHandler implements ApiHandler { } return { id: openAiNativeDefaultModelId, info: openAiNativeModels[openAiNativeDefaultModelId] } } + + async completePrompt(prompt: string): Promise { + try { + const modelId = this.getModel().id + let requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming + + switch (modelId) { + case "o1": + case "o1-preview": + case "o1-mini": + // o1 doesn't support non-1 temp or system prompt + requestOptions = { + model: modelId, + messages: [{ role: "user", content: prompt }] + } + break + default: + requestOptions = { + model: modelId, + messages: [{ role: "user", content: prompt }], + temperature: 0 + } + } + + const response = await this.client.chat.completions.create(requestOptions) + return response.choices[0]?.message.content || "" + } catch (error) { + if (error instanceof Error) { + throw new Error(`OpenAI Native completion error: ${error.message}`) + } + throw error + } + } } diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index 3ec2192..0878028 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -6,11 +6,11 @@ import { ModelInfo, openAiModelInfoSaneDefaults, } from "../../shared/api" -import { ApiHandler } from "../index" +import { ApiHandler, SingleCompletionHandler } from "../index" import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream } from "../transform/stream" -export class OpenAiHandler implements ApiHandler { +export class OpenAiHandler implements ApiHandler, SingleCompletionHandler { protected options: ApiHandlerOptions private client: OpenAI @@ -100,4 +100,22 @@ export class OpenAiHandler implements ApiHandler { info: openAiModelInfoSaneDefaults, } } + + async completePrompt(prompt: string): Promise { + try { + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { + model: this.getModel().id, + messages: [{ role: "user", content: prompt }], + temperature: 0, + } + + const response = await this.client.chat.completions.create(requestOptions) + return response.choices[0]?.message.content || "" + } catch (error) { + if (error instanceof Error) { + throw new Error(`OpenAI completion error: ${error.message}`) + } + throw error + } + } } diff --git a/src/api/providers/vertex.ts b/src/api/providers/vertex.ts index 60e6967..aed704e 100644 --- a/src/api/providers/vertex.ts +++ b/src/api/providers/vertex.ts @@ -1,11 +1,11 @@ import { Anthropic } from "@anthropic-ai/sdk" import { AnthropicVertex } from "@anthropic-ai/vertex-sdk" -import { ApiHandler } from "../" +import { ApiHandler, SingleCompletionHandler } from "../" import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../../shared/api" import { ApiStream } from "../transform/stream" // https://docs.anthropic.com/en/api/claude-on-vertex-ai -export class VertexHandler implements ApiHandler { +export class VertexHandler implements ApiHandler, SingleCompletionHandler { private options: ApiHandlerOptions private client: AnthropicVertex @@ -83,4 +83,27 @@ export class VertexHandler implements ApiHandler { } return { id: vertexDefaultModelId, info: vertexModels[vertexDefaultModelId] } } + + async completePrompt(prompt: string): Promise { + try { + const response = await this.client.messages.create({ + model: this.getModel().id, + max_tokens: this.getModel().info.maxTokens || 8192, + temperature: 0, + messages: [{ role: "user", content: prompt }], + stream: false + }) + + const content = response.content[0] + if (content.type === 'text') { + return content.text + } + return '' + } catch (error) { + if (error instanceof Error) { + throw new Error(`Vertex completion error: ${error.message}`) + } + throw error + } + } }