Add non-streaming completePrompt to all providers

This commit is contained in:
Matt Rubens
2025-01-13 16:16:58 -05:00
parent 2d176e5c92
commit 4027e1c10c
18 changed files with 1235 additions and 438 deletions

View File

@@ -1,148 +1,160 @@
import { LmStudioHandler } from '../lmstudio';
import { Anthropic } from '@anthropic-ai/sdk';
import { ApiHandlerOptions } from '../../../shared/api';
import OpenAI from 'openai';
import { Anthropic } from '@anthropic-ai/sdk';
// Mock OpenAI SDK
jest.mock('openai', () => ({
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: jest.fn()
// Mock OpenAI client
const mockCreate = jest.fn();
jest.mock('openai', () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: 'test-completion',
choices: [{
message: { role: 'assistant', content: 'Test response' },
finish_reason: 'stop',
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [{
delta: { content: 'Test response' },
index: 0
}],
usage: null
};
yield {
choices: [{
delta: {},
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
};
})
}
}
}
}))
}));
}))
};
});
describe('LmStudioHandler', () => {
let handler: LmStudioHandler;
let mockOptions: ApiHandlerOptions;
beforeEach(() => {
handler = new LmStudioHandler({
lmStudioModelId: 'mistral-7b',
lmStudioBaseUrl: 'http://localhost:1234'
});
mockOptions = {
apiModelId: 'local-model',
lmStudioModelId: 'local-model',
lmStudioBaseUrl: 'http://localhost:1234/v1'
};
handler = new LmStudioHandler(mockOptions);
mockCreate.mockClear();
});
describe('constructor', () => {
it('should initialize with provided config', () => {
expect(OpenAI).toHaveBeenCalledWith({
baseURL: 'http://localhost:1234/v1',
apiKey: 'noop'
});
it('should initialize with provided options', () => {
expect(handler).toBeInstanceOf(LmStudioHandler);
expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId);
});
it('should use default base URL if not provided', () => {
const defaultHandler = new LmStudioHandler({
lmStudioModelId: 'mistral-7b'
});
expect(OpenAI).toHaveBeenCalledWith({
baseURL: 'http://localhost:1234/v1',
apiKey: 'noop'
const handlerWithoutUrl = new LmStudioHandler({
apiModelId: 'local-model',
lmStudioModelId: 'local-model'
});
expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler);
});
});
describe('createMessage', () => {
const mockMessages: Anthropic.Messages.MessageParam[] = [
const systemPrompt = 'You are a helpful assistant.';
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: 'Hello'
},
{
role: 'assistant',
content: 'Hi there!'
content: 'Hello!'
}
];
const systemPrompt = 'You are a helpful assistant';
it('should handle streaming responses correctly', async () => {
const mockStream = [
{
choices: [{
delta: { content: 'Hello' }
}]
},
{
choices: [{
delta: { content: ' world!' }
}]
}
];
// Setup async iterator for mock stream
const asyncIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of mockStream) {
yield chunk;
}
}
};
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
(handler['client'].chat.completions as any).create = mockCreate;
const stream = handler.createMessage(systemPrompt, mockMessages);
const chunks = [];
it('should handle streaming responses', async () => {
const stream = handler.createMessage(systemPrompt, messages);
const chunks: any[] = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
expect(chunks.length).toBe(2);
expect(chunks[0]).toEqual({
type: 'text',
text: 'Hello'
});
expect(chunks[1]).toEqual({
type: 'text',
text: ' world!'
});
expect(mockCreate).toHaveBeenCalledWith({
model: 'mistral-7b',
messages: expect.arrayContaining([
{
role: 'system',
content: systemPrompt
}
]),
temperature: 0,
stream: true
});
expect(chunks.length).toBeGreaterThan(0);
const textChunks = chunks.filter(chunk => chunk.type === 'text');
expect(textChunks).toHaveLength(1);
expect(textChunks[0].text).toBe('Test response');
});
it('should handle API errors with custom message', async () => {
const mockError = new Error('LM Studio API error');
const mockCreate = jest.fn().mockRejectedValue(mockError);
(handler['client'].chat.completions as any).create = mockCreate;
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
const stream = handler.createMessage(systemPrompt, mockMessages);
const stream = handler.createMessage(systemPrompt, messages);
await expect(async () => {
for await (const chunk of stream) {
// Should throw before yielding any chunks
// Should not reach here
}
}).rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
});
});
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.lmStudioModelId,
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0,
stream: false
});
});
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
});
it('should handle empty response', async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: '' } }]
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
describe('getModel', () => {
it('should return model info with sane defaults', () => {
it('should return model info', () => {
const modelInfo = handler.getModel();
expect(modelInfo.id).toBe('mistral-7b');
expect(modelInfo.id).toBe(mockOptions.lmStudioModelId);
expect(modelInfo.info).toBeDefined();
expect(modelInfo.info.maxTokens).toBe(-1);
expect(modelInfo.info.contextWindow).toBe(128_000);
});
it('should return empty string as model ID if not provided', () => {
const noModelHandler = new LmStudioHandler({});
const modelInfo = noModelHandler.getModel();
expect(modelInfo.id).toBe('');
expect(modelInfo.info).toBeDefined();
});
});
});