mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-21 04:41:16 -05:00
Add non-streaming completePrompt to all providers
This commit is contained in:
@@ -1,148 +1,160 @@
|
||||
import { OllamaHandler } from '../ollama';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
|
||||
// Mock OpenAI SDK
|
||||
jest.mock('openai', () => ({
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: jest.fn()
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
}));
|
||||
}))
|
||||
};
|
||||
});
|
||||
|
||||
describe('OllamaHandler', () => {
|
||||
let handler: OllamaHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
|
||||
beforeEach(() => {
|
||||
handler = new OllamaHandler({
|
||||
mockOptions = {
|
||||
apiModelId: 'llama2',
|
||||
ollamaModelId: 'llama2',
|
||||
ollamaBaseUrl: 'http://localhost:11434'
|
||||
});
|
||||
ollamaBaseUrl: 'http://localhost:11434/v1'
|
||||
};
|
||||
handler = new OllamaHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided config', () => {
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
baseURL: 'http://localhost:11434/v1',
|
||||
apiKey: 'ollama'
|
||||
});
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(OllamaHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.ollamaModelId);
|
||||
});
|
||||
|
||||
it('should use default base URL if not provided', () => {
|
||||
const defaultHandler = new OllamaHandler({
|
||||
const handlerWithoutUrl = new OllamaHandler({
|
||||
apiModelId: 'llama2',
|
||||
ollamaModelId: 'llama2'
|
||||
});
|
||||
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
baseURL: 'http://localhost:11434/v1',
|
||||
apiKey: 'ollama'
|
||||
});
|
||||
expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler);
|
||||
});
|
||||
});
|
||||
|
||||
describe('createMessage', () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello'
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi there!'
|
||||
content: 'Hello!'
|
||||
}
|
||||
];
|
||||
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
|
||||
it('should handle streaming responses correctly', async () => {
|
||||
const mockStream = [
|
||||
{
|
||||
choices: [{
|
||||
delta: { content: 'Hello' }
|
||||
}]
|
||||
},
|
||||
{
|
||||
choices: [{
|
||||
delta: { content: ' world!' }
|
||||
}]
|
||||
}
|
||||
];
|
||||
|
||||
// Setup async iterator for mock stream
|
||||
const asyncIterator = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
|
||||
(handler['client'].chat.completions as any).create = mockCreate;
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const chunks = [];
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
expect(chunks.length).toBe(2);
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Hello'
|
||||
});
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'text',
|
||||
text: ' world!'
|
||||
});
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow('API Error');
|
||||
});
|
||||
});
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'llama2',
|
||||
messages: expect.arrayContaining([
|
||||
{
|
||||
role: 'system',
|
||||
content: systemPrompt
|
||||
}
|
||||
]),
|
||||
model: mockOptions.ollamaModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0,
|
||||
stream: true
|
||||
stream: false
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('Ollama API error');
|
||||
const mockCreate = jest.fn().mockRejectedValue(mockError);
|
||||
(handler['client'].chat.completions as any).create = mockCreate;
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Ollama completion error: API Error');
|
||||
});
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
}
|
||||
}).rejects.toThrow('Ollama API error');
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info with sane defaults', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe('llama2');
|
||||
expect(modelInfo.id).toBe(mockOptions.ollamaModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(-1);
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||
});
|
||||
|
||||
it('should return empty string as model ID if not provided', () => {
|
||||
const noModelHandler = new OllamaHandler({});
|
||||
const modelInfo = noModelHandler.getModel();
|
||||
expect(modelInfo.id).toBe('');
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user