mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-20 12:21:13 -05:00
Add non-streaming completePrompt to all providers
This commit is contained in:
@@ -46,7 +46,42 @@ jest.mock('@anthropic-ai/sdk', () => {
|
||||
}
|
||||
},
|
||||
messages: {
|
||||
create: mockCreate
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
content: [
|
||||
{ type: 'text', text: 'Test response' }
|
||||
],
|
||||
role: 'assistant',
|
||||
model: options.model,
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
return {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: 'message_start',
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
yield {
|
||||
type: 'content_block_start',
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'Test response'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}))
|
||||
};
|
||||
@@ -144,6 +179,42 @@ describe('AnthropicHandler', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.apiModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
max_tokens: 8192,
|
||||
temperature: 0,
|
||||
stream: false
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Anthropic completion error: API Error');
|
||||
});
|
||||
|
||||
it('should handle non-text content', async () => {
|
||||
mockCreate.mockImplementationOnce(async () => ({
|
||||
content: [{ type: 'image' }]
|
||||
}));
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockImplementationOnce(async () => ({
|
||||
content: [{ type: 'text', text: '' }]
|
||||
}));
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return default model if no model ID is provided', () => {
|
||||
const handlerWithoutModel = new AnthropicHandler({
|
||||
|
||||
@@ -119,6 +119,108 @@ describe('AwsBedrockHandler', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(JSON.stringify({
|
||||
content: 'Test response'
|
||||
}))
|
||||
};
|
||||
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||
messages: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
role: 'user',
|
||||
content: [{ text: 'Test prompt' }]
|
||||
})
|
||||
]),
|
||||
inferenceConfig: expect.objectContaining({
|
||||
maxTokens: 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1
|
||||
})
|
||||
})
|
||||
}));
|
||||
});
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('AWS Bedrock error');
|
||||
const mockSend = jest.fn().mockRejectedValue(mockError);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Bedrock completion error: AWS Bedrock error');
|
||||
});
|
||||
|
||||
it('should handle invalid response format', async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode('invalid json')
|
||||
};
|
||||
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(JSON.stringify({}))
|
||||
};
|
||||
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
|
||||
it('should handle cross-region inference', async () => {
|
||||
handler = new AwsBedrockHandler({
|
||||
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||
awsAccessKey: 'test-access-key',
|
||||
awsSecretKey: 'test-secret-key',
|
||||
awsRegion: 'us-east-1',
|
||||
awsUseCrossRegionInference: true
|
||||
});
|
||||
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(JSON.stringify({
|
||||
content: 'Test response'
|
||||
}))
|
||||
};
|
||||
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: 'us.anthropic.claude-3-5-sonnet-20241022-v2:0'
|
||||
})
|
||||
}));
|
||||
});
|
||||
});
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return correct model info in test environment', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
|
||||
@@ -6,7 +6,12 @@ import { GoogleGenerativeAI } from '@google/generative-ai';
|
||||
jest.mock('@google/generative-ai', () => ({
|
||||
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
|
||||
getGenerativeModel: jest.fn().mockReturnValue({
|
||||
generateContentStream: jest.fn()
|
||||
generateContentStream: jest.fn(),
|
||||
generateContent: jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => 'Test response'
|
||||
}
|
||||
})
|
||||
})
|
||||
}))
|
||||
}));
|
||||
@@ -133,6 +138,59 @@ describe('GeminiHandler', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const mockGenerateContent = jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => 'Test response'
|
||||
}
|
||||
});
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent
|
||||
});
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
|
||||
model: 'gemini-2.0-flash-thinking-exp-1219'
|
||||
});
|
||||
expect(mockGenerateContent).toHaveBeenCalledWith({
|
||||
contents: [{ role: 'user', parts: [{ text: 'Test prompt' }] }],
|
||||
generationConfig: {
|
||||
temperature: 0
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('Gemini API error');
|
||||
const mockGenerateContent = jest.fn().mockRejectedValue(mockError);
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent
|
||||
});
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Gemini completion error: Gemini API error');
|
||||
});
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
const mockGenerateContent = jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => ''
|
||||
}
|
||||
});
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent
|
||||
});
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return correct model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
|
||||
226
src/api/providers/__tests__/glama.test.ts
Normal file
226
src/api/providers/__tests__/glama.test.ts
Normal file
@@ -0,0 +1,226 @@
|
||||
import { GlamaHandler } from '../glama';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import axios from 'axios';
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
const mockWithResponse = jest.fn();
|
||||
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: (...args: any[]) => {
|
||||
const stream = {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
const result = mockCreate(...args);
|
||||
if (args[0].stream) {
|
||||
mockWithResponse.mockReturnValue(Promise.resolve({
|
||||
data: stream,
|
||||
response: {
|
||||
headers: {
|
||||
get: (name: string) => name === 'x-completion-request-id' ? 'test-request-id' : null
|
||||
}
|
||||
}
|
||||
}));
|
||||
result.withResponse = mockWithResponse;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
|
||||
describe('GlamaHandler', () => {
|
||||
let handler: GlamaHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: 'anthropic/claude-3-5-sonnet',
|
||||
glamaModelId: 'anthropic/claude-3-5-sonnet',
|
||||
glamaApiKey: 'test-api-key'
|
||||
};
|
||||
handler = new GlamaHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
mockWithResponse.mockClear();
|
||||
|
||||
// Default mock implementation for non-streaming responses
|
||||
mockCreate.mockResolvedValue({
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(GlamaHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
|
||||
});
|
||||
});
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello!'
|
||||
}
|
||||
];
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
// Mock axios for token usage request
|
||||
const mockAxios = jest.spyOn(axios, 'get').mockResolvedValueOnce({
|
||||
data: {
|
||||
tokenUsage: {
|
||||
promptTokens: 10,
|
||||
completionTokens: 5,
|
||||
cacheCreationInputTokens: 0,
|
||||
cacheReadInputTokens: 0
|
||||
},
|
||||
totalCostUsd: "0.00"
|
||||
}
|
||||
});
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
expect(chunks.length).toBe(2); // Text chunk and usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Test response'
|
||||
});
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'usage',
|
||||
inputTokens: 10,
|
||||
outputTokens: 5,
|
||||
cacheWriteTokens: 0,
|
||||
cacheReadTokens: 0,
|
||||
totalCost: 0
|
||||
});
|
||||
|
||||
mockAxios.mockRestore();
|
||||
});
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockImplementationOnce(() => {
|
||||
throw new Error('API Error');
|
||||
});
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks = [];
|
||||
|
||||
try {
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
fail('Expected error to be thrown');
|
||||
} catch (error) {
|
||||
expect(error).toBeInstanceOf(Error);
|
||||
expect(error.message).toBe('API Error');
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
model: mockOptions.apiModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0,
|
||||
max_tokens: 8192
|
||||
}));
|
||||
});
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Glama completion error: API Error');
|
||||
});
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
|
||||
it('should not set max_tokens for non-Anthropic models', async () => {
|
||||
// Reset mock to clear any previous calls
|
||||
mockCreate.mockClear();
|
||||
|
||||
const nonAnthropicOptions = {
|
||||
apiModelId: 'openai/gpt-4',
|
||||
glamaModelId: 'openai/gpt-4',
|
||||
glamaApiKey: 'test-key',
|
||||
glamaModelInfo: {
|
||||
maxTokens: 4096,
|
||||
contextWindow: 8192,
|
||||
supportsImages: true,
|
||||
supportsPromptCache: false
|
||||
}
|
||||
};
|
||||
const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions);
|
||||
|
||||
await nonAnthropicHandler.completePrompt('Test prompt');
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
model: 'openai/gpt-4',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0
|
||||
}));
|
||||
expect(mockCreate.mock.calls[0][0]).not.toHaveProperty('max_tokens');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe(mockOptions.apiModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(8192);
|
||||
expect(modelInfo.info.contextWindow).toBe(200_000);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,148 +1,160 @@
|
||||
import { LmStudioHandler } from '../lmstudio';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
|
||||
// Mock OpenAI SDK
|
||||
jest.mock('openai', () => ({
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: jest.fn()
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
}));
|
||||
};
|
||||
});
|
||||
|
||||
describe('LmStudioHandler', () => {
|
||||
let handler: LmStudioHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
|
||||
beforeEach(() => {
|
||||
handler = new LmStudioHandler({
|
||||
lmStudioModelId: 'mistral-7b',
|
||||
lmStudioBaseUrl: 'http://localhost:1234'
|
||||
});
|
||||
mockOptions = {
|
||||
apiModelId: 'local-model',
|
||||
lmStudioModelId: 'local-model',
|
||||
lmStudioBaseUrl: 'http://localhost:1234/v1'
|
||||
};
|
||||
handler = new LmStudioHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided config', () => {
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
baseURL: 'http://localhost:1234/v1',
|
||||
apiKey: 'noop'
|
||||
});
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(LmStudioHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId);
|
||||
});
|
||||
|
||||
it('should use default base URL if not provided', () => {
|
||||
const defaultHandler = new LmStudioHandler({
|
||||
lmStudioModelId: 'mistral-7b'
|
||||
});
|
||||
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
baseURL: 'http://localhost:1234/v1',
|
||||
apiKey: 'noop'
|
||||
const handlerWithoutUrl = new LmStudioHandler({
|
||||
apiModelId: 'local-model',
|
||||
lmStudioModelId: 'local-model'
|
||||
});
|
||||
expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler);
|
||||
});
|
||||
});
|
||||
|
||||
describe('createMessage', () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello'
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi there!'
|
||||
content: 'Hello!'
|
||||
}
|
||||
];
|
||||
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
|
||||
it('should handle streaming responses correctly', async () => {
|
||||
const mockStream = [
|
||||
{
|
||||
choices: [{
|
||||
delta: { content: 'Hello' }
|
||||
}]
|
||||
},
|
||||
{
|
||||
choices: [{
|
||||
delta: { content: ' world!' }
|
||||
}]
|
||||
}
|
||||
];
|
||||
|
||||
// Setup async iterator for mock stream
|
||||
const asyncIterator = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
|
||||
(handler['client'].chat.completions as any).create = mockCreate;
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const chunks = [];
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
expect(chunks.length).toBe(2);
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Hello'
|
||||
});
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'text',
|
||||
text: ' world!'
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'mistral-7b',
|
||||
messages: expect.arrayContaining([
|
||||
{
|
||||
role: 'system',
|
||||
content: systemPrompt
|
||||
}
|
||||
]),
|
||||
temperature: 0,
|
||||
stream: true
|
||||
});
|
||||
});
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
|
||||
it('should handle API errors with custom message', async () => {
|
||||
const mockError = new Error('LM Studio API error');
|
||||
const mockCreate = jest.fn().mockRejectedValue(mockError);
|
||||
(handler['client'].chat.completions as any).create = mockCreate;
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
|
||||
});
|
||||
});
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.lmStudioModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
|
||||
});
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info with sane defaults', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe('mistral-7b');
|
||||
expect(modelInfo.id).toBe(mockOptions.lmStudioModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(-1);
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||
});
|
||||
|
||||
it('should return empty string as model ID if not provided', () => {
|
||||
const noModelHandler = new LmStudioHandler({});
|
||||
const modelInfo = noModelHandler.getModel();
|
||||
expect(modelInfo.id).toBe('');
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,148 +1,160 @@
|
||||
import { OllamaHandler } from '../ollama';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
|
||||
// Mock OpenAI SDK
|
||||
jest.mock('openai', () => ({
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: jest.fn()
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
}));
|
||||
};
|
||||
});
|
||||
|
||||
describe('OllamaHandler', () => {
|
||||
let handler: OllamaHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
|
||||
beforeEach(() => {
|
||||
handler = new OllamaHandler({
|
||||
mockOptions = {
|
||||
apiModelId: 'llama2',
|
||||
ollamaModelId: 'llama2',
|
||||
ollamaBaseUrl: 'http://localhost:11434'
|
||||
});
|
||||
ollamaBaseUrl: 'http://localhost:11434/v1'
|
||||
};
|
||||
handler = new OllamaHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided config', () => {
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
baseURL: 'http://localhost:11434/v1',
|
||||
apiKey: 'ollama'
|
||||
});
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(OllamaHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.ollamaModelId);
|
||||
});
|
||||
|
||||
it('should use default base URL if not provided', () => {
|
||||
const defaultHandler = new OllamaHandler({
|
||||
const handlerWithoutUrl = new OllamaHandler({
|
||||
apiModelId: 'llama2',
|
||||
ollamaModelId: 'llama2'
|
||||
});
|
||||
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
baseURL: 'http://localhost:11434/v1',
|
||||
apiKey: 'ollama'
|
||||
});
|
||||
expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler);
|
||||
});
|
||||
});
|
||||
|
||||
describe('createMessage', () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello'
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi there!'
|
||||
content: 'Hello!'
|
||||
}
|
||||
];
|
||||
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
|
||||
it('should handle streaming responses correctly', async () => {
|
||||
const mockStream = [
|
||||
{
|
||||
choices: [{
|
||||
delta: { content: 'Hello' }
|
||||
}]
|
||||
},
|
||||
{
|
||||
choices: [{
|
||||
delta: { content: ' world!' }
|
||||
}]
|
||||
}
|
||||
];
|
||||
|
||||
// Setup async iterator for mock stream
|
||||
const asyncIterator = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
|
||||
(handler['client'].chat.completions as any).create = mockCreate;
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const chunks = [];
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
expect(chunks.length).toBe(2);
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Hello'
|
||||
});
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'text',
|
||||
text: ' world!'
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'llama2',
|
||||
messages: expect.arrayContaining([
|
||||
{
|
||||
role: 'system',
|
||||
content: systemPrompt
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
]),
|
||||
}).rejects.toThrow('API Error');
|
||||
});
|
||||
});
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.ollamaModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0,
|
||||
stream: true
|
||||
stream: false
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('Ollama API error');
|
||||
const mockCreate = jest.fn().mockRejectedValue(mockError);
|
||||
(handler['client'].chat.completions as any).create = mockCreate;
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Ollama completion error: API Error');
|
||||
});
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
}
|
||||
}).rejects.toThrow('Ollama API error');
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info with sane defaults', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe('llama2');
|
||||
expect(modelInfo.id).toBe(mockOptions.ollamaModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(-1);
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||
});
|
||||
|
||||
it('should return empty string as model ID if not provided', () => {
|
||||
const noModelHandler = new OllamaHandler({});
|
||||
const modelInfo = noModelHandler.getModel();
|
||||
expect(modelInfo.id).toBe('');
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,230 +1,209 @@
|
||||
import { OpenAiNativeHandler } from "../openai-native"
|
||||
import OpenAI from "openai"
|
||||
import { ApiHandlerOptions, openAiNativeDefaultModelId } from "../../../shared/api"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { OpenAiNativeHandler } from '../openai-native';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
|
||||
// Mock OpenAI
|
||||
jest.mock("openai")
|
||||
|
||||
describe("OpenAiNativeHandler", () => {
|
||||
let handler: OpenAiNativeHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
let mockOpenAIClient: jest.Mocked<OpenAI>
|
||||
let mockCreate: jest.Mock
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset mocks
|
||||
jest.clearAllMocks()
|
||||
|
||||
// Setup mock options
|
||||
mockOptions = {
|
||||
openAiNativeApiKey: "test-api-key",
|
||||
apiModelId: "gpt-4o", // Use the correct model ID from shared/api.ts
|
||||
}
|
||||
|
||||
// Setup mock create function
|
||||
mockCreate = jest.fn()
|
||||
|
||||
// Setup mock OpenAI client
|
||||
mockOpenAIClient = {
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate,
|
||||
},
|
||||
},
|
||||
} as unknown as jest.Mocked<OpenAI>
|
||||
|
||||
// Mock OpenAI constructor
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).mockImplementation(() => mockOpenAIClient)
|
||||
|
||||
// Create handler instance
|
||||
handler = new OpenAiNativeHandler(mockOptions)
|
||||
})
|
||||
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
apiKey: mockOptions.openAiNativeApiKey,
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe("getModel", () => {
|
||||
it("should return specified model when valid", () => {
|
||||
const result = handler.getModel()
|
||||
expect(result.id).toBe("gpt-4o") // Use the correct model ID
|
||||
})
|
||||
|
||||
it("should return default model when model ID is invalid", () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: "invalid-model" as any,
|
||||
})
|
||||
const result = handler.getModel()
|
||||
expect(result.id).toBe(openAiNativeDefaultModelId)
|
||||
})
|
||||
|
||||
it("should return default model when model ID is not provided", () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: undefined,
|
||||
})
|
||||
const result = handler.getModel()
|
||||
expect(result.id).toBe(openAiNativeDefaultModelId)
|
||||
})
|
||||
})
|
||||
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{ role: "user", content: "Hello" },
|
||||
]
|
||||
|
||||
describe("o1 models", () => {
|
||||
beforeEach(() => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: "o1-preview",
|
||||
})
|
||||
})
|
||||
|
||||
it("should handle non-streaming response for o1 models", async () => {
|
||||
const mockResponse = {
|
||||
choices: [{ message: { content: "Hello there!" } }],
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
},
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
mockCreate.mockResolvedValueOnce(mockResponse)
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const results = []
|
||||
for await (const result of generator) {
|
||||
results.push(result)
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
|
||||
expect(results).toEqual([
|
||||
{ type: "text", text: "Hello there!" },
|
||||
{ type: "usage", inputTokens: 10, outputTokens: 5 },
|
||||
])
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "o1-preview",
|
||||
messages: [
|
||||
{ role: "user", content: systemPrompt },
|
||||
{ role: "user", content: "Hello" },
|
||||
],
|
||||
})
|
||||
})
|
||||
|
||||
it("should handle missing content in response", async () => {
|
||||
const mockResponse = {
|
||||
choices: [{ message: { content: null } }],
|
||||
usage: null,
|
||||
};
|
||||
}
|
||||
|
||||
mockCreate.mockResolvedValueOnce(mockResponse)
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const results = []
|
||||
for await (const result of generator) {
|
||||
results.push(result)
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
|
||||
expect(results).toEqual([
|
||||
{ type: "text", text: "" },
|
||||
{ type: "usage", inputTokens: 0, outputTokens: 0 },
|
||||
])
|
||||
})
|
||||
})
|
||||
describe('OpenAiNativeHandler', () => {
|
||||
let handler: OpenAiNativeHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
|
||||
describe("streaming models", () => {
|
||||
beforeEach(() => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: "gpt-4o",
|
||||
})
|
||||
})
|
||||
mockOptions = {
|
||||
apiModelId: 'gpt-4o',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
};
|
||||
handler = new OpenAiNativeHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
|
||||
it("should handle streaming response", async () => {
|
||||
const mockStream = [
|
||||
{ choices: [{ delta: { content: "Hello" } }], usage: null },
|
||||
{ choices: [{ delta: { content: " there" } }], usage: null },
|
||||
{ choices: [{ delta: { content: "!" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||
]
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(OpenAiNativeHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
|
||||
});
|
||||
|
||||
mockCreate.mockResolvedValueOnce(
|
||||
(async function* () {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk
|
||||
it('should initialize with empty API key', () => {
|
||||
const handlerWithoutKey = new OpenAiNativeHandler({
|
||||
apiModelId: 'gpt-4o',
|
||||
openAiNativeApiKey: ''
|
||||
});
|
||||
expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler);
|
||||
});
|
||||
});
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello!'
|
||||
}
|
||||
})()
|
||||
)
|
||||
];
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const results = []
|
||||
for await (const result of generator) {
|
||||
results.push(result)
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
expect(results).toEqual([
|
||||
{ type: "text", text: "Hello" },
|
||||
{ type: "text", text: " there" },
|
||||
{ type: "text", text: "!" },
|
||||
{ type: "usage", inputTokens: 10, outputTokens: 5 },
|
||||
])
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "gpt-4o",
|
||||
temperature: 0,
|
||||
messages: [
|
||||
{ role: "system", content: systemPrompt },
|
||||
{ role: "user", content: "Hello" },
|
||||
],
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
})
|
||||
})
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
|
||||
it("should handle empty delta content", async () => {
|
||||
const mockStream = [
|
||||
{ choices: [{ delta: {} }], usage: null },
|
||||
{ choices: [{ delta: { content: null } }], usage: null },
|
||||
{ choices: [{ delta: { content: "Hello" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||
]
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
|
||||
mockCreate.mockResolvedValueOnce(
|
||||
(async function* () {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk
|
||||
}
|
||||
})()
|
||||
)
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const results = []
|
||||
for await (const result of generator) {
|
||||
results.push(result)
|
||||
}
|
||||
|
||||
expect(results).toEqual([
|
||||
{ type: "text", text: "Hello" },
|
||||
{ type: "usage", inputTokens: 10, outputTokens: 5 },
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
await expect(async () => {
|
||||
for await (const _ of generator) {
|
||||
// consume generator
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow("API Error")
|
||||
})
|
||||
})
|
||||
})
|
||||
}).rejects.toThrow('API Error');
|
||||
});
|
||||
});
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully with gpt-4o model', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'gpt-4o',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0
|
||||
});
|
||||
});
|
||||
|
||||
it('should complete prompt successfully with o1 model', async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: 'o1',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||
});
|
||||
});
|
||||
|
||||
it('should complete prompt successfully with o1-preview model', async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: 'o1-preview',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1-preview',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||
});
|
||||
});
|
||||
|
||||
it('should complete prompt successfully with o1-mini model', async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: 'o1-mini',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1-mini',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('OpenAI Native completion error: API Error');
|
||||
});
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe(mockOptions.apiModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(4096);
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||
});
|
||||
|
||||
it('should handle undefined model ID', () => {
|
||||
const handlerWithoutModel = new OpenAiNativeHandler({
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
const modelInfo = handlerWithoutModel.getModel();
|
||||
expect(modelInfo.id).toBe('gpt-4o'); // Default model
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -176,6 +176,32 @@ describe('OpenAiHandler', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.openAiModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('OpenAI completion error: API Error');
|
||||
});
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockImplementationOnce(() => ({
|
||||
choices: [{ message: { content: '' } }]
|
||||
}));
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info with sane defaults', () => {
|
||||
const model = handler.getModel();
|
||||
|
||||
@@ -6,7 +6,42 @@ import { AnthropicVertex } from '@anthropic-ai/vertex-sdk';
|
||||
jest.mock('@anthropic-ai/vertex-sdk', () => ({
|
||||
AnthropicVertex: jest.fn().mockImplementation(() => ({
|
||||
messages: {
|
||||
create: jest.fn()
|
||||
create: jest.fn().mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
content: [
|
||||
{ type: 'text', text: 'Test response' }
|
||||
],
|
||||
role: 'assistant',
|
||||
model: options.model,
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
return {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: 'message_start',
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
yield {
|
||||
type: 'content_block_start',
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'Test response'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}))
|
||||
}));
|
||||
@@ -196,6 +231,49 @@ describe('VertexHandler', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(handler['client'].messages.create).toHaveBeenCalledWith({
|
||||
model: 'claude-3-5-sonnet-v2@20241022',
|
||||
max_tokens: 8192,
|
||||
temperature: 0,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
stream: false
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('Vertex API error');
|
||||
const mockCreate = jest.fn().mockRejectedValue(mockError);
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Vertex completion error: Vertex API error');
|
||||
});
|
||||
|
||||
it('should handle non-text content', async () => {
|
||||
const mockCreate = jest.fn().mockResolvedValue({
|
||||
content: [{ type: 'image' }]
|
||||
});
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
const mockCreate = jest.fn().mockResolvedValue({
|
||||
content: [{ type: 'text', text: '' }]
|
||||
});
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return correct model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
|
||||
@@ -7,10 +7,10 @@ import {
|
||||
ApiHandlerOptions,
|
||||
ModelInfo,
|
||||
} from "../../shared/api"
|
||||
import { ApiHandler } from "../index"
|
||||
import { ApiHandler, SingleCompletionHandler } from "../index"
|
||||
import { ApiStream } from "../transform/stream"
|
||||
|
||||
export class AnthropicHandler implements ApiHandler {
|
||||
export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: Anthropic
|
||||
|
||||
@@ -173,4 +173,27 @@ export class AnthropicHandler implements ApiHandler {
|
||||
}
|
||||
return { id: anthropicDefaultModelId, info: anthropicModels[anthropicDefaultModelId] }
|
||||
}
|
||||
|
||||
async completePrompt(prompt: string): Promise<string> {
|
||||
try {
|
||||
const response = await this.client.messages.create({
|
||||
model: this.getModel().id,
|
||||
max_tokens: this.getModel().info.maxTokens || 8192,
|
||||
temperature: 0,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
stream: false
|
||||
})
|
||||
|
||||
const content = response.content[0]
|
||||
if (content.type === 'text') {
|
||||
return content.text
|
||||
}
|
||||
return ''
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`Anthropic completion error: ${error.message}`)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { BedrockRuntimeClient, ConverseStreamCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime"
|
||||
import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { ApiHandler } from "../"
|
||||
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
|
||||
import { ApiStream } from "../transform/stream"
|
||||
import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../transform/bedrock-converse-format"
|
||||
@@ -38,7 +38,7 @@ export interface StreamEvent {
|
||||
};
|
||||
}
|
||||
|
||||
export class AwsBedrockHandler implements ApiHandler {
|
||||
export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: BedrockRuntimeClient
|
||||
|
||||
@@ -219,4 +219,63 @@ export class AwsBedrockHandler implements ApiHandler {
|
||||
info: bedrockModels[bedrockDefaultModelId]
|
||||
}
|
||||
}
|
||||
|
||||
async completePrompt(prompt: string): Promise<string> {
|
||||
try {
|
||||
const modelConfig = this.getModel()
|
||||
|
||||
// Handle cross-region inference
|
||||
let modelId: string
|
||||
if (this.options.awsUseCrossRegionInference) {
|
||||
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
|
||||
switch (regionPrefix) {
|
||||
case "us-":
|
||||
modelId = `us.${modelConfig.id}`
|
||||
break
|
||||
case "eu-":
|
||||
modelId = `eu.${modelConfig.id}`
|
||||
break
|
||||
default:
|
||||
modelId = modelConfig.id
|
||||
break
|
||||
}
|
||||
} else {
|
||||
modelId = modelConfig.id
|
||||
}
|
||||
|
||||
const payload = {
|
||||
modelId,
|
||||
messages: convertToBedrockConverseMessages([{
|
||||
role: "user",
|
||||
content: prompt
|
||||
}]),
|
||||
inferenceConfig: {
|
||||
maxTokens: modelConfig.info.maxTokens || 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1
|
||||
}
|
||||
}
|
||||
|
||||
const command = new ConverseCommand(payload)
|
||||
const response = await this.client.send(command)
|
||||
|
||||
if (response.output && response.output instanceof Uint8Array) {
|
||||
try {
|
||||
const outputStr = new TextDecoder().decode(response.output)
|
||||
const output = JSON.parse(outputStr)
|
||||
if (output.content) {
|
||||
return output.content
|
||||
}
|
||||
} catch (parseError) {
|
||||
console.error('Failed to parse Bedrock response:', parseError)
|
||||
}
|
||||
}
|
||||
return ''
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`Bedrock completion error: ${error.message}`)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { GoogleGenerativeAI } from "@google/generative-ai"
|
||||
import { ApiHandler } from "../"
|
||||
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||
import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../../shared/api"
|
||||
import { convertAnthropicMessageToGemini } from "../transform/gemini-format"
|
||||
import { ApiStream } from "../transform/stream"
|
||||
|
||||
export class GeminiHandler implements ApiHandler {
|
||||
export class GeminiHandler implements ApiHandler, SingleCompletionHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: GoogleGenerativeAI
|
||||
|
||||
@@ -53,4 +53,26 @@ export class GeminiHandler implements ApiHandler {
|
||||
}
|
||||
return { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] }
|
||||
}
|
||||
|
||||
async completePrompt(prompt: string): Promise<string> {
|
||||
try {
|
||||
const model = this.client.getGenerativeModel({
|
||||
model: this.getModel().id,
|
||||
})
|
||||
|
||||
const result = await model.generateContent({
|
||||
contents: [{ role: "user", parts: [{ text: prompt }] }],
|
||||
generationConfig: {
|
||||
temperature: 0,
|
||||
},
|
||||
})
|
||||
|
||||
return result.response.text()
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`Gemini completion error: ${error.message}`)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import axios from "axios"
|
||||
import OpenAI from "openai"
|
||||
import { ApiHandler } from "../"
|
||||
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||
import { ApiHandlerOptions, ModelInfo, glamaDefaultModelId, glamaDefaultModelInfo } from "../../shared/api"
|
||||
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||
import { ApiStream } from "../transform/stream"
|
||||
import delay from "delay"
|
||||
|
||||
export class GlamaHandler implements ApiHandler {
|
||||
export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: OpenAI
|
||||
|
||||
@@ -129,4 +129,26 @@ export class GlamaHandler implements ApiHandler {
|
||||
|
||||
return { id: glamaDefaultModelId, info: glamaDefaultModelInfo }
|
||||
}
|
||||
|
||||
async completePrompt(prompt: string): Promise<string> {
|
||||
try {
|
||||
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
|
||||
model: this.getModel().id,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
temperature: 0,
|
||||
}
|
||||
|
||||
if (this.getModel().id.startsWith("anthropic/")) {
|
||||
requestOptions.max_tokens = 8192
|
||||
}
|
||||
|
||||
const response = await this.client.chat.completions.create(requestOptions)
|
||||
return response.choices[0]?.message.content || ""
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`Glama completion error: ${error.message}`)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import OpenAI from "openai"
|
||||
import { ApiHandler } from "../"
|
||||
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
|
||||
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||
import { ApiStream } from "../transform/stream"
|
||||
|
||||
export class LmStudioHandler implements ApiHandler {
|
||||
export class LmStudioHandler implements ApiHandler, SingleCompletionHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: OpenAI
|
||||
|
||||
@@ -53,4 +53,20 @@ export class LmStudioHandler implements ApiHandler {
|
||||
info: openAiModelInfoSaneDefaults,
|
||||
}
|
||||
}
|
||||
|
||||
async completePrompt(prompt: string): Promise<string> {
|
||||
try {
|
||||
const response = await this.client.chat.completions.create({
|
||||
model: this.getModel().id,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
})
|
||||
return response.choices[0]?.message.content || ""
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
"Please check the LM Studio developer logs to debug what went wrong. You may need to load the model with a larger context length to work with Cline's prompts.",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import OpenAI from "openai"
|
||||
import { ApiHandler } from "../"
|
||||
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
|
||||
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||
import { ApiStream } from "../transform/stream"
|
||||
|
||||
export class OllamaHandler implements ApiHandler {
|
||||
export class OllamaHandler implements ApiHandler, SingleCompletionHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: OpenAI
|
||||
|
||||
@@ -46,4 +46,21 @@ export class OllamaHandler implements ApiHandler {
|
||||
info: openAiModelInfoSaneDefaults,
|
||||
}
|
||||
}
|
||||
|
||||
async completePrompt(prompt: string): Promise<string> {
|
||||
try {
|
||||
const response = await this.client.chat.completions.create({
|
||||
model: this.getModel().id,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
})
|
||||
return response.choices[0]?.message.content || ""
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`Ollama completion error: ${error.message}`)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import OpenAI from "openai"
|
||||
import { ApiHandler } from "../"
|
||||
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||
import {
|
||||
ApiHandlerOptions,
|
||||
ModelInfo,
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||
import { ApiStream } from "../transform/stream"
|
||||
|
||||
export class OpenAiNativeHandler implements ApiHandler {
|
||||
export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: OpenAI
|
||||
|
||||
@@ -83,4 +83,37 @@ export class OpenAiNativeHandler implements ApiHandler {
|
||||
}
|
||||
return { id: openAiNativeDefaultModelId, info: openAiNativeModels[openAiNativeDefaultModelId] }
|
||||
}
|
||||
|
||||
async completePrompt(prompt: string): Promise<string> {
|
||||
try {
|
||||
const modelId = this.getModel().id
|
||||
let requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
|
||||
|
||||
switch (modelId) {
|
||||
case "o1":
|
||||
case "o1-preview":
|
||||
case "o1-mini":
|
||||
// o1 doesn't support non-1 temp or system prompt
|
||||
requestOptions = {
|
||||
model: modelId,
|
||||
messages: [{ role: "user", content: prompt }]
|
||||
}
|
||||
break
|
||||
default:
|
||||
requestOptions = {
|
||||
model: modelId,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
temperature: 0
|
||||
}
|
||||
}
|
||||
|
||||
const response = await this.client.chat.completions.create(requestOptions)
|
||||
return response.choices[0]?.message.content || ""
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`OpenAI Native completion error: ${error.message}`)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,11 +6,11 @@ import {
|
||||
ModelInfo,
|
||||
openAiModelInfoSaneDefaults,
|
||||
} from "../../shared/api"
|
||||
import { ApiHandler } from "../index"
|
||||
import { ApiHandler, SingleCompletionHandler } from "../index"
|
||||
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||
import { ApiStream } from "../transform/stream"
|
||||
|
||||
export class OpenAiHandler implements ApiHandler {
|
||||
export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
|
||||
protected options: ApiHandlerOptions
|
||||
private client: OpenAI
|
||||
|
||||
@@ -100,4 +100,22 @@ export class OpenAiHandler implements ApiHandler {
|
||||
info: openAiModelInfoSaneDefaults,
|
||||
}
|
||||
}
|
||||
|
||||
async completePrompt(prompt: string): Promise<string> {
|
||||
try {
|
||||
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
|
||||
model: this.getModel().id,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
temperature: 0,
|
||||
}
|
||||
|
||||
const response = await this.client.chat.completions.create(requestOptions)
|
||||
return response.choices[0]?.message.content || ""
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`OpenAI completion error: ${error.message}`)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
|
||||
import { ApiHandler } from "../"
|
||||
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||
import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../../shared/api"
|
||||
import { ApiStream } from "../transform/stream"
|
||||
|
||||
// https://docs.anthropic.com/en/api/claude-on-vertex-ai
|
||||
export class VertexHandler implements ApiHandler {
|
||||
export class VertexHandler implements ApiHandler, SingleCompletionHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: AnthropicVertex
|
||||
|
||||
@@ -83,4 +83,27 @@ export class VertexHandler implements ApiHandler {
|
||||
}
|
||||
return { id: vertexDefaultModelId, info: vertexModels[vertexDefaultModelId] }
|
||||
}
|
||||
|
||||
async completePrompt(prompt: string): Promise<string> {
|
||||
try {
|
||||
const response = await this.client.messages.create({
|
||||
model: this.getModel().id,
|
||||
max_tokens: this.getModel().info.maxTokens || 8192,
|
||||
temperature: 0,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
stream: false
|
||||
})
|
||||
|
||||
const content = response.content[0]
|
||||
if (content.type === 'text') {
|
||||
return content.text
|
||||
}
|
||||
return ''
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`Vertex completion error: ${error.message}`)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user