Add non-streaming completePrompt to all providers

This commit is contained in:
Matt Rubens
2025-01-13 16:16:58 -05:00
parent 2d176e5c92
commit 4027e1c10c
18 changed files with 1235 additions and 438 deletions

View File

@@ -46,7 +46,42 @@ jest.mock('@anthropic-ai/sdk', () => {
} }
}, },
messages: { messages: {
create: mockCreate create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: 'test-completion',
content: [
{ type: 'text', text: 'Test response' }
],
role: 'assistant',
model: options.model,
usage: {
input_tokens: 10,
output_tokens: 5
}
}
}
return {
async *[Symbol.asyncIterator]() {
yield {
type: 'message_start',
message: {
usage: {
input_tokens: 10,
output_tokens: 5
}
}
}
yield {
type: 'content_block_start',
content_block: {
type: 'text',
text: 'Test response'
}
}
}
}
})
} }
})) }))
}; };
@@ -144,6 +179,42 @@ describe('AnthropicHandler', () => {
}); });
}); });
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.apiModelId,
messages: [{ role: 'user', content: 'Test prompt' }],
max_tokens: 8192,
temperature: 0,
stream: false
});
});
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Anthropic completion error: API Error');
});
it('should handle non-text content', async () => {
mockCreate.mockImplementationOnce(async () => ({
content: [{ type: 'image' }]
}));
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
it('should handle empty response', async () => {
mockCreate.mockImplementationOnce(async () => ({
content: [{ type: 'text', text: '' }]
}));
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
describe('getModel', () => { describe('getModel', () => {
it('should return default model if no model ID is provided', () => { it('should return default model if no model ID is provided', () => {
const handlerWithoutModel = new AnthropicHandler({ const handlerWithoutModel = new AnthropicHandler({

View File

@@ -119,6 +119,108 @@ describe('AwsBedrockHandler', () => {
}); });
}); });
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const mockResponse = {
output: new TextEncoder().encode(JSON.stringify({
content: 'Test response'
}))
};
const mockSend = jest.fn().mockResolvedValue(mockResponse);
handler['client'] = {
send: mockSend
} as unknown as BedrockRuntimeClient;
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
input: expect.objectContaining({
modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
messages: expect.arrayContaining([
expect.objectContaining({
role: 'user',
content: [{ text: 'Test prompt' }]
})
]),
inferenceConfig: expect.objectContaining({
maxTokens: 5000,
temperature: 0.3,
topP: 0.1
})
})
}));
});
it('should handle API errors', async () => {
const mockError = new Error('AWS Bedrock error');
const mockSend = jest.fn().mockRejectedValue(mockError);
handler['client'] = {
send: mockSend
} as unknown as BedrockRuntimeClient;
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Bedrock completion error: AWS Bedrock error');
});
it('should handle invalid response format', async () => {
const mockResponse = {
output: new TextEncoder().encode('invalid json')
};
const mockSend = jest.fn().mockResolvedValue(mockResponse);
handler['client'] = {
send: mockSend
} as unknown as BedrockRuntimeClient;
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
it('should handle empty response', async () => {
const mockResponse = {
output: new TextEncoder().encode(JSON.stringify({}))
};
const mockSend = jest.fn().mockResolvedValue(mockResponse);
handler['client'] = {
send: mockSend
} as unknown as BedrockRuntimeClient;
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
it('should handle cross-region inference', async () => {
handler = new AwsBedrockHandler({
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
awsAccessKey: 'test-access-key',
awsSecretKey: 'test-secret-key',
awsRegion: 'us-east-1',
awsUseCrossRegionInference: true
});
const mockResponse = {
output: new TextEncoder().encode(JSON.stringify({
content: 'Test response'
}))
};
const mockSend = jest.fn().mockResolvedValue(mockResponse);
handler['client'] = {
send: mockSend
} as unknown as BedrockRuntimeClient;
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
input: expect.objectContaining({
modelId: 'us.anthropic.claude-3-5-sonnet-20241022-v2:0'
})
}));
});
});
describe('getModel', () => { describe('getModel', () => {
it('should return correct model info in test environment', () => { it('should return correct model info in test environment', () => {
const modelInfo = handler.getModel(); const modelInfo = handler.getModel();

View File

@@ -6,7 +6,12 @@ import { GoogleGenerativeAI } from '@google/generative-ai';
jest.mock('@google/generative-ai', () => ({ jest.mock('@google/generative-ai', () => ({
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({ GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
getGenerativeModel: jest.fn().mockReturnValue({ getGenerativeModel: jest.fn().mockReturnValue({
generateContentStream: jest.fn() generateContentStream: jest.fn(),
generateContent: jest.fn().mockResolvedValue({
response: {
text: () => 'Test response'
}
})
}) })
})) }))
})); }));
@@ -133,6 +138,59 @@ describe('GeminiHandler', () => {
}); });
}); });
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const mockGenerateContent = jest.fn().mockResolvedValue({
response: {
text: () => 'Test response'
}
});
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent
});
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
model: 'gemini-2.0-flash-thinking-exp-1219'
});
expect(mockGenerateContent).toHaveBeenCalledWith({
contents: [{ role: 'user', parts: [{ text: 'Test prompt' }] }],
generationConfig: {
temperature: 0
}
});
});
it('should handle API errors', async () => {
const mockError = new Error('Gemini API error');
const mockGenerateContent = jest.fn().mockRejectedValue(mockError);
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent
});
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Gemini completion error: Gemini API error');
});
it('should handle empty response', async () => {
const mockGenerateContent = jest.fn().mockResolvedValue({
response: {
text: () => ''
}
});
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent
});
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
describe('getModel', () => { describe('getModel', () => {
it('should return correct model info', () => { it('should return correct model info', () => {
const modelInfo = handler.getModel(); const modelInfo = handler.getModel();

View File

@@ -0,0 +1,226 @@
import { GlamaHandler } from '../glama';
import { ApiHandlerOptions } from '../../../shared/api';
import OpenAI from 'openai';
import { Anthropic } from '@anthropic-ai/sdk';
import axios from 'axios';
// Mock OpenAI client
const mockCreate = jest.fn();
const mockWithResponse = jest.fn();
jest.mock('openai', () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: (...args: any[]) => {
const stream = {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [{
delta: { content: 'Test response' },
index: 0
}],
usage: null
};
yield {
choices: [{
delta: {},
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
};
const result = mockCreate(...args);
if (args[0].stream) {
mockWithResponse.mockReturnValue(Promise.resolve({
data: stream,
response: {
headers: {
get: (name: string) => name === 'x-completion-request-id' ? 'test-request-id' : null
}
}
}));
result.withResponse = mockWithResponse;
}
return result;
}
}
}
}))
};
});
describe('GlamaHandler', () => {
let handler: GlamaHandler;
let mockOptions: ApiHandlerOptions;
beforeEach(() => {
mockOptions = {
apiModelId: 'anthropic/claude-3-5-sonnet',
glamaModelId: 'anthropic/claude-3-5-sonnet',
glamaApiKey: 'test-api-key'
};
handler = new GlamaHandler(mockOptions);
mockCreate.mockClear();
mockWithResponse.mockClear();
// Default mock implementation for non-streaming responses
mockCreate.mockResolvedValue({
id: 'test-completion',
choices: [{
message: { role: 'assistant', content: 'Test response' },
finish_reason: 'stop',
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
});
});
describe('constructor', () => {
it('should initialize with provided options', () => {
expect(handler).toBeInstanceOf(GlamaHandler);
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
});
});
describe('createMessage', () => {
const systemPrompt = 'You are a helpful assistant.';
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: 'Hello!'
}
];
it('should handle streaming responses', async () => {
// Mock axios for token usage request
const mockAxios = jest.spyOn(axios, 'get').mockResolvedValueOnce({
data: {
tokenUsage: {
promptTokens: 10,
completionTokens: 5,
cacheCreationInputTokens: 0,
cacheReadInputTokens: 0
},
totalCostUsd: "0.00"
}
});
const stream = handler.createMessage(systemPrompt, messages);
const chunks: any[] = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
expect(chunks.length).toBe(2); // Text chunk and usage chunk
expect(chunks[0]).toEqual({
type: 'text',
text: 'Test response'
});
expect(chunks[1]).toEqual({
type: 'usage',
inputTokens: 10,
outputTokens: 5,
cacheWriteTokens: 0,
cacheReadTokens: 0,
totalCost: 0
});
mockAxios.mockRestore();
});
it('should handle API errors', async () => {
mockCreate.mockImplementationOnce(() => {
throw new Error('API Error');
});
const stream = handler.createMessage(systemPrompt, messages);
const chunks = [];
try {
for await (const chunk of stream) {
chunks.push(chunk);
}
fail('Expected error to be thrown');
} catch (error) {
expect(error).toBeInstanceOf(Error);
expect(error.message).toBe('API Error');
}
});
});
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
model: mockOptions.apiModelId,
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0,
max_tokens: 8192
}));
});
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Glama completion error: API Error');
});
it('should handle empty response', async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: '' } }]
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
it('should not set max_tokens for non-Anthropic models', async () => {
// Reset mock to clear any previous calls
mockCreate.mockClear();
const nonAnthropicOptions = {
apiModelId: 'openai/gpt-4',
glamaModelId: 'openai/gpt-4',
glamaApiKey: 'test-key',
glamaModelInfo: {
maxTokens: 4096,
contextWindow: 8192,
supportsImages: true,
supportsPromptCache: false
}
};
const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions);
await nonAnthropicHandler.completePrompt('Test prompt');
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
model: 'openai/gpt-4',
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0
}));
expect(mockCreate.mock.calls[0][0]).not.toHaveProperty('max_tokens');
});
});
describe('getModel', () => {
it('should return model info', () => {
const modelInfo = handler.getModel();
expect(modelInfo.id).toBe(mockOptions.apiModelId);
expect(modelInfo.info).toBeDefined();
expect(modelInfo.info.maxTokens).toBe(8192);
expect(modelInfo.info.contextWindow).toBe(200_000);
});
});
});

View File

@@ -1,148 +1,160 @@
import { LmStudioHandler } from '../lmstudio'; import { LmStudioHandler } from '../lmstudio';
import { Anthropic } from '@anthropic-ai/sdk'; import { ApiHandlerOptions } from '../../../shared/api';
import OpenAI from 'openai'; import OpenAI from 'openai';
import { Anthropic } from '@anthropic-ai/sdk';
// Mock OpenAI SDK // Mock OpenAI client
jest.mock('openai', () => ({ const mockCreate = jest.fn();
jest.mock('openai', () => {
return {
__esModule: true, __esModule: true,
default: jest.fn().mockImplementation(() => ({ default: jest.fn().mockImplementation(() => ({
chat: { chat: {
completions: { completions: {
create: jest.fn() create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: 'test-completion',
choices: [{
message: { role: 'assistant', content: 'Test response' },
finish_reason: 'stop',
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [{
delta: { content: 'Test response' },
index: 0
}],
usage: null
};
yield {
choices: [{
delta: {},
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
};
})
} }
} }
})) }))
})); };
});
describe('LmStudioHandler', () => { describe('LmStudioHandler', () => {
let handler: LmStudioHandler; let handler: LmStudioHandler;
let mockOptions: ApiHandlerOptions;
beforeEach(() => { beforeEach(() => {
handler = new LmStudioHandler({ mockOptions = {
lmStudioModelId: 'mistral-7b', apiModelId: 'local-model',
lmStudioBaseUrl: 'http://localhost:1234' lmStudioModelId: 'local-model',
}); lmStudioBaseUrl: 'http://localhost:1234/v1'
};
handler = new LmStudioHandler(mockOptions);
mockCreate.mockClear();
}); });
describe('constructor', () => { describe('constructor', () => {
it('should initialize with provided config', () => { it('should initialize with provided options', () => {
expect(OpenAI).toHaveBeenCalledWith({ expect(handler).toBeInstanceOf(LmStudioHandler);
baseURL: 'http://localhost:1234/v1', expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId);
apiKey: 'noop'
});
}); });
it('should use default base URL if not provided', () => { it('should use default base URL if not provided', () => {
const defaultHandler = new LmStudioHandler({ const handlerWithoutUrl = new LmStudioHandler({
lmStudioModelId: 'mistral-7b' apiModelId: 'local-model',
}); lmStudioModelId: 'local-model'
expect(OpenAI).toHaveBeenCalledWith({
baseURL: 'http://localhost:1234/v1',
apiKey: 'noop'
}); });
expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler);
}); });
}); });
describe('createMessage', () => { describe('createMessage', () => {
const mockMessages: Anthropic.Messages.MessageParam[] = [ const systemPrompt = 'You are a helpful assistant.';
const messages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: 'user',
content: 'Hello' content: 'Hello!'
},
{
role: 'assistant',
content: 'Hi there!'
} }
]; ];
const systemPrompt = 'You are a helpful assistant'; it('should handle streaming responses', async () => {
const stream = handler.createMessage(systemPrompt, messages);
it('should handle streaming responses correctly', async () => { const chunks: any[] = [];
const mockStream = [
{
choices: [{
delta: { content: 'Hello' }
}]
},
{
choices: [{
delta: { content: ' world!' }
}]
}
];
// Setup async iterator for mock stream
const asyncIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of mockStream) {
yield chunk;
}
}
};
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
(handler['client'].chat.completions as any).create = mockCreate;
const stream = handler.createMessage(systemPrompt, mockMessages);
const chunks = [];
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk);
} }
expect(chunks.length).toBe(2); expect(chunks.length).toBeGreaterThan(0);
expect(chunks[0]).toEqual({ const textChunks = chunks.filter(chunk => chunk.type === 'text');
type: 'text', expect(textChunks).toHaveLength(1);
text: 'Hello' expect(textChunks[0].text).toBe('Test response');
});
expect(chunks[1]).toEqual({
type: 'text',
text: ' world!'
}); });
expect(mockCreate).toHaveBeenCalledWith({ it('should handle API errors', async () => {
model: 'mistral-7b', mockCreate.mockRejectedValueOnce(new Error('API Error'));
messages: expect.arrayContaining([
{
role: 'system',
content: systemPrompt
}
]),
temperature: 0,
stream: true
});
});
it('should handle API errors with custom message', async () => { const stream = handler.createMessage(systemPrompt, messages);
const mockError = new Error('LM Studio API error');
const mockCreate = jest.fn().mockRejectedValue(mockError);
(handler['client'].chat.completions as any).create = mockCreate;
const stream = handler.createMessage(systemPrompt, mockMessages);
await expect(async () => { await expect(async () => {
for await (const chunk of stream) { for await (const chunk of stream) {
// Should throw before yielding any chunks // Should not reach here
} }
}).rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong'); }).rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
}); });
}); });
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.lmStudioModelId,
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0,
stream: false
});
});
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
});
it('should handle empty response', async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: '' } }]
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
describe('getModel', () => { describe('getModel', () => {
it('should return model info with sane defaults', () => { it('should return model info', () => {
const modelInfo = handler.getModel(); const modelInfo = handler.getModel();
expect(modelInfo.id).toBe('mistral-7b'); expect(modelInfo.id).toBe(mockOptions.lmStudioModelId);
expect(modelInfo.info).toBeDefined(); expect(modelInfo.info).toBeDefined();
expect(modelInfo.info.maxTokens).toBe(-1); expect(modelInfo.info.maxTokens).toBe(-1);
expect(modelInfo.info.contextWindow).toBe(128_000); expect(modelInfo.info.contextWindow).toBe(128_000);
}); });
it('should return empty string as model ID if not provided', () => {
const noModelHandler = new LmStudioHandler({});
const modelInfo = noModelHandler.getModel();
expect(modelInfo.id).toBe('');
expect(modelInfo.info).toBeDefined();
});
}); });
}); });

View File

@@ -1,148 +1,160 @@
import { OllamaHandler } from '../ollama'; import { OllamaHandler } from '../ollama';
import { Anthropic } from '@anthropic-ai/sdk'; import { ApiHandlerOptions } from '../../../shared/api';
import OpenAI from 'openai'; import OpenAI from 'openai';
import { Anthropic } from '@anthropic-ai/sdk';
// Mock OpenAI SDK // Mock OpenAI client
jest.mock('openai', () => ({ const mockCreate = jest.fn();
jest.mock('openai', () => {
return {
__esModule: true, __esModule: true,
default: jest.fn().mockImplementation(() => ({ default: jest.fn().mockImplementation(() => ({
chat: { chat: {
completions: { completions: {
create: jest.fn() create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: 'test-completion',
choices: [{
message: { role: 'assistant', content: 'Test response' },
finish_reason: 'stop',
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [{
delta: { content: 'Test response' },
index: 0
}],
usage: null
};
yield {
choices: [{
delta: {},
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
};
})
} }
} }
})) }))
})); };
});
describe('OllamaHandler', () => { describe('OllamaHandler', () => {
let handler: OllamaHandler; let handler: OllamaHandler;
let mockOptions: ApiHandlerOptions;
beforeEach(() => { beforeEach(() => {
handler = new OllamaHandler({ mockOptions = {
apiModelId: 'llama2',
ollamaModelId: 'llama2', ollamaModelId: 'llama2',
ollamaBaseUrl: 'http://localhost:11434' ollamaBaseUrl: 'http://localhost:11434/v1'
}); };
handler = new OllamaHandler(mockOptions);
mockCreate.mockClear();
}); });
describe('constructor', () => { describe('constructor', () => {
it('should initialize with provided config', () => { it('should initialize with provided options', () => {
expect(OpenAI).toHaveBeenCalledWith({ expect(handler).toBeInstanceOf(OllamaHandler);
baseURL: 'http://localhost:11434/v1', expect(handler.getModel().id).toBe(mockOptions.ollamaModelId);
apiKey: 'ollama'
});
}); });
it('should use default base URL if not provided', () => { it('should use default base URL if not provided', () => {
const defaultHandler = new OllamaHandler({ const handlerWithoutUrl = new OllamaHandler({
apiModelId: 'llama2',
ollamaModelId: 'llama2' ollamaModelId: 'llama2'
}); });
expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler);
expect(OpenAI).toHaveBeenCalledWith({
baseURL: 'http://localhost:11434/v1',
apiKey: 'ollama'
});
}); });
}); });
describe('createMessage', () => { describe('createMessage', () => {
const mockMessages: Anthropic.Messages.MessageParam[] = [ const systemPrompt = 'You are a helpful assistant.';
const messages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: 'user',
content: 'Hello' content: 'Hello!'
},
{
role: 'assistant',
content: 'Hi there!'
} }
]; ];
const systemPrompt = 'You are a helpful assistant'; it('should handle streaming responses', async () => {
const stream = handler.createMessage(systemPrompt, messages);
it('should handle streaming responses correctly', async () => { const chunks: any[] = [];
const mockStream = [
{
choices: [{
delta: { content: 'Hello' }
}]
},
{
choices: [{
delta: { content: ' world!' }
}]
}
];
// Setup async iterator for mock stream
const asyncIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of mockStream) {
yield chunk;
}
}
};
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
(handler['client'].chat.completions as any).create = mockCreate;
const stream = handler.createMessage(systemPrompt, mockMessages);
const chunks = [];
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk);
} }
expect(chunks.length).toBe(2); expect(chunks.length).toBeGreaterThan(0);
expect(chunks[0]).toEqual({ const textChunks = chunks.filter(chunk => chunk.type === 'text');
type: 'text', expect(textChunks).toHaveLength(1);
text: 'Hello' expect(textChunks[0].text).toBe('Test response');
});
expect(chunks[1]).toEqual({
type: 'text',
text: ' world!'
}); });
expect(mockCreate).toHaveBeenCalledWith({ it('should handle API errors', async () => {
model: 'llama2', mockCreate.mockRejectedValueOnce(new Error('API Error'));
messages: expect.arrayContaining([
{ const stream = handler.createMessage(systemPrompt, messages);
role: 'system',
content: systemPrompt await expect(async () => {
for await (const chunk of stream) {
// Should not reach here
} }
]), }).rejects.toThrow('API Error');
});
});
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.ollamaModelId,
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0, temperature: 0,
stream: true stream: false
}); });
}); });
it('should handle API errors', async () => { it('should handle API errors', async () => {
const mockError = new Error('Ollama API error'); mockCreate.mockRejectedValueOnce(new Error('API Error'));
const mockCreate = jest.fn().mockRejectedValue(mockError); await expect(handler.completePrompt('Test prompt'))
(handler['client'].chat.completions as any).create = mockCreate; .rejects.toThrow('Ollama completion error: API Error');
});
const stream = handler.createMessage(systemPrompt, mockMessages); it('should handle empty response', async () => {
mockCreate.mockResolvedValueOnce({
await expect(async () => { choices: [{ message: { content: '' } }]
for await (const chunk of stream) { });
// Should throw before yielding any chunks const result = await handler.completePrompt('Test prompt');
} expect(result).toBe('');
}).rejects.toThrow('Ollama API error');
}); });
}); });
describe('getModel', () => { describe('getModel', () => {
it('should return model info with sane defaults', () => { it('should return model info', () => {
const modelInfo = handler.getModel(); const modelInfo = handler.getModel();
expect(modelInfo.id).toBe('llama2'); expect(modelInfo.id).toBe(mockOptions.ollamaModelId);
expect(modelInfo.info).toBeDefined(); expect(modelInfo.info).toBeDefined();
expect(modelInfo.info.maxTokens).toBe(-1); expect(modelInfo.info.maxTokens).toBe(-1);
expect(modelInfo.info.contextWindow).toBe(128_000); expect(modelInfo.info.contextWindow).toBe(128_000);
}); });
it('should return empty string as model ID if not provided', () => {
const noModelHandler = new OllamaHandler({});
const modelInfo = noModelHandler.getModel();
expect(modelInfo.id).toBe('');
expect(modelInfo.info).toBeDefined();
});
}); });
}); });

View File

@@ -1,230 +1,209 @@
import { OpenAiNativeHandler } from "../openai-native" import { OpenAiNativeHandler } from '../openai-native';
import OpenAI from "openai" import { ApiHandlerOptions } from '../../../shared/api';
import { ApiHandlerOptions, openAiNativeDefaultModelId } from "../../../shared/api" import OpenAI from 'openai';
import { Anthropic } from "@anthropic-ai/sdk" import { Anthropic } from '@anthropic-ai/sdk';
// Mock OpenAI // Mock OpenAI client
jest.mock("openai") const mockCreate = jest.fn();
jest.mock('openai', () => {
describe("OpenAiNativeHandler", () => { return {
let handler: OpenAiNativeHandler __esModule: true,
let mockOptions: ApiHandlerOptions default: jest.fn().mockImplementation(() => ({
let mockOpenAIClient: jest.Mocked<OpenAI>
let mockCreate: jest.Mock
beforeEach(() => {
// Reset mocks
jest.clearAllMocks()
// Setup mock options
mockOptions = {
openAiNativeApiKey: "test-api-key",
apiModelId: "gpt-4o", // Use the correct model ID from shared/api.ts
}
// Setup mock create function
mockCreate = jest.fn()
// Setup mock OpenAI client
mockOpenAIClient = {
chat: { chat: {
completions: { completions: {
create: mockCreate, create: mockCreate.mockImplementation(async (options) => {
}, if (!options.stream) {
}, return {
} as unknown as jest.Mocked<OpenAI> id: 'test-completion',
choices: [{
// Mock OpenAI constructor message: { role: 'assistant', content: 'Test response' },
;(OpenAI as jest.MockedClass<typeof OpenAI>).mockImplementation(() => mockOpenAIClient) finish_reason: 'stop',
index: 0
// Create handler instance }],
handler = new OpenAiNativeHandler(mockOptions)
})
describe("constructor", () => {
it("should initialize with provided options", () => {
expect(OpenAI).toHaveBeenCalledWith({
apiKey: mockOptions.openAiNativeApiKey,
})
})
})
describe("getModel", () => {
it("should return specified model when valid", () => {
const result = handler.getModel()
expect(result.id).toBe("gpt-4o") // Use the correct model ID
})
it("should return default model when model ID is invalid", () => {
handler = new OpenAiNativeHandler({
...mockOptions,
apiModelId: "invalid-model" as any,
})
const result = handler.getModel()
expect(result.id).toBe(openAiNativeDefaultModelId)
})
it("should return default model when model ID is not provided", () => {
handler = new OpenAiNativeHandler({
...mockOptions,
apiModelId: undefined,
})
const result = handler.getModel()
expect(result.id).toBe(openAiNativeDefaultModelId)
})
})
describe("createMessage", () => {
const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [
{ role: "user", content: "Hello" },
]
describe("o1 models", () => {
beforeEach(() => {
handler = new OpenAiNativeHandler({
...mockOptions,
apiModelId: "o1-preview",
})
})
it("should handle non-streaming response for o1 models", async () => {
const mockResponse = {
choices: [{ message: { content: "Hello there!" } }],
usage: { usage: {
prompt_tokens: 10, prompt_tokens: 10,
completion_tokens: 5, completion_tokens: 5,
}, total_tokens: 15
}
};
} }
mockCreate.mockResolvedValueOnce(mockResponse) return {
[Symbol.asyncIterator]: async function* () {
const generator = handler.createMessage(systemPrompt, messages) yield {
const results = [] choices: [{
for await (const result of generator) { delta: { content: 'Test response' },
results.push(result) index: 0
}],
usage: null
};
yield {
choices: [{
delta: {},
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
} }
};
expect(results).toEqual([
{ type: "text", text: "Hello there!" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
expect(mockCreate).toHaveBeenCalledWith({
model: "o1-preview",
messages: [
{ role: "user", content: systemPrompt },
{ role: "user", content: "Hello" },
],
})
})
it("should handle missing content in response", async () => {
const mockResponse = {
choices: [{ message: { content: null } }],
usage: null,
} }
};
mockCreate.mockResolvedValueOnce(mockResponse) })
const generator = handler.createMessage(systemPrompt, messages)
const results = []
for await (const result of generator) {
results.push(result)
} }
}
}))
};
});
expect(results).toEqual([ describe('OpenAiNativeHandler', () => {
{ type: "text", text: "" }, let handler: OpenAiNativeHandler;
{ type: "usage", inputTokens: 0, outputTokens: 0 }, let mockOptions: ApiHandlerOptions;
])
})
})
describe("streaming models", () => {
beforeEach(() => { beforeEach(() => {
handler = new OpenAiNativeHandler({ mockOptions = {
...mockOptions, apiModelId: 'gpt-4o',
apiModelId: "gpt-4o", openAiNativeApiKey: 'test-api-key'
}) };
}) handler = new OpenAiNativeHandler(mockOptions);
mockCreate.mockClear();
});
it("should handle streaming response", async () => { describe('constructor', () => {
const mockStream = [ it('should initialize with provided options', () => {
{ choices: [{ delta: { content: "Hello" } }], usage: null }, expect(handler).toBeInstanceOf(OpenAiNativeHandler);
{ choices: [{ delta: { content: " there" } }], usage: null }, expect(handler.getModel().id).toBe(mockOptions.apiModelId);
{ choices: [{ delta: { content: "!" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } }, });
]
mockCreate.mockResolvedValueOnce( it('should initialize with empty API key', () => {
(async function* () { const handlerWithoutKey = new OpenAiNativeHandler({
for (const chunk of mockStream) { apiModelId: 'gpt-4o',
yield chunk openAiNativeApiKey: ''
});
expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler);
});
});
describe('createMessage', () => {
const systemPrompt = 'You are a helpful assistant.';
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: 'Hello!'
} }
})() ];
)
const generator = handler.createMessage(systemPrompt, messages) it('should handle streaming responses', async () => {
const results = [] const stream = handler.createMessage(systemPrompt, messages);
for await (const result of generator) { const chunks: any[] = [];
results.push(result) for await (const chunk of stream) {
chunks.push(chunk);
} }
expect(results).toEqual([ expect(chunks.length).toBeGreaterThan(0);
{ type: "text", text: "Hello" }, const textChunks = chunks.filter(chunk => chunk.type === 'text');
{ type: "text", text: " there" }, expect(textChunks).toHaveLength(1);
{ type: "text", text: "!" }, expect(textChunks[0].text).toBe('Test response');
{ type: "usage", inputTokens: 10, outputTokens: 5 }, });
])
expect(mockCreate).toHaveBeenCalledWith({ it('should handle API errors', async () => {
model: "gpt-4o", mockCreate.mockRejectedValueOnce(new Error('API Error'));
temperature: 0,
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: "Hello" },
],
stream: true,
stream_options: { include_usage: true },
})
})
it("should handle empty delta content", async () => { const stream = handler.createMessage(systemPrompt, messages);
const mockStream = [
{ choices: [{ delta: {} }], usage: null },
{ choices: [{ delta: { content: null } }], usage: null },
{ choices: [{ delta: { content: "Hello" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
]
mockCreate.mockResolvedValueOnce(
(async function* () {
for (const chunk of mockStream) {
yield chunk
}
})()
)
const generator = handler.createMessage(systemPrompt, messages)
const results = []
for await (const result of generator) {
results.push(result)
}
expect(results).toEqual([
{ type: "text", text: "Hello" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
})
})
it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error("API Error"))
const generator = handler.createMessage(systemPrompt, messages)
await expect(async () => { await expect(async () => {
for await (const _ of generator) { for await (const chunk of stream) {
// consume generator // Should not reach here
} }
}).rejects.toThrow("API Error") }).rejects.toThrow('API Error');
}) });
}) });
})
describe('completePrompt', () => {
it('should complete prompt successfully with gpt-4o model', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: 'gpt-4o',
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0
});
});
it('should complete prompt successfully with o1 model', async () => {
handler = new OpenAiNativeHandler({
apiModelId: 'o1',
openAiNativeApiKey: 'test-api-key'
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: 'o1',
messages: [{ role: 'user', content: 'Test prompt' }]
});
});
it('should complete prompt successfully with o1-preview model', async () => {
handler = new OpenAiNativeHandler({
apiModelId: 'o1-preview',
openAiNativeApiKey: 'test-api-key'
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: 'o1-preview',
messages: [{ role: 'user', content: 'Test prompt' }]
});
});
it('should complete prompt successfully with o1-mini model', async () => {
handler = new OpenAiNativeHandler({
apiModelId: 'o1-mini',
openAiNativeApiKey: 'test-api-key'
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: 'o1-mini',
messages: [{ role: 'user', content: 'Test prompt' }]
});
});
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('OpenAI Native completion error: API Error');
});
it('should handle empty response', async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: '' } }]
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
describe('getModel', () => {
it('should return model info', () => {
const modelInfo = handler.getModel();
expect(modelInfo.id).toBe(mockOptions.apiModelId);
expect(modelInfo.info).toBeDefined();
expect(modelInfo.info.maxTokens).toBe(4096);
expect(modelInfo.info.contextWindow).toBe(128_000);
});
it('should handle undefined model ID', () => {
const handlerWithoutModel = new OpenAiNativeHandler({
openAiNativeApiKey: 'test-api-key'
});
const modelInfo = handlerWithoutModel.getModel();
expect(modelInfo.id).toBe('gpt-4o'); // Default model
expect(modelInfo.info).toBeDefined();
});
});
});

View File

@@ -176,6 +176,32 @@ describe('OpenAiHandler', () => {
}); });
}); });
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.openAiModelId,
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0
});
});
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('OpenAI completion error: API Error');
});
it('should handle empty response', async () => {
mockCreate.mockImplementationOnce(() => ({
choices: [{ message: { content: '' } }]
}));
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
describe('getModel', () => { describe('getModel', () => {
it('should return model info with sane defaults', () => { it('should return model info with sane defaults', () => {
const model = handler.getModel(); const model = handler.getModel();

View File

@@ -6,7 +6,42 @@ import { AnthropicVertex } from '@anthropic-ai/vertex-sdk';
jest.mock('@anthropic-ai/vertex-sdk', () => ({ jest.mock('@anthropic-ai/vertex-sdk', () => ({
AnthropicVertex: jest.fn().mockImplementation(() => ({ AnthropicVertex: jest.fn().mockImplementation(() => ({
messages: { messages: {
create: jest.fn() create: jest.fn().mockImplementation(async (options) => {
if (!options.stream) {
return {
id: 'test-completion',
content: [
{ type: 'text', text: 'Test response' }
],
role: 'assistant',
model: options.model,
usage: {
input_tokens: 10,
output_tokens: 5
}
}
}
return {
async *[Symbol.asyncIterator]() {
yield {
type: 'message_start',
message: {
usage: {
input_tokens: 10,
output_tokens: 5
}
}
}
yield {
type: 'content_block_start',
content_block: {
type: 'text',
text: 'Test response'
}
}
}
}
})
} }
})) }))
})); }));
@@ -196,6 +231,49 @@ describe('VertexHandler', () => {
}); });
}); });
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(handler['client'].messages.create).toHaveBeenCalledWith({
model: 'claude-3-5-sonnet-v2@20241022',
max_tokens: 8192,
temperature: 0,
messages: [{ role: 'user', content: 'Test prompt' }],
stream: false
});
});
it('should handle API errors', async () => {
const mockError = new Error('Vertex API error');
const mockCreate = jest.fn().mockRejectedValue(mockError);
(handler['client'].messages as any).create = mockCreate;
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Vertex completion error: Vertex API error');
});
it('should handle non-text content', async () => {
const mockCreate = jest.fn().mockResolvedValue({
content: [{ type: 'image' }]
});
(handler['client'].messages as any).create = mockCreate;
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
it('should handle empty response', async () => {
const mockCreate = jest.fn().mockResolvedValue({
content: [{ type: 'text', text: '' }]
});
(handler['client'].messages as any).create = mockCreate;
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
describe('getModel', () => { describe('getModel', () => {
it('should return correct model info', () => { it('should return correct model info', () => {
const modelInfo = handler.getModel(); const modelInfo = handler.getModel();

View File

@@ -7,10 +7,10 @@ import {
ApiHandlerOptions, ApiHandlerOptions,
ModelInfo, ModelInfo,
} from "../../shared/api" } from "../../shared/api"
import { ApiHandler } from "../index" import { ApiHandler, SingleCompletionHandler } from "../index"
import { ApiStream } from "../transform/stream" import { ApiStream } from "../transform/stream"
export class AnthropicHandler implements ApiHandler { export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions private options: ApiHandlerOptions
private client: Anthropic private client: Anthropic
@@ -173,4 +173,27 @@ export class AnthropicHandler implements ApiHandler {
} }
return { id: anthropicDefaultModelId, info: anthropicModels[anthropicDefaultModelId] } return { id: anthropicDefaultModelId, info: anthropicModels[anthropicDefaultModelId] }
} }
async completePrompt(prompt: string): Promise<string> {
try {
const response = await this.client.messages.create({
model: this.getModel().id,
max_tokens: this.getModel().info.maxTokens || 8192,
temperature: 0,
messages: [{ role: "user", content: prompt }],
stream: false
})
const content = response.content[0]
if (content.type === 'text') {
return content.text
}
return ''
} catch (error) {
if (error instanceof Error) {
throw new Error(`Anthropic completion error: ${error.message}`)
}
throw error
}
}
} }

View File

@@ -1,6 +1,6 @@
import { BedrockRuntimeClient, ConverseStreamCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime" import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime"
import { Anthropic } from "@anthropic-ai/sdk" import { Anthropic } from "@anthropic-ai/sdk"
import { ApiHandler } from "../" import { ApiHandler, SingleCompletionHandler } from "../"
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api" import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
import { ApiStream } from "../transform/stream" import { ApiStream } from "../transform/stream"
import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../transform/bedrock-converse-format" import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../transform/bedrock-converse-format"
@@ -38,7 +38,7 @@ export interface StreamEvent {
}; };
} }
export class AwsBedrockHandler implements ApiHandler { export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions private options: ApiHandlerOptions
private client: BedrockRuntimeClient private client: BedrockRuntimeClient
@@ -219,4 +219,63 @@ export class AwsBedrockHandler implements ApiHandler {
info: bedrockModels[bedrockDefaultModelId] info: bedrockModels[bedrockDefaultModelId]
} }
} }
async completePrompt(prompt: string): Promise<string> {
try {
const modelConfig = this.getModel()
// Handle cross-region inference
let modelId: string
if (this.options.awsUseCrossRegionInference) {
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
switch (regionPrefix) {
case "us-":
modelId = `us.${modelConfig.id}`
break
case "eu-":
modelId = `eu.${modelConfig.id}`
break
default:
modelId = modelConfig.id
break
}
} else {
modelId = modelConfig.id
}
const payload = {
modelId,
messages: convertToBedrockConverseMessages([{
role: "user",
content: prompt
}]),
inferenceConfig: {
maxTokens: modelConfig.info.maxTokens || 5000,
temperature: 0.3,
topP: 0.1
}
}
const command = new ConverseCommand(payload)
const response = await this.client.send(command)
if (response.output && response.output instanceof Uint8Array) {
try {
const outputStr = new TextDecoder().decode(response.output)
const output = JSON.parse(outputStr)
if (output.content) {
return output.content
}
} catch (parseError) {
console.error('Failed to parse Bedrock response:', parseError)
}
}
return ''
} catch (error) {
if (error instanceof Error) {
throw new Error(`Bedrock completion error: ${error.message}`)
}
throw error
}
}
} }

View File

@@ -1,11 +1,11 @@
import { Anthropic } from "@anthropic-ai/sdk" import { Anthropic } from "@anthropic-ai/sdk"
import { GoogleGenerativeAI } from "@google/generative-ai" import { GoogleGenerativeAI } from "@google/generative-ai"
import { ApiHandler } from "../" import { ApiHandler, SingleCompletionHandler } from "../"
import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../../shared/api" import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../../shared/api"
import { convertAnthropicMessageToGemini } from "../transform/gemini-format" import { convertAnthropicMessageToGemini } from "../transform/gemini-format"
import { ApiStream } from "../transform/stream" import { ApiStream } from "../transform/stream"
export class GeminiHandler implements ApiHandler { export class GeminiHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions private options: ApiHandlerOptions
private client: GoogleGenerativeAI private client: GoogleGenerativeAI
@@ -53,4 +53,26 @@ export class GeminiHandler implements ApiHandler {
} }
return { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] } return { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] }
} }
async completePrompt(prompt: string): Promise<string> {
try {
const model = this.client.getGenerativeModel({
model: this.getModel().id,
})
const result = await model.generateContent({
contents: [{ role: "user", parts: [{ text: prompt }] }],
generationConfig: {
temperature: 0,
},
})
return result.response.text()
} catch (error) {
if (error instanceof Error) {
throw new Error(`Gemini completion error: ${error.message}`)
}
throw error
}
}
} }

View File

@@ -1,13 +1,13 @@
import { Anthropic } from "@anthropic-ai/sdk" import { Anthropic } from "@anthropic-ai/sdk"
import axios from "axios" import axios from "axios"
import OpenAI from "openai" import OpenAI from "openai"
import { ApiHandler } from "../" import { ApiHandler, SingleCompletionHandler } from "../"
import { ApiHandlerOptions, ModelInfo, glamaDefaultModelId, glamaDefaultModelInfo } from "../../shared/api" import { ApiHandlerOptions, ModelInfo, glamaDefaultModelId, glamaDefaultModelInfo } from "../../shared/api"
import { convertToOpenAiMessages } from "../transform/openai-format" import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream" import { ApiStream } from "../transform/stream"
import delay from "delay" import delay from "delay"
export class GlamaHandler implements ApiHandler { export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions private options: ApiHandlerOptions
private client: OpenAI private client: OpenAI
@@ -129,4 +129,26 @@ export class GlamaHandler implements ApiHandler {
return { id: glamaDefaultModelId, info: glamaDefaultModelInfo } return { id: glamaDefaultModelId, info: glamaDefaultModelInfo }
} }
async completePrompt(prompt: string): Promise<string> {
try {
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: this.getModel().id,
messages: [{ role: "user", content: prompt }],
temperature: 0,
}
if (this.getModel().id.startsWith("anthropic/")) {
requestOptions.max_tokens = 8192
}
const response = await this.client.chat.completions.create(requestOptions)
return response.choices[0]?.message.content || ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`Glama completion error: ${error.message}`)
}
throw error
}
}
} }

View File

@@ -1,11 +1,11 @@
import { Anthropic } from "@anthropic-ai/sdk" import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai" import OpenAI from "openai"
import { ApiHandler } from "../" import { ApiHandler, SingleCompletionHandler } from "../"
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api" import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
import { convertToOpenAiMessages } from "../transform/openai-format" import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream" import { ApiStream } from "../transform/stream"
export class LmStudioHandler implements ApiHandler { export class LmStudioHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions private options: ApiHandlerOptions
private client: OpenAI private client: OpenAI
@@ -53,4 +53,20 @@ export class LmStudioHandler implements ApiHandler {
info: openAiModelInfoSaneDefaults, info: openAiModelInfoSaneDefaults,
} }
} }
async completePrompt(prompt: string): Promise<string> {
try {
const response = await this.client.chat.completions.create({
model: this.getModel().id,
messages: [{ role: "user", content: prompt }],
temperature: 0,
stream: false
})
return response.choices[0]?.message.content || ""
} catch (error) {
throw new Error(
"Please check the LM Studio developer logs to debug what went wrong. You may need to load the model with a larger context length to work with Cline's prompts.",
)
}
}
} }

View File

@@ -1,11 +1,11 @@
import { Anthropic } from "@anthropic-ai/sdk" import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai" import OpenAI from "openai"
import { ApiHandler } from "../" import { ApiHandler, SingleCompletionHandler } from "../"
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api" import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
import { convertToOpenAiMessages } from "../transform/openai-format" import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream" import { ApiStream } from "../transform/stream"
export class OllamaHandler implements ApiHandler { export class OllamaHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions private options: ApiHandlerOptions
private client: OpenAI private client: OpenAI
@@ -46,4 +46,21 @@ export class OllamaHandler implements ApiHandler {
info: openAiModelInfoSaneDefaults, info: openAiModelInfoSaneDefaults,
} }
} }
async completePrompt(prompt: string): Promise<string> {
try {
const response = await this.client.chat.completions.create({
model: this.getModel().id,
messages: [{ role: "user", content: prompt }],
temperature: 0,
stream: false
})
return response.choices[0]?.message.content || ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`Ollama completion error: ${error.message}`)
}
throw error
}
}
} }

View File

@@ -1,6 +1,6 @@
import { Anthropic } from "@anthropic-ai/sdk" import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai" import OpenAI from "openai"
import { ApiHandler } from "../" import { ApiHandler, SingleCompletionHandler } from "../"
import { import {
ApiHandlerOptions, ApiHandlerOptions,
ModelInfo, ModelInfo,
@@ -11,7 +11,7 @@ import {
import { convertToOpenAiMessages } from "../transform/openai-format" import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream" import { ApiStream } from "../transform/stream"
export class OpenAiNativeHandler implements ApiHandler { export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions private options: ApiHandlerOptions
private client: OpenAI private client: OpenAI
@@ -83,4 +83,37 @@ export class OpenAiNativeHandler implements ApiHandler {
} }
return { id: openAiNativeDefaultModelId, info: openAiNativeModels[openAiNativeDefaultModelId] } return { id: openAiNativeDefaultModelId, info: openAiNativeModels[openAiNativeDefaultModelId] }
} }
async completePrompt(prompt: string): Promise<string> {
try {
const modelId = this.getModel().id
let requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
switch (modelId) {
case "o1":
case "o1-preview":
case "o1-mini":
// o1 doesn't support non-1 temp or system prompt
requestOptions = {
model: modelId,
messages: [{ role: "user", content: prompt }]
}
break
default:
requestOptions = {
model: modelId,
messages: [{ role: "user", content: prompt }],
temperature: 0
}
}
const response = await this.client.chat.completions.create(requestOptions)
return response.choices[0]?.message.content || ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`OpenAI Native completion error: ${error.message}`)
}
throw error
}
}
} }

View File

@@ -6,11 +6,11 @@ import {
ModelInfo, ModelInfo,
openAiModelInfoSaneDefaults, openAiModelInfoSaneDefaults,
} from "../../shared/api" } from "../../shared/api"
import { ApiHandler } from "../index" import { ApiHandler, SingleCompletionHandler } from "../index"
import { convertToOpenAiMessages } from "../transform/openai-format" import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream" import { ApiStream } from "../transform/stream"
export class OpenAiHandler implements ApiHandler { export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
protected options: ApiHandlerOptions protected options: ApiHandlerOptions
private client: OpenAI private client: OpenAI
@@ -100,4 +100,22 @@ export class OpenAiHandler implements ApiHandler {
info: openAiModelInfoSaneDefaults, info: openAiModelInfoSaneDefaults,
} }
} }
async completePrompt(prompt: string): Promise<string> {
try {
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: this.getModel().id,
messages: [{ role: "user", content: prompt }],
temperature: 0,
}
const response = await this.client.chat.completions.create(requestOptions)
return response.choices[0]?.message.content || ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`OpenAI completion error: ${error.message}`)
}
throw error
}
}
} }

View File

@@ -1,11 +1,11 @@
import { Anthropic } from "@anthropic-ai/sdk" import { Anthropic } from "@anthropic-ai/sdk"
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk" import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
import { ApiHandler } from "../" import { ApiHandler, SingleCompletionHandler } from "../"
import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../../shared/api" import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../../shared/api"
import { ApiStream } from "../transform/stream" import { ApiStream } from "../transform/stream"
// https://docs.anthropic.com/en/api/claude-on-vertex-ai // https://docs.anthropic.com/en/api/claude-on-vertex-ai
export class VertexHandler implements ApiHandler { export class VertexHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions private options: ApiHandlerOptions
private client: AnthropicVertex private client: AnthropicVertex
@@ -83,4 +83,27 @@ export class VertexHandler implements ApiHandler {
} }
return { id: vertexDefaultModelId, info: vertexModels[vertexDefaultModelId] } return { id: vertexDefaultModelId, info: vertexModels[vertexDefaultModelId] }
} }
async completePrompt(prompt: string): Promise<string> {
try {
const response = await this.client.messages.create({
model: this.getModel().id,
max_tokens: this.getModel().info.maxTokens || 8192,
temperature: 0,
messages: [{ role: "user", content: prompt }],
stream: false
})
const content = response.content[0]
if (content.type === 'text') {
return content.text
}
return ''
} catch (error) {
if (error instanceof Error) {
throw new Error(`Vertex completion error: ${error.message}`)
}
throw error
}
}
} }