mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-20 12:21:13 -05:00
Add non-streaming completePrompt to all providers
This commit is contained in:
@@ -46,7 +46,42 @@ jest.mock('@anthropic-ai/sdk', () => {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
messages: {
|
messages: {
|
||||||
create: mockCreate
|
create: mockCreate.mockImplementation(async (options) => {
|
||||||
|
if (!options.stream) {
|
||||||
|
return {
|
||||||
|
id: 'test-completion',
|
||||||
|
content: [
|
||||||
|
{ type: 'text', text: 'Test response' }
|
||||||
|
],
|
||||||
|
role: 'assistant',
|
||||||
|
model: options.model,
|
||||||
|
usage: {
|
||||||
|
input_tokens: 10,
|
||||||
|
output_tokens: 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
yield {
|
||||||
|
type: 'message_start',
|
||||||
|
message: {
|
||||||
|
usage: {
|
||||||
|
input_tokens: 10,
|
||||||
|
output_tokens: 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
yield {
|
||||||
|
type: 'content_block_start',
|
||||||
|
content_block: {
|
||||||
|
type: 'text',
|
||||||
|
text: 'Test response'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
};
|
};
|
||||||
@@ -144,6 +179,42 @@ describe('AnthropicHandler', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('completePrompt', () => {
|
||||||
|
it('should complete prompt successfully', async () => {
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('Test response');
|
||||||
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
|
model: mockOptions.apiModelId,
|
||||||
|
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||||
|
max_tokens: 8192,
|
||||||
|
temperature: 0,
|
||||||
|
stream: false
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle API errors', async () => {
|
||||||
|
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||||
|
await expect(handler.completePrompt('Test prompt'))
|
||||||
|
.rejects.toThrow('Anthropic completion error: API Error');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle non-text content', async () => {
|
||||||
|
mockCreate.mockImplementationOnce(async () => ({
|
||||||
|
content: [{ type: 'image' }]
|
||||||
|
}));
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle empty response', async () => {
|
||||||
|
mockCreate.mockImplementationOnce(async () => ({
|
||||||
|
content: [{ type: 'text', text: '' }]
|
||||||
|
}));
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('getModel', () => {
|
describe('getModel', () => {
|
||||||
it('should return default model if no model ID is provided', () => {
|
it('should return default model if no model ID is provided', () => {
|
||||||
const handlerWithoutModel = new AnthropicHandler({
|
const handlerWithoutModel = new AnthropicHandler({
|
||||||
|
|||||||
@@ -119,6 +119,108 @@ describe('AwsBedrockHandler', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('completePrompt', () => {
|
||||||
|
it('should complete prompt successfully', async () => {
|
||||||
|
const mockResponse = {
|
||||||
|
output: new TextEncoder().encode(JSON.stringify({
|
||||||
|
content: 'Test response'
|
||||||
|
}))
|
||||||
|
};
|
||||||
|
|
||||||
|
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||||
|
handler['client'] = {
|
||||||
|
send: mockSend
|
||||||
|
} as unknown as BedrockRuntimeClient;
|
||||||
|
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('Test response');
|
||||||
|
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
|
||||||
|
input: expect.objectContaining({
|
||||||
|
modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||||
|
messages: expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
role: 'user',
|
||||||
|
content: [{ text: 'Test prompt' }]
|
||||||
|
})
|
||||||
|
]),
|
||||||
|
inferenceConfig: expect.objectContaining({
|
||||||
|
maxTokens: 5000,
|
||||||
|
temperature: 0.3,
|
||||||
|
topP: 0.1
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}));
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle API errors', async () => {
|
||||||
|
const mockError = new Error('AWS Bedrock error');
|
||||||
|
const mockSend = jest.fn().mockRejectedValue(mockError);
|
||||||
|
handler['client'] = {
|
||||||
|
send: mockSend
|
||||||
|
} as unknown as BedrockRuntimeClient;
|
||||||
|
|
||||||
|
await expect(handler.completePrompt('Test prompt'))
|
||||||
|
.rejects.toThrow('Bedrock completion error: AWS Bedrock error');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle invalid response format', async () => {
|
||||||
|
const mockResponse = {
|
||||||
|
output: new TextEncoder().encode('invalid json')
|
||||||
|
};
|
||||||
|
|
||||||
|
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||||
|
handler['client'] = {
|
||||||
|
send: mockSend
|
||||||
|
} as unknown as BedrockRuntimeClient;
|
||||||
|
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle empty response', async () => {
|
||||||
|
const mockResponse = {
|
||||||
|
output: new TextEncoder().encode(JSON.stringify({}))
|
||||||
|
};
|
||||||
|
|
||||||
|
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||||
|
handler['client'] = {
|
||||||
|
send: mockSend
|
||||||
|
} as unknown as BedrockRuntimeClient;
|
||||||
|
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle cross-region inference', async () => {
|
||||||
|
handler = new AwsBedrockHandler({
|
||||||
|
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||||
|
awsAccessKey: 'test-access-key',
|
||||||
|
awsSecretKey: 'test-secret-key',
|
||||||
|
awsRegion: 'us-east-1',
|
||||||
|
awsUseCrossRegionInference: true
|
||||||
|
});
|
||||||
|
|
||||||
|
const mockResponse = {
|
||||||
|
output: new TextEncoder().encode(JSON.stringify({
|
||||||
|
content: 'Test response'
|
||||||
|
}))
|
||||||
|
};
|
||||||
|
|
||||||
|
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||||
|
handler['client'] = {
|
||||||
|
send: mockSend
|
||||||
|
} as unknown as BedrockRuntimeClient;
|
||||||
|
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('Test response');
|
||||||
|
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
|
||||||
|
input: expect.objectContaining({
|
||||||
|
modelId: 'us.anthropic.claude-3-5-sonnet-20241022-v2:0'
|
||||||
|
})
|
||||||
|
}));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('getModel', () => {
|
describe('getModel', () => {
|
||||||
it('should return correct model info in test environment', () => {
|
it('should return correct model info in test environment', () => {
|
||||||
const modelInfo = handler.getModel();
|
const modelInfo = handler.getModel();
|
||||||
|
|||||||
@@ -6,7 +6,12 @@ import { GoogleGenerativeAI } from '@google/generative-ai';
|
|||||||
jest.mock('@google/generative-ai', () => ({
|
jest.mock('@google/generative-ai', () => ({
|
||||||
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
|
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
|
||||||
getGenerativeModel: jest.fn().mockReturnValue({
|
getGenerativeModel: jest.fn().mockReturnValue({
|
||||||
generateContentStream: jest.fn()
|
generateContentStream: jest.fn(),
|
||||||
|
generateContent: jest.fn().mockResolvedValue({
|
||||||
|
response: {
|
||||||
|
text: () => 'Test response'
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}))
|
}))
|
||||||
}));
|
}));
|
||||||
@@ -133,6 +138,59 @@ describe('GeminiHandler', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('completePrompt', () => {
|
||||||
|
it('should complete prompt successfully', async () => {
|
||||||
|
const mockGenerateContent = jest.fn().mockResolvedValue({
|
||||||
|
response: {
|
||||||
|
text: () => 'Test response'
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||||
|
generateContent: mockGenerateContent
|
||||||
|
});
|
||||||
|
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||||
|
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('Test response');
|
||||||
|
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
|
||||||
|
model: 'gemini-2.0-flash-thinking-exp-1219'
|
||||||
|
});
|
||||||
|
expect(mockGenerateContent).toHaveBeenCalledWith({
|
||||||
|
contents: [{ role: 'user', parts: [{ text: 'Test prompt' }] }],
|
||||||
|
generationConfig: {
|
||||||
|
temperature: 0
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle API errors', async () => {
|
||||||
|
const mockError = new Error('Gemini API error');
|
||||||
|
const mockGenerateContent = jest.fn().mockRejectedValue(mockError);
|
||||||
|
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||||
|
generateContent: mockGenerateContent
|
||||||
|
});
|
||||||
|
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||||
|
|
||||||
|
await expect(handler.completePrompt('Test prompt'))
|
||||||
|
.rejects.toThrow('Gemini completion error: Gemini API error');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle empty response', async () => {
|
||||||
|
const mockGenerateContent = jest.fn().mockResolvedValue({
|
||||||
|
response: {
|
||||||
|
text: () => ''
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||||
|
generateContent: mockGenerateContent
|
||||||
|
});
|
||||||
|
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||||
|
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('getModel', () => {
|
describe('getModel', () => {
|
||||||
it('should return correct model info', () => {
|
it('should return correct model info', () => {
|
||||||
const modelInfo = handler.getModel();
|
const modelInfo = handler.getModel();
|
||||||
|
|||||||
226
src/api/providers/__tests__/glama.test.ts
Normal file
226
src/api/providers/__tests__/glama.test.ts
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
import { GlamaHandler } from '../glama';
|
||||||
|
import { ApiHandlerOptions } from '../../../shared/api';
|
||||||
|
import OpenAI from 'openai';
|
||||||
|
import { Anthropic } from '@anthropic-ai/sdk';
|
||||||
|
import axios from 'axios';
|
||||||
|
|
||||||
|
// Mock OpenAI client
|
||||||
|
const mockCreate = jest.fn();
|
||||||
|
const mockWithResponse = jest.fn();
|
||||||
|
|
||||||
|
jest.mock('openai', () => {
|
||||||
|
return {
|
||||||
|
__esModule: true,
|
||||||
|
default: jest.fn().mockImplementation(() => ({
|
||||||
|
chat: {
|
||||||
|
completions: {
|
||||||
|
create: (...args: any[]) => {
|
||||||
|
const stream = {
|
||||||
|
[Symbol.asyncIterator]: async function* () {
|
||||||
|
yield {
|
||||||
|
choices: [{
|
||||||
|
delta: { content: 'Test response' },
|
||||||
|
index: 0
|
||||||
|
}],
|
||||||
|
usage: null
|
||||||
|
};
|
||||||
|
yield {
|
||||||
|
choices: [{
|
||||||
|
delta: {},
|
||||||
|
index: 0
|
||||||
|
}],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 5,
|
||||||
|
total_tokens: 15
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const result = mockCreate(...args);
|
||||||
|
if (args[0].stream) {
|
||||||
|
mockWithResponse.mockReturnValue(Promise.resolve({
|
||||||
|
data: stream,
|
||||||
|
response: {
|
||||||
|
headers: {
|
||||||
|
get: (name: string) => name === 'x-completion-request-id' ? 'test-request-id' : null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
result.withResponse = mockWithResponse;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('GlamaHandler', () => {
|
||||||
|
let handler: GlamaHandler;
|
||||||
|
let mockOptions: ApiHandlerOptions;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockOptions = {
|
||||||
|
apiModelId: 'anthropic/claude-3-5-sonnet',
|
||||||
|
glamaModelId: 'anthropic/claude-3-5-sonnet',
|
||||||
|
glamaApiKey: 'test-api-key'
|
||||||
|
};
|
||||||
|
handler = new GlamaHandler(mockOptions);
|
||||||
|
mockCreate.mockClear();
|
||||||
|
mockWithResponse.mockClear();
|
||||||
|
|
||||||
|
// Default mock implementation for non-streaming responses
|
||||||
|
mockCreate.mockResolvedValue({
|
||||||
|
id: 'test-completion',
|
||||||
|
choices: [{
|
||||||
|
message: { role: 'assistant', content: 'Test response' },
|
||||||
|
finish_reason: 'stop',
|
||||||
|
index: 0
|
||||||
|
}],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 5,
|
||||||
|
total_tokens: 15
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('constructor', () => {
|
||||||
|
it('should initialize with provided options', () => {
|
||||||
|
expect(handler).toBeInstanceOf(GlamaHandler);
|
||||||
|
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('createMessage', () => {
|
||||||
|
const systemPrompt = 'You are a helpful assistant.';
|
||||||
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: 'Hello!'
|
||||||
|
}
|
||||||
|
];
|
||||||
|
|
||||||
|
it('should handle streaming responses', async () => {
|
||||||
|
// Mock axios for token usage request
|
||||||
|
const mockAxios = jest.spyOn(axios, 'get').mockResolvedValueOnce({
|
||||||
|
data: {
|
||||||
|
tokenUsage: {
|
||||||
|
promptTokens: 10,
|
||||||
|
completionTokens: 5,
|
||||||
|
cacheCreationInputTokens: 0,
|
||||||
|
cacheReadInputTokens: 0
|
||||||
|
},
|
||||||
|
totalCostUsd: "0.00"
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const stream = handler.createMessage(systemPrompt, messages);
|
||||||
|
const chunks: any[] = [];
|
||||||
|
for await (const chunk of stream) {
|
||||||
|
chunks.push(chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(chunks.length).toBe(2); // Text chunk and usage chunk
|
||||||
|
expect(chunks[0]).toEqual({
|
||||||
|
type: 'text',
|
||||||
|
text: 'Test response'
|
||||||
|
});
|
||||||
|
expect(chunks[1]).toEqual({
|
||||||
|
type: 'usage',
|
||||||
|
inputTokens: 10,
|
||||||
|
outputTokens: 5,
|
||||||
|
cacheWriteTokens: 0,
|
||||||
|
cacheReadTokens: 0,
|
||||||
|
totalCost: 0
|
||||||
|
});
|
||||||
|
|
||||||
|
mockAxios.mockRestore();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle API errors', async () => {
|
||||||
|
mockCreate.mockImplementationOnce(() => {
|
||||||
|
throw new Error('API Error');
|
||||||
|
});
|
||||||
|
|
||||||
|
const stream = handler.createMessage(systemPrompt, messages);
|
||||||
|
const chunks = [];
|
||||||
|
|
||||||
|
try {
|
||||||
|
for await (const chunk of stream) {
|
||||||
|
chunks.push(chunk);
|
||||||
|
}
|
||||||
|
fail('Expected error to be thrown');
|
||||||
|
} catch (error) {
|
||||||
|
expect(error).toBeInstanceOf(Error);
|
||||||
|
expect(error.message).toBe('API Error');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('completePrompt', () => {
|
||||||
|
it('should complete prompt successfully', async () => {
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('Test response');
|
||||||
|
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||||
|
model: mockOptions.apiModelId,
|
||||||
|
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||||
|
temperature: 0,
|
||||||
|
max_tokens: 8192
|
||||||
|
}));
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle API errors', async () => {
|
||||||
|
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||||
|
await expect(handler.completePrompt('Test prompt'))
|
||||||
|
.rejects.toThrow('Glama completion error: API Error');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle empty response', async () => {
|
||||||
|
mockCreate.mockResolvedValueOnce({
|
||||||
|
choices: [{ message: { content: '' } }]
|
||||||
|
});
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not set max_tokens for non-Anthropic models', async () => {
|
||||||
|
// Reset mock to clear any previous calls
|
||||||
|
mockCreate.mockClear();
|
||||||
|
|
||||||
|
const nonAnthropicOptions = {
|
||||||
|
apiModelId: 'openai/gpt-4',
|
||||||
|
glamaModelId: 'openai/gpt-4',
|
||||||
|
glamaApiKey: 'test-key',
|
||||||
|
glamaModelInfo: {
|
||||||
|
maxTokens: 4096,
|
||||||
|
contextWindow: 8192,
|
||||||
|
supportsImages: true,
|
||||||
|
supportsPromptCache: false
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions);
|
||||||
|
|
||||||
|
await nonAnthropicHandler.completePrompt('Test prompt');
|
||||||
|
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||||
|
model: 'openai/gpt-4',
|
||||||
|
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||||
|
temperature: 0
|
||||||
|
}));
|
||||||
|
expect(mockCreate.mock.calls[0][0]).not.toHaveProperty('max_tokens');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('getModel', () => {
|
||||||
|
it('should return model info', () => {
|
||||||
|
const modelInfo = handler.getModel();
|
||||||
|
expect(modelInfo.id).toBe(mockOptions.apiModelId);
|
||||||
|
expect(modelInfo.info).toBeDefined();
|
||||||
|
expect(modelInfo.info.maxTokens).toBe(8192);
|
||||||
|
expect(modelInfo.info.contextWindow).toBe(200_000);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -1,148 +1,160 @@
|
|||||||
import { LmStudioHandler } from '../lmstudio';
|
import { LmStudioHandler } from '../lmstudio';
|
||||||
import { Anthropic } from '@anthropic-ai/sdk';
|
import { ApiHandlerOptions } from '../../../shared/api';
|
||||||
import OpenAI from 'openai';
|
import OpenAI from 'openai';
|
||||||
|
import { Anthropic } from '@anthropic-ai/sdk';
|
||||||
|
|
||||||
// Mock OpenAI SDK
|
// Mock OpenAI client
|
||||||
jest.mock('openai', () => ({
|
const mockCreate = jest.fn();
|
||||||
|
jest.mock('openai', () => {
|
||||||
|
return {
|
||||||
__esModule: true,
|
__esModule: true,
|
||||||
default: jest.fn().mockImplementation(() => ({
|
default: jest.fn().mockImplementation(() => ({
|
||||||
chat: {
|
chat: {
|
||||||
completions: {
|
completions: {
|
||||||
create: jest.fn()
|
create: mockCreate.mockImplementation(async (options) => {
|
||||||
|
if (!options.stream) {
|
||||||
|
return {
|
||||||
|
id: 'test-completion',
|
||||||
|
choices: [{
|
||||||
|
message: { role: 'assistant', content: 'Test response' },
|
||||||
|
finish_reason: 'stop',
|
||||||
|
index: 0
|
||||||
|
}],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 5,
|
||||||
|
total_tokens: 15
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
[Symbol.asyncIterator]: async function* () {
|
||||||
|
yield {
|
||||||
|
choices: [{
|
||||||
|
delta: { content: 'Test response' },
|
||||||
|
index: 0
|
||||||
|
}],
|
||||||
|
usage: null
|
||||||
|
};
|
||||||
|
yield {
|
||||||
|
choices: [{
|
||||||
|
delta: {},
|
||||||
|
index: 0
|
||||||
|
}],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 5,
|
||||||
|
total_tokens: 15
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
}));
|
};
|
||||||
|
});
|
||||||
|
|
||||||
describe('LmStudioHandler', () => {
|
describe('LmStudioHandler', () => {
|
||||||
let handler: LmStudioHandler;
|
let handler: LmStudioHandler;
|
||||||
|
let mockOptions: ApiHandlerOptions;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
handler = new LmStudioHandler({
|
mockOptions = {
|
||||||
lmStudioModelId: 'mistral-7b',
|
apiModelId: 'local-model',
|
||||||
lmStudioBaseUrl: 'http://localhost:1234'
|
lmStudioModelId: 'local-model',
|
||||||
});
|
lmStudioBaseUrl: 'http://localhost:1234/v1'
|
||||||
|
};
|
||||||
|
handler = new LmStudioHandler(mockOptions);
|
||||||
|
mockCreate.mockClear();
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('constructor', () => {
|
describe('constructor', () => {
|
||||||
it('should initialize with provided config', () => {
|
it('should initialize with provided options', () => {
|
||||||
expect(OpenAI).toHaveBeenCalledWith({
|
expect(handler).toBeInstanceOf(LmStudioHandler);
|
||||||
baseURL: 'http://localhost:1234/v1',
|
expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId);
|
||||||
apiKey: 'noop'
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should use default base URL if not provided', () => {
|
it('should use default base URL if not provided', () => {
|
||||||
const defaultHandler = new LmStudioHandler({
|
const handlerWithoutUrl = new LmStudioHandler({
|
||||||
lmStudioModelId: 'mistral-7b'
|
apiModelId: 'local-model',
|
||||||
});
|
lmStudioModelId: 'local-model'
|
||||||
|
|
||||||
expect(OpenAI).toHaveBeenCalledWith({
|
|
||||||
baseURL: 'http://localhost:1234/v1',
|
|
||||||
apiKey: 'noop'
|
|
||||||
});
|
});
|
||||||
|
expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('createMessage', () => {
|
describe('createMessage', () => {
|
||||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
const systemPrompt = 'You are a helpful assistant.';
|
||||||
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content: 'Hello'
|
content: 'Hello!'
|
||||||
},
|
|
||||||
{
|
|
||||||
role: 'assistant',
|
|
||||||
content: 'Hi there!'
|
|
||||||
}
|
}
|
||||||
];
|
];
|
||||||
|
|
||||||
const systemPrompt = 'You are a helpful assistant';
|
it('should handle streaming responses', async () => {
|
||||||
|
const stream = handler.createMessage(systemPrompt, messages);
|
||||||
it('should handle streaming responses correctly', async () => {
|
const chunks: any[] = [];
|
||||||
const mockStream = [
|
|
||||||
{
|
|
||||||
choices: [{
|
|
||||||
delta: { content: 'Hello' }
|
|
||||||
}]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
choices: [{
|
|
||||||
delta: { content: ' world!' }
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
];
|
|
||||||
|
|
||||||
// Setup async iterator for mock stream
|
|
||||||
const asyncIterator = {
|
|
||||||
async *[Symbol.asyncIterator]() {
|
|
||||||
for (const chunk of mockStream) {
|
|
||||||
yield chunk;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
|
|
||||||
(handler['client'].chat.completions as any).create = mockCreate;
|
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
|
||||||
const chunks = [];
|
|
||||||
|
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
chunks.push(chunk);
|
chunks.push(chunk);
|
||||||
}
|
}
|
||||||
|
|
||||||
expect(chunks.length).toBe(2);
|
expect(chunks.length).toBeGreaterThan(0);
|
||||||
expect(chunks[0]).toEqual({
|
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||||
type: 'text',
|
expect(textChunks).toHaveLength(1);
|
||||||
text: 'Hello'
|
expect(textChunks[0].text).toBe('Test response');
|
||||||
});
|
|
||||||
expect(chunks[1]).toEqual({
|
|
||||||
type: 'text',
|
|
||||||
text: ' world!'
|
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockCreate).toHaveBeenCalledWith({
|
it('should handle API errors', async () => {
|
||||||
model: 'mistral-7b',
|
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||||
messages: expect.arrayContaining([
|
|
||||||
{
|
|
||||||
role: 'system',
|
|
||||||
content: systemPrompt
|
|
||||||
}
|
|
||||||
]),
|
|
||||||
temperature: 0,
|
|
||||||
stream: true
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle API errors with custom message', async () => {
|
const stream = handler.createMessage(systemPrompt, messages);
|
||||||
const mockError = new Error('LM Studio API error');
|
|
||||||
const mockCreate = jest.fn().mockRejectedValue(mockError);
|
|
||||||
(handler['client'].chat.completions as any).create = mockCreate;
|
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
|
||||||
|
|
||||||
await expect(async () => {
|
await expect(async () => {
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
// Should throw before yielding any chunks
|
// Should not reach here
|
||||||
}
|
}
|
||||||
}).rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
|
}).rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('completePrompt', () => {
|
||||||
|
it('should complete prompt successfully', async () => {
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('Test response');
|
||||||
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
|
model: mockOptions.lmStudioModelId,
|
||||||
|
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||||
|
temperature: 0,
|
||||||
|
stream: false
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle API errors', async () => {
|
||||||
|
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||||
|
await expect(handler.completePrompt('Test prompt'))
|
||||||
|
.rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle empty response', async () => {
|
||||||
|
mockCreate.mockResolvedValueOnce({
|
||||||
|
choices: [{ message: { content: '' } }]
|
||||||
|
});
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('getModel', () => {
|
describe('getModel', () => {
|
||||||
it('should return model info with sane defaults', () => {
|
it('should return model info', () => {
|
||||||
const modelInfo = handler.getModel();
|
const modelInfo = handler.getModel();
|
||||||
expect(modelInfo.id).toBe('mistral-7b');
|
expect(modelInfo.id).toBe(mockOptions.lmStudioModelId);
|
||||||
expect(modelInfo.info).toBeDefined();
|
expect(modelInfo.info).toBeDefined();
|
||||||
expect(modelInfo.info.maxTokens).toBe(-1);
|
expect(modelInfo.info.maxTokens).toBe(-1);
|
||||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return empty string as model ID if not provided', () => {
|
|
||||||
const noModelHandler = new LmStudioHandler({});
|
|
||||||
const modelInfo = noModelHandler.getModel();
|
|
||||||
expect(modelInfo.id).toBe('');
|
|
||||||
expect(modelInfo.info).toBeDefined();
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -1,148 +1,160 @@
|
|||||||
import { OllamaHandler } from '../ollama';
|
import { OllamaHandler } from '../ollama';
|
||||||
import { Anthropic } from '@anthropic-ai/sdk';
|
import { ApiHandlerOptions } from '../../../shared/api';
|
||||||
import OpenAI from 'openai';
|
import OpenAI from 'openai';
|
||||||
|
import { Anthropic } from '@anthropic-ai/sdk';
|
||||||
|
|
||||||
// Mock OpenAI SDK
|
// Mock OpenAI client
|
||||||
jest.mock('openai', () => ({
|
const mockCreate = jest.fn();
|
||||||
|
jest.mock('openai', () => {
|
||||||
|
return {
|
||||||
__esModule: true,
|
__esModule: true,
|
||||||
default: jest.fn().mockImplementation(() => ({
|
default: jest.fn().mockImplementation(() => ({
|
||||||
chat: {
|
chat: {
|
||||||
completions: {
|
completions: {
|
||||||
create: jest.fn()
|
create: mockCreate.mockImplementation(async (options) => {
|
||||||
|
if (!options.stream) {
|
||||||
|
return {
|
||||||
|
id: 'test-completion',
|
||||||
|
choices: [{
|
||||||
|
message: { role: 'assistant', content: 'Test response' },
|
||||||
|
finish_reason: 'stop',
|
||||||
|
index: 0
|
||||||
|
}],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 5,
|
||||||
|
total_tokens: 15
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
[Symbol.asyncIterator]: async function* () {
|
||||||
|
yield {
|
||||||
|
choices: [{
|
||||||
|
delta: { content: 'Test response' },
|
||||||
|
index: 0
|
||||||
|
}],
|
||||||
|
usage: null
|
||||||
|
};
|
||||||
|
yield {
|
||||||
|
choices: [{
|
||||||
|
delta: {},
|
||||||
|
index: 0
|
||||||
|
}],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 5,
|
||||||
|
total_tokens: 15
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
}));
|
};
|
||||||
|
});
|
||||||
|
|
||||||
describe('OllamaHandler', () => {
|
describe('OllamaHandler', () => {
|
||||||
let handler: OllamaHandler;
|
let handler: OllamaHandler;
|
||||||
|
let mockOptions: ApiHandlerOptions;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
handler = new OllamaHandler({
|
mockOptions = {
|
||||||
|
apiModelId: 'llama2',
|
||||||
ollamaModelId: 'llama2',
|
ollamaModelId: 'llama2',
|
||||||
ollamaBaseUrl: 'http://localhost:11434'
|
ollamaBaseUrl: 'http://localhost:11434/v1'
|
||||||
});
|
};
|
||||||
|
handler = new OllamaHandler(mockOptions);
|
||||||
|
mockCreate.mockClear();
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('constructor', () => {
|
describe('constructor', () => {
|
||||||
it('should initialize with provided config', () => {
|
it('should initialize with provided options', () => {
|
||||||
expect(OpenAI).toHaveBeenCalledWith({
|
expect(handler).toBeInstanceOf(OllamaHandler);
|
||||||
baseURL: 'http://localhost:11434/v1',
|
expect(handler.getModel().id).toBe(mockOptions.ollamaModelId);
|
||||||
apiKey: 'ollama'
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should use default base URL if not provided', () => {
|
it('should use default base URL if not provided', () => {
|
||||||
const defaultHandler = new OllamaHandler({
|
const handlerWithoutUrl = new OllamaHandler({
|
||||||
|
apiModelId: 'llama2',
|
||||||
ollamaModelId: 'llama2'
|
ollamaModelId: 'llama2'
|
||||||
});
|
});
|
||||||
|
expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler);
|
||||||
expect(OpenAI).toHaveBeenCalledWith({
|
|
||||||
baseURL: 'http://localhost:11434/v1',
|
|
||||||
apiKey: 'ollama'
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('createMessage', () => {
|
describe('createMessage', () => {
|
||||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
const systemPrompt = 'You are a helpful assistant.';
|
||||||
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content: 'Hello'
|
content: 'Hello!'
|
||||||
},
|
|
||||||
{
|
|
||||||
role: 'assistant',
|
|
||||||
content: 'Hi there!'
|
|
||||||
}
|
}
|
||||||
];
|
];
|
||||||
|
|
||||||
const systemPrompt = 'You are a helpful assistant';
|
it('should handle streaming responses', async () => {
|
||||||
|
const stream = handler.createMessage(systemPrompt, messages);
|
||||||
it('should handle streaming responses correctly', async () => {
|
const chunks: any[] = [];
|
||||||
const mockStream = [
|
|
||||||
{
|
|
||||||
choices: [{
|
|
||||||
delta: { content: 'Hello' }
|
|
||||||
}]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
choices: [{
|
|
||||||
delta: { content: ' world!' }
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
];
|
|
||||||
|
|
||||||
// Setup async iterator for mock stream
|
|
||||||
const asyncIterator = {
|
|
||||||
async *[Symbol.asyncIterator]() {
|
|
||||||
for (const chunk of mockStream) {
|
|
||||||
yield chunk;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
|
|
||||||
(handler['client'].chat.completions as any).create = mockCreate;
|
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
|
||||||
const chunks = [];
|
|
||||||
|
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
chunks.push(chunk);
|
chunks.push(chunk);
|
||||||
}
|
}
|
||||||
|
|
||||||
expect(chunks.length).toBe(2);
|
expect(chunks.length).toBeGreaterThan(0);
|
||||||
expect(chunks[0]).toEqual({
|
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||||
type: 'text',
|
expect(textChunks).toHaveLength(1);
|
||||||
text: 'Hello'
|
expect(textChunks[0].text).toBe('Test response');
|
||||||
});
|
|
||||||
expect(chunks[1]).toEqual({
|
|
||||||
type: 'text',
|
|
||||||
text: ' world!'
|
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockCreate).toHaveBeenCalledWith({
|
it('should handle API errors', async () => {
|
||||||
model: 'llama2',
|
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||||
messages: expect.arrayContaining([
|
|
||||||
{
|
const stream = handler.createMessage(systemPrompt, messages);
|
||||||
role: 'system',
|
|
||||||
content: systemPrompt
|
await expect(async () => {
|
||||||
|
for await (const chunk of stream) {
|
||||||
|
// Should not reach here
|
||||||
}
|
}
|
||||||
]),
|
}).rejects.toThrow('API Error');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('completePrompt', () => {
|
||||||
|
it('should complete prompt successfully', async () => {
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('Test response');
|
||||||
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
|
model: mockOptions.ollamaModelId,
|
||||||
|
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
stream: true
|
stream: false
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
it('should handle API errors', async () => {
|
||||||
const mockError = new Error('Ollama API error');
|
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||||
const mockCreate = jest.fn().mockRejectedValue(mockError);
|
await expect(handler.completePrompt('Test prompt'))
|
||||||
(handler['client'].chat.completions as any).create = mockCreate;
|
.rejects.toThrow('Ollama completion error: API Error');
|
||||||
|
});
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
it('should handle empty response', async () => {
|
||||||
|
mockCreate.mockResolvedValueOnce({
|
||||||
await expect(async () => {
|
choices: [{ message: { content: '' } }]
|
||||||
for await (const chunk of stream) {
|
});
|
||||||
// Should throw before yielding any chunks
|
const result = await handler.completePrompt('Test prompt');
|
||||||
}
|
expect(result).toBe('');
|
||||||
}).rejects.toThrow('Ollama API error');
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('getModel', () => {
|
describe('getModel', () => {
|
||||||
it('should return model info with sane defaults', () => {
|
it('should return model info', () => {
|
||||||
const modelInfo = handler.getModel();
|
const modelInfo = handler.getModel();
|
||||||
expect(modelInfo.id).toBe('llama2');
|
expect(modelInfo.id).toBe(mockOptions.ollamaModelId);
|
||||||
expect(modelInfo.info).toBeDefined();
|
expect(modelInfo.info).toBeDefined();
|
||||||
expect(modelInfo.info.maxTokens).toBe(-1);
|
expect(modelInfo.info.maxTokens).toBe(-1);
|
||||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return empty string as model ID if not provided', () => {
|
|
||||||
const noModelHandler = new OllamaHandler({});
|
|
||||||
const modelInfo = noModelHandler.getModel();
|
|
||||||
expect(modelInfo.id).toBe('');
|
|
||||||
expect(modelInfo.info).toBeDefined();
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -1,230 +1,209 @@
|
|||||||
import { OpenAiNativeHandler } from "../openai-native"
|
import { OpenAiNativeHandler } from '../openai-native';
|
||||||
import OpenAI from "openai"
|
import { ApiHandlerOptions } from '../../../shared/api';
|
||||||
import { ApiHandlerOptions, openAiNativeDefaultModelId } from "../../../shared/api"
|
import OpenAI from 'openai';
|
||||||
import { Anthropic } from "@anthropic-ai/sdk"
|
import { Anthropic } from '@anthropic-ai/sdk';
|
||||||
|
|
||||||
// Mock OpenAI
|
// Mock OpenAI client
|
||||||
jest.mock("openai")
|
const mockCreate = jest.fn();
|
||||||
|
jest.mock('openai', () => {
|
||||||
describe("OpenAiNativeHandler", () => {
|
return {
|
||||||
let handler: OpenAiNativeHandler
|
__esModule: true,
|
||||||
let mockOptions: ApiHandlerOptions
|
default: jest.fn().mockImplementation(() => ({
|
||||||
let mockOpenAIClient: jest.Mocked<OpenAI>
|
|
||||||
let mockCreate: jest.Mock
|
|
||||||
|
|
||||||
beforeEach(() => {
|
|
||||||
// Reset mocks
|
|
||||||
jest.clearAllMocks()
|
|
||||||
|
|
||||||
// Setup mock options
|
|
||||||
mockOptions = {
|
|
||||||
openAiNativeApiKey: "test-api-key",
|
|
||||||
apiModelId: "gpt-4o", // Use the correct model ID from shared/api.ts
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup mock create function
|
|
||||||
mockCreate = jest.fn()
|
|
||||||
|
|
||||||
// Setup mock OpenAI client
|
|
||||||
mockOpenAIClient = {
|
|
||||||
chat: {
|
chat: {
|
||||||
completions: {
|
completions: {
|
||||||
create: mockCreate,
|
create: mockCreate.mockImplementation(async (options) => {
|
||||||
},
|
if (!options.stream) {
|
||||||
},
|
return {
|
||||||
} as unknown as jest.Mocked<OpenAI>
|
id: 'test-completion',
|
||||||
|
choices: [{
|
||||||
// Mock OpenAI constructor
|
message: { role: 'assistant', content: 'Test response' },
|
||||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).mockImplementation(() => mockOpenAIClient)
|
finish_reason: 'stop',
|
||||||
|
index: 0
|
||||||
// Create handler instance
|
}],
|
||||||
handler = new OpenAiNativeHandler(mockOptions)
|
|
||||||
})
|
|
||||||
|
|
||||||
describe("constructor", () => {
|
|
||||||
it("should initialize with provided options", () => {
|
|
||||||
expect(OpenAI).toHaveBeenCalledWith({
|
|
||||||
apiKey: mockOptions.openAiNativeApiKey,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe("getModel", () => {
|
|
||||||
it("should return specified model when valid", () => {
|
|
||||||
const result = handler.getModel()
|
|
||||||
expect(result.id).toBe("gpt-4o") // Use the correct model ID
|
|
||||||
})
|
|
||||||
|
|
||||||
it("should return default model when model ID is invalid", () => {
|
|
||||||
handler = new OpenAiNativeHandler({
|
|
||||||
...mockOptions,
|
|
||||||
apiModelId: "invalid-model" as any,
|
|
||||||
})
|
|
||||||
const result = handler.getModel()
|
|
||||||
expect(result.id).toBe(openAiNativeDefaultModelId)
|
|
||||||
})
|
|
||||||
|
|
||||||
it("should return default model when model ID is not provided", () => {
|
|
||||||
handler = new OpenAiNativeHandler({
|
|
||||||
...mockOptions,
|
|
||||||
apiModelId: undefined,
|
|
||||||
})
|
|
||||||
const result = handler.getModel()
|
|
||||||
expect(result.id).toBe(openAiNativeDefaultModelId)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe("createMessage", () => {
|
|
||||||
const systemPrompt = "You are a helpful assistant"
|
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [
|
|
||||||
{ role: "user", content: "Hello" },
|
|
||||||
]
|
|
||||||
|
|
||||||
describe("o1 models", () => {
|
|
||||||
beforeEach(() => {
|
|
||||||
handler = new OpenAiNativeHandler({
|
|
||||||
...mockOptions,
|
|
||||||
apiModelId: "o1-preview",
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
it("should handle non-streaming response for o1 models", async () => {
|
|
||||||
const mockResponse = {
|
|
||||||
choices: [{ message: { content: "Hello there!" } }],
|
|
||||||
usage: {
|
usage: {
|
||||||
prompt_tokens: 10,
|
prompt_tokens: 10,
|
||||||
completion_tokens: 5,
|
completion_tokens: 5,
|
||||||
},
|
total_tokens: 15
|
||||||
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
mockCreate.mockResolvedValueOnce(mockResponse)
|
return {
|
||||||
|
[Symbol.asyncIterator]: async function* () {
|
||||||
const generator = handler.createMessage(systemPrompt, messages)
|
yield {
|
||||||
const results = []
|
choices: [{
|
||||||
for await (const result of generator) {
|
delta: { content: 'Test response' },
|
||||||
results.push(result)
|
index: 0
|
||||||
|
}],
|
||||||
|
usage: null
|
||||||
|
};
|
||||||
|
yield {
|
||||||
|
choices: [{
|
||||||
|
delta: {},
|
||||||
|
index: 0
|
||||||
|
}],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 5,
|
||||||
|
total_tokens: 15
|
||||||
}
|
}
|
||||||
|
};
|
||||||
expect(results).toEqual([
|
|
||||||
{ type: "text", text: "Hello there!" },
|
|
||||||
{ type: "usage", inputTokens: 10, outputTokens: 5 },
|
|
||||||
])
|
|
||||||
|
|
||||||
expect(mockCreate).toHaveBeenCalledWith({
|
|
||||||
model: "o1-preview",
|
|
||||||
messages: [
|
|
||||||
{ role: "user", content: systemPrompt },
|
|
||||||
{ role: "user", content: "Hello" },
|
|
||||||
],
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
it("should handle missing content in response", async () => {
|
|
||||||
const mockResponse = {
|
|
||||||
choices: [{ message: { content: null } }],
|
|
||||||
usage: null,
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
mockCreate.mockResolvedValueOnce(mockResponse)
|
})
|
||||||
|
|
||||||
const generator = handler.createMessage(systemPrompt, messages)
|
|
||||||
const results = []
|
|
||||||
for await (const result of generator) {
|
|
||||||
results.push(result)
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
expect(results).toEqual([
|
describe('OpenAiNativeHandler', () => {
|
||||||
{ type: "text", text: "" },
|
let handler: OpenAiNativeHandler;
|
||||||
{ type: "usage", inputTokens: 0, outputTokens: 0 },
|
let mockOptions: ApiHandlerOptions;
|
||||||
])
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe("streaming models", () => {
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
handler = new OpenAiNativeHandler({
|
mockOptions = {
|
||||||
...mockOptions,
|
apiModelId: 'gpt-4o',
|
||||||
apiModelId: "gpt-4o",
|
openAiNativeApiKey: 'test-api-key'
|
||||||
})
|
};
|
||||||
})
|
handler = new OpenAiNativeHandler(mockOptions);
|
||||||
|
mockCreate.mockClear();
|
||||||
|
});
|
||||||
|
|
||||||
it("should handle streaming response", async () => {
|
describe('constructor', () => {
|
||||||
const mockStream = [
|
it('should initialize with provided options', () => {
|
||||||
{ choices: [{ delta: { content: "Hello" } }], usage: null },
|
expect(handler).toBeInstanceOf(OpenAiNativeHandler);
|
||||||
{ choices: [{ delta: { content: " there" } }], usage: null },
|
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
|
||||||
{ choices: [{ delta: { content: "!" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
});
|
||||||
]
|
|
||||||
|
|
||||||
mockCreate.mockResolvedValueOnce(
|
it('should initialize with empty API key', () => {
|
||||||
(async function* () {
|
const handlerWithoutKey = new OpenAiNativeHandler({
|
||||||
for (const chunk of mockStream) {
|
apiModelId: 'gpt-4o',
|
||||||
yield chunk
|
openAiNativeApiKey: ''
|
||||||
|
});
|
||||||
|
expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('createMessage', () => {
|
||||||
|
const systemPrompt = 'You are a helpful assistant.';
|
||||||
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: 'Hello!'
|
||||||
}
|
}
|
||||||
})()
|
];
|
||||||
)
|
|
||||||
|
|
||||||
const generator = handler.createMessage(systemPrompt, messages)
|
it('should handle streaming responses', async () => {
|
||||||
const results = []
|
const stream = handler.createMessage(systemPrompt, messages);
|
||||||
for await (const result of generator) {
|
const chunks: any[] = [];
|
||||||
results.push(result)
|
for await (const chunk of stream) {
|
||||||
|
chunks.push(chunk);
|
||||||
}
|
}
|
||||||
|
|
||||||
expect(results).toEqual([
|
expect(chunks.length).toBeGreaterThan(0);
|
||||||
{ type: "text", text: "Hello" },
|
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||||
{ type: "text", text: " there" },
|
expect(textChunks).toHaveLength(1);
|
||||||
{ type: "text", text: "!" },
|
expect(textChunks[0].text).toBe('Test response');
|
||||||
{ type: "usage", inputTokens: 10, outputTokens: 5 },
|
});
|
||||||
])
|
|
||||||
|
|
||||||
expect(mockCreate).toHaveBeenCalledWith({
|
it('should handle API errors', async () => {
|
||||||
model: "gpt-4o",
|
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||||
temperature: 0,
|
|
||||||
messages: [
|
|
||||||
{ role: "system", content: systemPrompt },
|
|
||||||
{ role: "user", content: "Hello" },
|
|
||||||
],
|
|
||||||
stream: true,
|
|
||||||
stream_options: { include_usage: true },
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
it("should handle empty delta content", async () => {
|
const stream = handler.createMessage(systemPrompt, messages);
|
||||||
const mockStream = [
|
|
||||||
{ choices: [{ delta: {} }], usage: null },
|
|
||||||
{ choices: [{ delta: { content: null } }], usage: null },
|
|
||||||
{ choices: [{ delta: { content: "Hello" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
|
||||||
]
|
|
||||||
|
|
||||||
mockCreate.mockResolvedValueOnce(
|
|
||||||
(async function* () {
|
|
||||||
for (const chunk of mockStream) {
|
|
||||||
yield chunk
|
|
||||||
}
|
|
||||||
})()
|
|
||||||
)
|
|
||||||
|
|
||||||
const generator = handler.createMessage(systemPrompt, messages)
|
|
||||||
const results = []
|
|
||||||
for await (const result of generator) {
|
|
||||||
results.push(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(results).toEqual([
|
|
||||||
{ type: "text", text: "Hello" },
|
|
||||||
{ type: "usage", inputTokens: 10, outputTokens: 5 },
|
|
||||||
])
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
it("should handle API errors", async () => {
|
|
||||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
|
||||||
|
|
||||||
const generator = handler.createMessage(systemPrompt, messages)
|
|
||||||
await expect(async () => {
|
await expect(async () => {
|
||||||
for await (const _ of generator) {
|
for await (const chunk of stream) {
|
||||||
// consume generator
|
// Should not reach here
|
||||||
}
|
}
|
||||||
}).rejects.toThrow("API Error")
|
}).rejects.toThrow('API Error');
|
||||||
})
|
});
|
||||||
})
|
});
|
||||||
})
|
|
||||||
|
describe('completePrompt', () => {
|
||||||
|
it('should complete prompt successfully with gpt-4o model', async () => {
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('Test response');
|
||||||
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
|
model: 'gpt-4o',
|
||||||
|
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||||
|
temperature: 0
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should complete prompt successfully with o1 model', async () => {
|
||||||
|
handler = new OpenAiNativeHandler({
|
||||||
|
apiModelId: 'o1',
|
||||||
|
openAiNativeApiKey: 'test-api-key'
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('Test response');
|
||||||
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
|
model: 'o1',
|
||||||
|
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should complete prompt successfully with o1-preview model', async () => {
|
||||||
|
handler = new OpenAiNativeHandler({
|
||||||
|
apiModelId: 'o1-preview',
|
||||||
|
openAiNativeApiKey: 'test-api-key'
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('Test response');
|
||||||
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
|
model: 'o1-preview',
|
||||||
|
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should complete prompt successfully with o1-mini model', async () => {
|
||||||
|
handler = new OpenAiNativeHandler({
|
||||||
|
apiModelId: 'o1-mini',
|
||||||
|
openAiNativeApiKey: 'test-api-key'
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('Test response');
|
||||||
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
|
model: 'o1-mini',
|
||||||
|
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle API errors', async () => {
|
||||||
|
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||||
|
await expect(handler.completePrompt('Test prompt'))
|
||||||
|
.rejects.toThrow('OpenAI Native completion error: API Error');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle empty response', async () => {
|
||||||
|
mockCreate.mockResolvedValueOnce({
|
||||||
|
choices: [{ message: { content: '' } }]
|
||||||
|
});
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('getModel', () => {
|
||||||
|
it('should return model info', () => {
|
||||||
|
const modelInfo = handler.getModel();
|
||||||
|
expect(modelInfo.id).toBe(mockOptions.apiModelId);
|
||||||
|
expect(modelInfo.info).toBeDefined();
|
||||||
|
expect(modelInfo.info.maxTokens).toBe(4096);
|
||||||
|
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle undefined model ID', () => {
|
||||||
|
const handlerWithoutModel = new OpenAiNativeHandler({
|
||||||
|
openAiNativeApiKey: 'test-api-key'
|
||||||
|
});
|
||||||
|
const modelInfo = handlerWithoutModel.getModel();
|
||||||
|
expect(modelInfo.id).toBe('gpt-4o'); // Default model
|
||||||
|
expect(modelInfo.info).toBeDefined();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -176,6 +176,32 @@ describe('OpenAiHandler', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('completePrompt', () => {
|
||||||
|
it('should complete prompt successfully', async () => {
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('Test response');
|
||||||
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
|
model: mockOptions.openAiModelId,
|
||||||
|
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||||
|
temperature: 0
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle API errors', async () => {
|
||||||
|
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||||
|
await expect(handler.completePrompt('Test prompt'))
|
||||||
|
.rejects.toThrow('OpenAI completion error: API Error');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle empty response', async () => {
|
||||||
|
mockCreate.mockImplementationOnce(() => ({
|
||||||
|
choices: [{ message: { content: '' } }]
|
||||||
|
}));
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('getModel', () => {
|
describe('getModel', () => {
|
||||||
it('should return model info with sane defaults', () => {
|
it('should return model info with sane defaults', () => {
|
||||||
const model = handler.getModel();
|
const model = handler.getModel();
|
||||||
|
|||||||
@@ -6,7 +6,42 @@ import { AnthropicVertex } from '@anthropic-ai/vertex-sdk';
|
|||||||
jest.mock('@anthropic-ai/vertex-sdk', () => ({
|
jest.mock('@anthropic-ai/vertex-sdk', () => ({
|
||||||
AnthropicVertex: jest.fn().mockImplementation(() => ({
|
AnthropicVertex: jest.fn().mockImplementation(() => ({
|
||||||
messages: {
|
messages: {
|
||||||
create: jest.fn()
|
create: jest.fn().mockImplementation(async (options) => {
|
||||||
|
if (!options.stream) {
|
||||||
|
return {
|
||||||
|
id: 'test-completion',
|
||||||
|
content: [
|
||||||
|
{ type: 'text', text: 'Test response' }
|
||||||
|
],
|
||||||
|
role: 'assistant',
|
||||||
|
model: options.model,
|
||||||
|
usage: {
|
||||||
|
input_tokens: 10,
|
||||||
|
output_tokens: 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
yield {
|
||||||
|
type: 'message_start',
|
||||||
|
message: {
|
||||||
|
usage: {
|
||||||
|
input_tokens: 10,
|
||||||
|
output_tokens: 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
yield {
|
||||||
|
type: 'content_block_start',
|
||||||
|
content_block: {
|
||||||
|
type: 'text',
|
||||||
|
text: 'Test response'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
}));
|
}));
|
||||||
@@ -196,6 +231,49 @@ describe('VertexHandler', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('completePrompt', () => {
|
||||||
|
it('should complete prompt successfully', async () => {
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('Test response');
|
||||||
|
expect(handler['client'].messages.create).toHaveBeenCalledWith({
|
||||||
|
model: 'claude-3-5-sonnet-v2@20241022',
|
||||||
|
max_tokens: 8192,
|
||||||
|
temperature: 0,
|
||||||
|
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||||
|
stream: false
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle API errors', async () => {
|
||||||
|
const mockError = new Error('Vertex API error');
|
||||||
|
const mockCreate = jest.fn().mockRejectedValue(mockError);
|
||||||
|
(handler['client'].messages as any).create = mockCreate;
|
||||||
|
|
||||||
|
await expect(handler.completePrompt('Test prompt'))
|
||||||
|
.rejects.toThrow('Vertex completion error: Vertex API error');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle non-text content', async () => {
|
||||||
|
const mockCreate = jest.fn().mockResolvedValue({
|
||||||
|
content: [{ type: 'image' }]
|
||||||
|
});
|
||||||
|
(handler['client'].messages as any).create = mockCreate;
|
||||||
|
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle empty response', async () => {
|
||||||
|
const mockCreate = jest.fn().mockResolvedValue({
|
||||||
|
content: [{ type: 'text', text: '' }]
|
||||||
|
});
|
||||||
|
(handler['client'].messages as any).create = mockCreate;
|
||||||
|
|
||||||
|
const result = await handler.completePrompt('Test prompt');
|
||||||
|
expect(result).toBe('');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('getModel', () => {
|
describe('getModel', () => {
|
||||||
it('should return correct model info', () => {
|
it('should return correct model info', () => {
|
||||||
const modelInfo = handler.getModel();
|
const modelInfo = handler.getModel();
|
||||||
|
|||||||
@@ -7,10 +7,10 @@ import {
|
|||||||
ApiHandlerOptions,
|
ApiHandlerOptions,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
} from "../../shared/api"
|
} from "../../shared/api"
|
||||||
import { ApiHandler } from "../index"
|
import { ApiHandler, SingleCompletionHandler } from "../index"
|
||||||
import { ApiStream } from "../transform/stream"
|
import { ApiStream } from "../transform/stream"
|
||||||
|
|
||||||
export class AnthropicHandler implements ApiHandler {
|
export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
|
||||||
private options: ApiHandlerOptions
|
private options: ApiHandlerOptions
|
||||||
private client: Anthropic
|
private client: Anthropic
|
||||||
|
|
||||||
@@ -173,4 +173,27 @@ export class AnthropicHandler implements ApiHandler {
|
|||||||
}
|
}
|
||||||
return { id: anthropicDefaultModelId, info: anthropicModels[anthropicDefaultModelId] }
|
return { id: anthropicDefaultModelId, info: anthropicModels[anthropicDefaultModelId] }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async completePrompt(prompt: string): Promise<string> {
|
||||||
|
try {
|
||||||
|
const response = await this.client.messages.create({
|
||||||
|
model: this.getModel().id,
|
||||||
|
max_tokens: this.getModel().info.maxTokens || 8192,
|
||||||
|
temperature: 0,
|
||||||
|
messages: [{ role: "user", content: prompt }],
|
||||||
|
stream: false
|
||||||
|
})
|
||||||
|
|
||||||
|
const content = response.content[0]
|
||||||
|
if (content.type === 'text') {
|
||||||
|
return content.text
|
||||||
|
}
|
||||||
|
return ''
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error) {
|
||||||
|
throw new Error(`Anthropic completion error: ${error.message}`)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { BedrockRuntimeClient, ConverseStreamCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime"
|
import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime"
|
||||||
import { Anthropic } from "@anthropic-ai/sdk"
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import { ApiHandler } from "../"
|
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||||
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
|
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
|
||||||
import { ApiStream } from "../transform/stream"
|
import { ApiStream } from "../transform/stream"
|
||||||
import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../transform/bedrock-converse-format"
|
import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../transform/bedrock-converse-format"
|
||||||
@@ -38,7 +38,7 @@ export interface StreamEvent {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export class AwsBedrockHandler implements ApiHandler {
|
export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||||
private options: ApiHandlerOptions
|
private options: ApiHandlerOptions
|
||||||
private client: BedrockRuntimeClient
|
private client: BedrockRuntimeClient
|
||||||
|
|
||||||
@@ -219,4 +219,63 @@ export class AwsBedrockHandler implements ApiHandler {
|
|||||||
info: bedrockModels[bedrockDefaultModelId]
|
info: bedrockModels[bedrockDefaultModelId]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async completePrompt(prompt: string): Promise<string> {
|
||||||
|
try {
|
||||||
|
const modelConfig = this.getModel()
|
||||||
|
|
||||||
|
// Handle cross-region inference
|
||||||
|
let modelId: string
|
||||||
|
if (this.options.awsUseCrossRegionInference) {
|
||||||
|
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
|
||||||
|
switch (regionPrefix) {
|
||||||
|
case "us-":
|
||||||
|
modelId = `us.${modelConfig.id}`
|
||||||
|
break
|
||||||
|
case "eu-":
|
||||||
|
modelId = `eu.${modelConfig.id}`
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
modelId = modelConfig.id
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
modelId = modelConfig.id
|
||||||
|
}
|
||||||
|
|
||||||
|
const payload = {
|
||||||
|
modelId,
|
||||||
|
messages: convertToBedrockConverseMessages([{
|
||||||
|
role: "user",
|
||||||
|
content: prompt
|
||||||
|
}]),
|
||||||
|
inferenceConfig: {
|
||||||
|
maxTokens: modelConfig.info.maxTokens || 5000,
|
||||||
|
temperature: 0.3,
|
||||||
|
topP: 0.1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const command = new ConverseCommand(payload)
|
||||||
|
const response = await this.client.send(command)
|
||||||
|
|
||||||
|
if (response.output && response.output instanceof Uint8Array) {
|
||||||
|
try {
|
||||||
|
const outputStr = new TextDecoder().decode(response.output)
|
||||||
|
const output = JSON.parse(outputStr)
|
||||||
|
if (output.content) {
|
||||||
|
return output.content
|
||||||
|
}
|
||||||
|
} catch (parseError) {
|
||||||
|
console.error('Failed to parse Bedrock response:', parseError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ''
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error) {
|
||||||
|
throw new Error(`Bedrock completion error: ${error.message}`)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import { Anthropic } from "@anthropic-ai/sdk"
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import { GoogleGenerativeAI } from "@google/generative-ai"
|
import { GoogleGenerativeAI } from "@google/generative-ai"
|
||||||
import { ApiHandler } from "../"
|
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||||
import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../../shared/api"
|
import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../../shared/api"
|
||||||
import { convertAnthropicMessageToGemini } from "../transform/gemini-format"
|
import { convertAnthropicMessageToGemini } from "../transform/gemini-format"
|
||||||
import { ApiStream } from "../transform/stream"
|
import { ApiStream } from "../transform/stream"
|
||||||
|
|
||||||
export class GeminiHandler implements ApiHandler {
|
export class GeminiHandler implements ApiHandler, SingleCompletionHandler {
|
||||||
private options: ApiHandlerOptions
|
private options: ApiHandlerOptions
|
||||||
private client: GoogleGenerativeAI
|
private client: GoogleGenerativeAI
|
||||||
|
|
||||||
@@ -53,4 +53,26 @@ export class GeminiHandler implements ApiHandler {
|
|||||||
}
|
}
|
||||||
return { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] }
|
return { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async completePrompt(prompt: string): Promise<string> {
|
||||||
|
try {
|
||||||
|
const model = this.client.getGenerativeModel({
|
||||||
|
model: this.getModel().id,
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = await model.generateContent({
|
||||||
|
contents: [{ role: "user", parts: [{ text: prompt }] }],
|
||||||
|
generationConfig: {
|
||||||
|
temperature: 0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return result.response.text()
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error) {
|
||||||
|
throw new Error(`Gemini completion error: ${error.message}`)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
import { Anthropic } from "@anthropic-ai/sdk"
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import axios from "axios"
|
import axios from "axios"
|
||||||
import OpenAI from "openai"
|
import OpenAI from "openai"
|
||||||
import { ApiHandler } from "../"
|
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||||
import { ApiHandlerOptions, ModelInfo, glamaDefaultModelId, glamaDefaultModelInfo } from "../../shared/api"
|
import { ApiHandlerOptions, ModelInfo, glamaDefaultModelId, glamaDefaultModelInfo } from "../../shared/api"
|
||||||
import { convertToOpenAiMessages } from "../transform/openai-format"
|
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||||
import { ApiStream } from "../transform/stream"
|
import { ApiStream } from "../transform/stream"
|
||||||
import delay from "delay"
|
import delay from "delay"
|
||||||
|
|
||||||
export class GlamaHandler implements ApiHandler {
|
export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
|
||||||
private options: ApiHandlerOptions
|
private options: ApiHandlerOptions
|
||||||
private client: OpenAI
|
private client: OpenAI
|
||||||
|
|
||||||
@@ -129,4 +129,26 @@ export class GlamaHandler implements ApiHandler {
|
|||||||
|
|
||||||
return { id: glamaDefaultModelId, info: glamaDefaultModelInfo }
|
return { id: glamaDefaultModelId, info: glamaDefaultModelInfo }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async completePrompt(prompt: string): Promise<string> {
|
||||||
|
try {
|
||||||
|
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
|
||||||
|
model: this.getModel().id,
|
||||||
|
messages: [{ role: "user", content: prompt }],
|
||||||
|
temperature: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.getModel().id.startsWith("anthropic/")) {
|
||||||
|
requestOptions.max_tokens = 8192
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await this.client.chat.completions.create(requestOptions)
|
||||||
|
return response.choices[0]?.message.content || ""
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error) {
|
||||||
|
throw new Error(`Glama completion error: ${error.message}`)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import { Anthropic } from "@anthropic-ai/sdk"
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import OpenAI from "openai"
|
import OpenAI from "openai"
|
||||||
import { ApiHandler } from "../"
|
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||||
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
|
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
|
||||||
import { convertToOpenAiMessages } from "../transform/openai-format"
|
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||||
import { ApiStream } from "../transform/stream"
|
import { ApiStream } from "../transform/stream"
|
||||||
|
|
||||||
export class LmStudioHandler implements ApiHandler {
|
export class LmStudioHandler implements ApiHandler, SingleCompletionHandler {
|
||||||
private options: ApiHandlerOptions
|
private options: ApiHandlerOptions
|
||||||
private client: OpenAI
|
private client: OpenAI
|
||||||
|
|
||||||
@@ -53,4 +53,20 @@ export class LmStudioHandler implements ApiHandler {
|
|||||||
info: openAiModelInfoSaneDefaults,
|
info: openAiModelInfoSaneDefaults,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async completePrompt(prompt: string): Promise<string> {
|
||||||
|
try {
|
||||||
|
const response = await this.client.chat.completions.create({
|
||||||
|
model: this.getModel().id,
|
||||||
|
messages: [{ role: "user", content: prompt }],
|
||||||
|
temperature: 0,
|
||||||
|
stream: false
|
||||||
|
})
|
||||||
|
return response.choices[0]?.message.content || ""
|
||||||
|
} catch (error) {
|
||||||
|
throw new Error(
|
||||||
|
"Please check the LM Studio developer logs to debug what went wrong. You may need to load the model with a larger context length to work with Cline's prompts.",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import { Anthropic } from "@anthropic-ai/sdk"
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import OpenAI from "openai"
|
import OpenAI from "openai"
|
||||||
import { ApiHandler } from "../"
|
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||||
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
|
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
|
||||||
import { convertToOpenAiMessages } from "../transform/openai-format"
|
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||||
import { ApiStream } from "../transform/stream"
|
import { ApiStream } from "../transform/stream"
|
||||||
|
|
||||||
export class OllamaHandler implements ApiHandler {
|
export class OllamaHandler implements ApiHandler, SingleCompletionHandler {
|
||||||
private options: ApiHandlerOptions
|
private options: ApiHandlerOptions
|
||||||
private client: OpenAI
|
private client: OpenAI
|
||||||
|
|
||||||
@@ -46,4 +46,21 @@ export class OllamaHandler implements ApiHandler {
|
|||||||
info: openAiModelInfoSaneDefaults,
|
info: openAiModelInfoSaneDefaults,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async completePrompt(prompt: string): Promise<string> {
|
||||||
|
try {
|
||||||
|
const response = await this.client.chat.completions.create({
|
||||||
|
model: this.getModel().id,
|
||||||
|
messages: [{ role: "user", content: prompt }],
|
||||||
|
temperature: 0,
|
||||||
|
stream: false
|
||||||
|
})
|
||||||
|
return response.choices[0]?.message.content || ""
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error) {
|
||||||
|
throw new Error(`Ollama completion error: ${error.message}`)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { Anthropic } from "@anthropic-ai/sdk"
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import OpenAI from "openai"
|
import OpenAI from "openai"
|
||||||
import { ApiHandler } from "../"
|
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||||
import {
|
import {
|
||||||
ApiHandlerOptions,
|
ApiHandlerOptions,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
@@ -11,7 +11,7 @@ import {
|
|||||||
import { convertToOpenAiMessages } from "../transform/openai-format"
|
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||||
import { ApiStream } from "../transform/stream"
|
import { ApiStream } from "../transform/stream"
|
||||||
|
|
||||||
export class OpenAiNativeHandler implements ApiHandler {
|
export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler {
|
||||||
private options: ApiHandlerOptions
|
private options: ApiHandlerOptions
|
||||||
private client: OpenAI
|
private client: OpenAI
|
||||||
|
|
||||||
@@ -83,4 +83,37 @@ export class OpenAiNativeHandler implements ApiHandler {
|
|||||||
}
|
}
|
||||||
return { id: openAiNativeDefaultModelId, info: openAiNativeModels[openAiNativeDefaultModelId] }
|
return { id: openAiNativeDefaultModelId, info: openAiNativeModels[openAiNativeDefaultModelId] }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async completePrompt(prompt: string): Promise<string> {
|
||||||
|
try {
|
||||||
|
const modelId = this.getModel().id
|
||||||
|
let requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
|
||||||
|
|
||||||
|
switch (modelId) {
|
||||||
|
case "o1":
|
||||||
|
case "o1-preview":
|
||||||
|
case "o1-mini":
|
||||||
|
// o1 doesn't support non-1 temp or system prompt
|
||||||
|
requestOptions = {
|
||||||
|
model: modelId,
|
||||||
|
messages: [{ role: "user", content: prompt }]
|
||||||
|
}
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
requestOptions = {
|
||||||
|
model: modelId,
|
||||||
|
messages: [{ role: "user", content: prompt }],
|
||||||
|
temperature: 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await this.client.chat.completions.create(requestOptions)
|
||||||
|
return response.choices[0]?.message.content || ""
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error) {
|
||||||
|
throw new Error(`OpenAI Native completion error: ${error.message}`)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,11 +6,11 @@ import {
|
|||||||
ModelInfo,
|
ModelInfo,
|
||||||
openAiModelInfoSaneDefaults,
|
openAiModelInfoSaneDefaults,
|
||||||
} from "../../shared/api"
|
} from "../../shared/api"
|
||||||
import { ApiHandler } from "../index"
|
import { ApiHandler, SingleCompletionHandler } from "../index"
|
||||||
import { convertToOpenAiMessages } from "../transform/openai-format"
|
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||||
import { ApiStream } from "../transform/stream"
|
import { ApiStream } from "../transform/stream"
|
||||||
|
|
||||||
export class OpenAiHandler implements ApiHandler {
|
export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
|
||||||
protected options: ApiHandlerOptions
|
protected options: ApiHandlerOptions
|
||||||
private client: OpenAI
|
private client: OpenAI
|
||||||
|
|
||||||
@@ -100,4 +100,22 @@ export class OpenAiHandler implements ApiHandler {
|
|||||||
info: openAiModelInfoSaneDefaults,
|
info: openAiModelInfoSaneDefaults,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async completePrompt(prompt: string): Promise<string> {
|
||||||
|
try {
|
||||||
|
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
|
||||||
|
model: this.getModel().id,
|
||||||
|
messages: [{ role: "user", content: prompt }],
|
||||||
|
temperature: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await this.client.chat.completions.create(requestOptions)
|
||||||
|
return response.choices[0]?.message.content || ""
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error) {
|
||||||
|
throw new Error(`OpenAI completion error: ${error.message}`)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import { Anthropic } from "@anthropic-ai/sdk"
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
|
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
|
||||||
import { ApiHandler } from "../"
|
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||||
import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../../shared/api"
|
import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../../shared/api"
|
||||||
import { ApiStream } from "../transform/stream"
|
import { ApiStream } from "../transform/stream"
|
||||||
|
|
||||||
// https://docs.anthropic.com/en/api/claude-on-vertex-ai
|
// https://docs.anthropic.com/en/api/claude-on-vertex-ai
|
||||||
export class VertexHandler implements ApiHandler {
|
export class VertexHandler implements ApiHandler, SingleCompletionHandler {
|
||||||
private options: ApiHandlerOptions
|
private options: ApiHandlerOptions
|
||||||
private client: AnthropicVertex
|
private client: AnthropicVertex
|
||||||
|
|
||||||
@@ -83,4 +83,27 @@ export class VertexHandler implements ApiHandler {
|
|||||||
}
|
}
|
||||||
return { id: vertexDefaultModelId, info: vertexModels[vertexDefaultModelId] }
|
return { id: vertexDefaultModelId, info: vertexModels[vertexDefaultModelId] }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async completePrompt(prompt: string): Promise<string> {
|
||||||
|
try {
|
||||||
|
const response = await this.client.messages.create({
|
||||||
|
model: this.getModel().id,
|
||||||
|
max_tokens: this.getModel().info.maxTokens || 8192,
|
||||||
|
temperature: 0,
|
||||||
|
messages: [{ role: "user", content: prompt }],
|
||||||
|
stream: false
|
||||||
|
})
|
||||||
|
|
||||||
|
const content = response.content[0]
|
||||||
|
if (content.type === 'text') {
|
||||||
|
return content.text
|
||||||
|
}
|
||||||
|
return ''
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error) {
|
||||||
|
throw new Error(`Vertex completion error: ${error.message}`)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user