Add non-streaming completePrompt to all providers

This commit is contained in:
Matt Rubens
2025-01-13 16:16:58 -05:00
parent 2d176e5c92
commit 4027e1c10c
18 changed files with 1235 additions and 438 deletions

View File

@@ -46,7 +46,42 @@ jest.mock('@anthropic-ai/sdk', () => {
}
},
messages: {
create: mockCreate
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: 'test-completion',
content: [
{ type: 'text', text: 'Test response' }
],
role: 'assistant',
model: options.model,
usage: {
input_tokens: 10,
output_tokens: 5
}
}
}
return {
async *[Symbol.asyncIterator]() {
yield {
type: 'message_start',
message: {
usage: {
input_tokens: 10,
output_tokens: 5
}
}
}
yield {
type: 'content_block_start',
content_block: {
type: 'text',
text: 'Test response'
}
}
}
}
})
}
}))
};
@@ -144,6 +179,42 @@ describe('AnthropicHandler', () => {
});
});
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.apiModelId,
messages: [{ role: 'user', content: 'Test prompt' }],
max_tokens: 8192,
temperature: 0,
stream: false
});
});
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Anthropic completion error: API Error');
});
it('should handle non-text content', async () => {
mockCreate.mockImplementationOnce(async () => ({
content: [{ type: 'image' }]
}));
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
it('should handle empty response', async () => {
mockCreate.mockImplementationOnce(async () => ({
content: [{ type: 'text', text: '' }]
}));
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
describe('getModel', () => {
it('should return default model if no model ID is provided', () => {
const handlerWithoutModel = new AnthropicHandler({

View File

@@ -119,6 +119,108 @@ describe('AwsBedrockHandler', () => {
});
});
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const mockResponse = {
output: new TextEncoder().encode(JSON.stringify({
content: 'Test response'
}))
};
const mockSend = jest.fn().mockResolvedValue(mockResponse);
handler['client'] = {
send: mockSend
} as unknown as BedrockRuntimeClient;
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
input: expect.objectContaining({
modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
messages: expect.arrayContaining([
expect.objectContaining({
role: 'user',
content: [{ text: 'Test prompt' }]
})
]),
inferenceConfig: expect.objectContaining({
maxTokens: 5000,
temperature: 0.3,
topP: 0.1
})
})
}));
});
it('should handle API errors', async () => {
const mockError = new Error('AWS Bedrock error');
const mockSend = jest.fn().mockRejectedValue(mockError);
handler['client'] = {
send: mockSend
} as unknown as BedrockRuntimeClient;
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Bedrock completion error: AWS Bedrock error');
});
it('should handle invalid response format', async () => {
const mockResponse = {
output: new TextEncoder().encode('invalid json')
};
const mockSend = jest.fn().mockResolvedValue(mockResponse);
handler['client'] = {
send: mockSend
} as unknown as BedrockRuntimeClient;
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
it('should handle empty response', async () => {
const mockResponse = {
output: new TextEncoder().encode(JSON.stringify({}))
};
const mockSend = jest.fn().mockResolvedValue(mockResponse);
handler['client'] = {
send: mockSend
} as unknown as BedrockRuntimeClient;
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
it('should handle cross-region inference', async () => {
handler = new AwsBedrockHandler({
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
awsAccessKey: 'test-access-key',
awsSecretKey: 'test-secret-key',
awsRegion: 'us-east-1',
awsUseCrossRegionInference: true
});
const mockResponse = {
output: new TextEncoder().encode(JSON.stringify({
content: 'Test response'
}))
};
const mockSend = jest.fn().mockResolvedValue(mockResponse);
handler['client'] = {
send: mockSend
} as unknown as BedrockRuntimeClient;
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
input: expect.objectContaining({
modelId: 'us.anthropic.claude-3-5-sonnet-20241022-v2:0'
})
}));
});
});
describe('getModel', () => {
it('should return correct model info in test environment', () => {
const modelInfo = handler.getModel();

View File

@@ -6,7 +6,12 @@ import { GoogleGenerativeAI } from '@google/generative-ai';
jest.mock('@google/generative-ai', () => ({
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
getGenerativeModel: jest.fn().mockReturnValue({
generateContentStream: jest.fn()
generateContentStream: jest.fn(),
generateContent: jest.fn().mockResolvedValue({
response: {
text: () => 'Test response'
}
})
})
}))
}));
@@ -133,6 +138,59 @@ describe('GeminiHandler', () => {
});
});
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const mockGenerateContent = jest.fn().mockResolvedValue({
response: {
text: () => 'Test response'
}
});
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent
});
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
model: 'gemini-2.0-flash-thinking-exp-1219'
});
expect(mockGenerateContent).toHaveBeenCalledWith({
contents: [{ role: 'user', parts: [{ text: 'Test prompt' }] }],
generationConfig: {
temperature: 0
}
});
});
it('should handle API errors', async () => {
const mockError = new Error('Gemini API error');
const mockGenerateContent = jest.fn().mockRejectedValue(mockError);
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent
});
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Gemini completion error: Gemini API error');
});
it('should handle empty response', async () => {
const mockGenerateContent = jest.fn().mockResolvedValue({
response: {
text: () => ''
}
});
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent
});
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
describe('getModel', () => {
it('should return correct model info', () => {
const modelInfo = handler.getModel();

View File

@@ -0,0 +1,226 @@
import { GlamaHandler } from '../glama';
import { ApiHandlerOptions } from '../../../shared/api';
import OpenAI from 'openai';
import { Anthropic } from '@anthropic-ai/sdk';
import axios from 'axios';
// Mock OpenAI client
const mockCreate = jest.fn();
const mockWithResponse = jest.fn();
jest.mock('openai', () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: (...args: any[]) => {
const stream = {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [{
delta: { content: 'Test response' },
index: 0
}],
usage: null
};
yield {
choices: [{
delta: {},
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
};
const result = mockCreate(...args);
if (args[0].stream) {
mockWithResponse.mockReturnValue(Promise.resolve({
data: stream,
response: {
headers: {
get: (name: string) => name === 'x-completion-request-id' ? 'test-request-id' : null
}
}
}));
result.withResponse = mockWithResponse;
}
return result;
}
}
}
}))
};
});
describe('GlamaHandler', () => {
let handler: GlamaHandler;
let mockOptions: ApiHandlerOptions;
beforeEach(() => {
mockOptions = {
apiModelId: 'anthropic/claude-3-5-sonnet',
glamaModelId: 'anthropic/claude-3-5-sonnet',
glamaApiKey: 'test-api-key'
};
handler = new GlamaHandler(mockOptions);
mockCreate.mockClear();
mockWithResponse.mockClear();
// Default mock implementation for non-streaming responses
mockCreate.mockResolvedValue({
id: 'test-completion',
choices: [{
message: { role: 'assistant', content: 'Test response' },
finish_reason: 'stop',
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
});
});
describe('constructor', () => {
it('should initialize with provided options', () => {
expect(handler).toBeInstanceOf(GlamaHandler);
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
});
});
describe('createMessage', () => {
const systemPrompt = 'You are a helpful assistant.';
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: 'Hello!'
}
];
it('should handle streaming responses', async () => {
// Mock axios for token usage request
const mockAxios = jest.spyOn(axios, 'get').mockResolvedValueOnce({
data: {
tokenUsage: {
promptTokens: 10,
completionTokens: 5,
cacheCreationInputTokens: 0,
cacheReadInputTokens: 0
},
totalCostUsd: "0.00"
}
});
const stream = handler.createMessage(systemPrompt, messages);
const chunks: any[] = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
expect(chunks.length).toBe(2); // Text chunk and usage chunk
expect(chunks[0]).toEqual({
type: 'text',
text: 'Test response'
});
expect(chunks[1]).toEqual({
type: 'usage',
inputTokens: 10,
outputTokens: 5,
cacheWriteTokens: 0,
cacheReadTokens: 0,
totalCost: 0
});
mockAxios.mockRestore();
});
it('should handle API errors', async () => {
mockCreate.mockImplementationOnce(() => {
throw new Error('API Error');
});
const stream = handler.createMessage(systemPrompt, messages);
const chunks = [];
try {
for await (const chunk of stream) {
chunks.push(chunk);
}
fail('Expected error to be thrown');
} catch (error) {
expect(error).toBeInstanceOf(Error);
expect(error.message).toBe('API Error');
}
});
});
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
model: mockOptions.apiModelId,
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0,
max_tokens: 8192
}));
});
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Glama completion error: API Error');
});
it('should handle empty response', async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: '' } }]
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
it('should not set max_tokens for non-Anthropic models', async () => {
// Reset mock to clear any previous calls
mockCreate.mockClear();
const nonAnthropicOptions = {
apiModelId: 'openai/gpt-4',
glamaModelId: 'openai/gpt-4',
glamaApiKey: 'test-key',
glamaModelInfo: {
maxTokens: 4096,
contextWindow: 8192,
supportsImages: true,
supportsPromptCache: false
}
};
const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions);
await nonAnthropicHandler.completePrompt('Test prompt');
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
model: 'openai/gpt-4',
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0
}));
expect(mockCreate.mock.calls[0][0]).not.toHaveProperty('max_tokens');
});
});
describe('getModel', () => {
it('should return model info', () => {
const modelInfo = handler.getModel();
expect(modelInfo.id).toBe(mockOptions.apiModelId);
expect(modelInfo.info).toBeDefined();
expect(modelInfo.info.maxTokens).toBe(8192);
expect(modelInfo.info.contextWindow).toBe(200_000);
});
});
});

View File

@@ -1,148 +1,160 @@
import { LmStudioHandler } from '../lmstudio';
import { Anthropic } from '@anthropic-ai/sdk';
import { ApiHandlerOptions } from '../../../shared/api';
import OpenAI from 'openai';
import { Anthropic } from '@anthropic-ai/sdk';
// Mock OpenAI SDK
jest.mock('openai', () => ({
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: jest.fn()
// Mock OpenAI client
const mockCreate = jest.fn();
jest.mock('openai', () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: 'test-completion',
choices: [{
message: { role: 'assistant', content: 'Test response' },
finish_reason: 'stop',
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [{
delta: { content: 'Test response' },
index: 0
}],
usage: null
};
yield {
choices: [{
delta: {},
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
};
})
}
}
}
}))
}));
}))
};
});
describe('LmStudioHandler', () => {
let handler: LmStudioHandler;
let mockOptions: ApiHandlerOptions;
beforeEach(() => {
handler = new LmStudioHandler({
lmStudioModelId: 'mistral-7b',
lmStudioBaseUrl: 'http://localhost:1234'
});
mockOptions = {
apiModelId: 'local-model',
lmStudioModelId: 'local-model',
lmStudioBaseUrl: 'http://localhost:1234/v1'
};
handler = new LmStudioHandler(mockOptions);
mockCreate.mockClear();
});
describe('constructor', () => {
it('should initialize with provided config', () => {
expect(OpenAI).toHaveBeenCalledWith({
baseURL: 'http://localhost:1234/v1',
apiKey: 'noop'
});
it('should initialize with provided options', () => {
expect(handler).toBeInstanceOf(LmStudioHandler);
expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId);
});
it('should use default base URL if not provided', () => {
const defaultHandler = new LmStudioHandler({
lmStudioModelId: 'mistral-7b'
});
expect(OpenAI).toHaveBeenCalledWith({
baseURL: 'http://localhost:1234/v1',
apiKey: 'noop'
const handlerWithoutUrl = new LmStudioHandler({
apiModelId: 'local-model',
lmStudioModelId: 'local-model'
});
expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler);
});
});
describe('createMessage', () => {
const mockMessages: Anthropic.Messages.MessageParam[] = [
const systemPrompt = 'You are a helpful assistant.';
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: 'Hello'
},
{
role: 'assistant',
content: 'Hi there!'
content: 'Hello!'
}
];
const systemPrompt = 'You are a helpful assistant';
it('should handle streaming responses correctly', async () => {
const mockStream = [
{
choices: [{
delta: { content: 'Hello' }
}]
},
{
choices: [{
delta: { content: ' world!' }
}]
}
];
// Setup async iterator for mock stream
const asyncIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of mockStream) {
yield chunk;
}
}
};
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
(handler['client'].chat.completions as any).create = mockCreate;
const stream = handler.createMessage(systemPrompt, mockMessages);
const chunks = [];
it('should handle streaming responses', async () => {
const stream = handler.createMessage(systemPrompt, messages);
const chunks: any[] = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
expect(chunks.length).toBe(2);
expect(chunks[0]).toEqual({
type: 'text',
text: 'Hello'
});
expect(chunks[1]).toEqual({
type: 'text',
text: ' world!'
});
expect(mockCreate).toHaveBeenCalledWith({
model: 'mistral-7b',
messages: expect.arrayContaining([
{
role: 'system',
content: systemPrompt
}
]),
temperature: 0,
stream: true
});
expect(chunks.length).toBeGreaterThan(0);
const textChunks = chunks.filter(chunk => chunk.type === 'text');
expect(textChunks).toHaveLength(1);
expect(textChunks[0].text).toBe('Test response');
});
it('should handle API errors with custom message', async () => {
const mockError = new Error('LM Studio API error');
const mockCreate = jest.fn().mockRejectedValue(mockError);
(handler['client'].chat.completions as any).create = mockCreate;
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
const stream = handler.createMessage(systemPrompt, mockMessages);
const stream = handler.createMessage(systemPrompt, messages);
await expect(async () => {
for await (const chunk of stream) {
// Should throw before yielding any chunks
// Should not reach here
}
}).rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
});
});
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.lmStudioModelId,
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0,
stream: false
});
});
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
});
it('should handle empty response', async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: '' } }]
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
describe('getModel', () => {
it('should return model info with sane defaults', () => {
it('should return model info', () => {
const modelInfo = handler.getModel();
expect(modelInfo.id).toBe('mistral-7b');
expect(modelInfo.id).toBe(mockOptions.lmStudioModelId);
expect(modelInfo.info).toBeDefined();
expect(modelInfo.info.maxTokens).toBe(-1);
expect(modelInfo.info.contextWindow).toBe(128_000);
});
it('should return empty string as model ID if not provided', () => {
const noModelHandler = new LmStudioHandler({});
const modelInfo = noModelHandler.getModel();
expect(modelInfo.id).toBe('');
expect(modelInfo.info).toBeDefined();
});
});
});

View File

@@ -1,148 +1,160 @@
import { OllamaHandler } from '../ollama';
import { Anthropic } from '@anthropic-ai/sdk';
import { ApiHandlerOptions } from '../../../shared/api';
import OpenAI from 'openai';
import { Anthropic } from '@anthropic-ai/sdk';
// Mock OpenAI SDK
jest.mock('openai', () => ({
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: jest.fn()
// Mock OpenAI client
const mockCreate = jest.fn();
jest.mock('openai', () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: 'test-completion',
choices: [{
message: { role: 'assistant', content: 'Test response' },
finish_reason: 'stop',
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [{
delta: { content: 'Test response' },
index: 0
}],
usage: null
};
yield {
choices: [{
delta: {},
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
};
})
}
}
}
}))
}));
}))
};
});
describe('OllamaHandler', () => {
let handler: OllamaHandler;
let mockOptions: ApiHandlerOptions;
beforeEach(() => {
handler = new OllamaHandler({
mockOptions = {
apiModelId: 'llama2',
ollamaModelId: 'llama2',
ollamaBaseUrl: 'http://localhost:11434'
});
ollamaBaseUrl: 'http://localhost:11434/v1'
};
handler = new OllamaHandler(mockOptions);
mockCreate.mockClear();
});
describe('constructor', () => {
it('should initialize with provided config', () => {
expect(OpenAI).toHaveBeenCalledWith({
baseURL: 'http://localhost:11434/v1',
apiKey: 'ollama'
});
it('should initialize with provided options', () => {
expect(handler).toBeInstanceOf(OllamaHandler);
expect(handler.getModel().id).toBe(mockOptions.ollamaModelId);
});
it('should use default base URL if not provided', () => {
const defaultHandler = new OllamaHandler({
const handlerWithoutUrl = new OllamaHandler({
apiModelId: 'llama2',
ollamaModelId: 'llama2'
});
expect(OpenAI).toHaveBeenCalledWith({
baseURL: 'http://localhost:11434/v1',
apiKey: 'ollama'
});
expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler);
});
});
describe('createMessage', () => {
const mockMessages: Anthropic.Messages.MessageParam[] = [
const systemPrompt = 'You are a helpful assistant.';
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: 'Hello'
},
{
role: 'assistant',
content: 'Hi there!'
content: 'Hello!'
}
];
const systemPrompt = 'You are a helpful assistant';
it('should handle streaming responses correctly', async () => {
const mockStream = [
{
choices: [{
delta: { content: 'Hello' }
}]
},
{
choices: [{
delta: { content: ' world!' }
}]
}
];
// Setup async iterator for mock stream
const asyncIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of mockStream) {
yield chunk;
}
}
};
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
(handler['client'].chat.completions as any).create = mockCreate;
const stream = handler.createMessage(systemPrompt, mockMessages);
const chunks = [];
it('should handle streaming responses', async () => {
const stream = handler.createMessage(systemPrompt, messages);
const chunks: any[] = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
expect(chunks.length).toBe(2);
expect(chunks[0]).toEqual({
type: 'text',
text: 'Hello'
});
expect(chunks[1]).toEqual({
type: 'text',
text: ' world!'
});
expect(chunks.length).toBeGreaterThan(0);
const textChunks = chunks.filter(chunk => chunk.type === 'text');
expect(textChunks).toHaveLength(1);
expect(textChunks[0].text).toBe('Test response');
});
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
const stream = handler.createMessage(systemPrompt, messages);
await expect(async () => {
for await (const chunk of stream) {
// Should not reach here
}
}).rejects.toThrow('API Error');
});
});
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: 'llama2',
messages: expect.arrayContaining([
{
role: 'system',
content: systemPrompt
}
]),
model: mockOptions.ollamaModelId,
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0,
stream: true
stream: false
});
});
it('should handle API errors', async () => {
const mockError = new Error('Ollama API error');
const mockCreate = jest.fn().mockRejectedValue(mockError);
(handler['client'].chat.completions as any).create = mockCreate;
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Ollama completion error: API Error');
});
const stream = handler.createMessage(systemPrompt, mockMessages);
await expect(async () => {
for await (const chunk of stream) {
// Should throw before yielding any chunks
}
}).rejects.toThrow('Ollama API error');
it('should handle empty response', async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: '' } }]
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
describe('getModel', () => {
it('should return model info with sane defaults', () => {
it('should return model info', () => {
const modelInfo = handler.getModel();
expect(modelInfo.id).toBe('llama2');
expect(modelInfo.id).toBe(mockOptions.ollamaModelId);
expect(modelInfo.info).toBeDefined();
expect(modelInfo.info.maxTokens).toBe(-1);
expect(modelInfo.info.contextWindow).toBe(128_000);
});
it('should return empty string as model ID if not provided', () => {
const noModelHandler = new OllamaHandler({});
const modelInfo = noModelHandler.getModel();
expect(modelInfo.id).toBe('');
expect(modelInfo.info).toBeDefined();
});
});
});

View File

@@ -1,230 +1,209 @@
import { OpenAiNativeHandler } from "../openai-native"
import OpenAI from "openai"
import { ApiHandlerOptions, openAiNativeDefaultModelId } from "../../../shared/api"
import { Anthropic } from "@anthropic-ai/sdk"
import { OpenAiNativeHandler } from '../openai-native';
import { ApiHandlerOptions } from '../../../shared/api';
import OpenAI from 'openai';
import { Anthropic } from '@anthropic-ai/sdk';
// Mock OpenAI
jest.mock("openai")
describe("OpenAiNativeHandler", () => {
let handler: OpenAiNativeHandler
let mockOptions: ApiHandlerOptions
let mockOpenAIClient: jest.Mocked<OpenAI>
let mockCreate: jest.Mock
beforeEach(() => {
// Reset mocks
jest.clearAllMocks()
// Setup mock options
mockOptions = {
openAiNativeApiKey: "test-api-key",
apiModelId: "gpt-4o", // Use the correct model ID from shared/api.ts
}
// Setup mock create function
mockCreate = jest.fn()
// Setup mock OpenAI client
mockOpenAIClient = {
// Mock OpenAI client
const mockCreate = jest.fn();
jest.mock('openai', () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate,
},
},
} as unknown as jest.Mocked<OpenAI>
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: 'test-completion',
choices: [{
message: { role: 'assistant', content: 'Test response' },
finish_reason: 'stop',
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [{
delta: { content: 'Test response' },
index: 0
}],
usage: null
};
yield {
choices: [{
delta: {},
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
};
})
}
}
}))
};
});
// Mock OpenAI constructor
;(OpenAI as jest.MockedClass<typeof OpenAI>).mockImplementation(() => mockOpenAIClient)
describe('OpenAiNativeHandler', () => {
let handler: OpenAiNativeHandler;
let mockOptions: ApiHandlerOptions;
// Create handler instance
handler = new OpenAiNativeHandler(mockOptions)
})
beforeEach(() => {
mockOptions = {
apiModelId: 'gpt-4o',
openAiNativeApiKey: 'test-api-key'
};
handler = new OpenAiNativeHandler(mockOptions);
mockCreate.mockClear();
});
describe("constructor", () => {
it("should initialize with provided options", () => {
expect(OpenAI).toHaveBeenCalledWith({
apiKey: mockOptions.openAiNativeApiKey,
})
})
})
describe('constructor', () => {
it('should initialize with provided options', () => {
expect(handler).toBeInstanceOf(OpenAiNativeHandler);
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
});
describe("getModel", () => {
it("should return specified model when valid", () => {
const result = handler.getModel()
expect(result.id).toBe("gpt-4o") // Use the correct model ID
})
it('should initialize with empty API key', () => {
const handlerWithoutKey = new OpenAiNativeHandler({
apiModelId: 'gpt-4o',
openAiNativeApiKey: ''
});
expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler);
});
});
it("should return default model when model ID is invalid", () => {
handler = new OpenAiNativeHandler({
...mockOptions,
apiModelId: "invalid-model" as any,
})
const result = handler.getModel()
expect(result.id).toBe(openAiNativeDefaultModelId)
})
it("should return default model when model ID is not provided", () => {
handler = new OpenAiNativeHandler({
...mockOptions,
apiModelId: undefined,
})
const result = handler.getModel()
expect(result.id).toBe(openAiNativeDefaultModelId)
})
})
describe("createMessage", () => {
const systemPrompt = "You are a helpful assistant"
describe('createMessage', () => {
const systemPrompt = 'You are a helpful assistant.';
const messages: Anthropic.Messages.MessageParam[] = [
{ role: "user", content: "Hello" },
]
{
role: 'user',
content: 'Hello!'
}
];
describe("o1 models", () => {
beforeEach(() => {
handler = new OpenAiNativeHandler({
...mockOptions,
apiModelId: "o1-preview",
})
})
it('should handle streaming responses', async () => {
const stream = handler.createMessage(systemPrompt, messages);
const chunks: any[] = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
it("should handle non-streaming response for o1 models", async () => {
const mockResponse = {
choices: [{ message: { content: "Hello there!" } }],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
},
}
expect(chunks.length).toBeGreaterThan(0);
const textChunks = chunks.filter(chunk => chunk.type === 'text');
expect(textChunks).toHaveLength(1);
expect(textChunks[0].text).toBe('Test response');
});
mockCreate.mockResolvedValueOnce(mockResponse)
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
const generator = handler.createMessage(systemPrompt, messages)
const results = []
for await (const result of generator) {
results.push(result)
}
const stream = handler.createMessage(systemPrompt, messages);
expect(results).toEqual([
{ type: "text", text: "Hello there!" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
expect(mockCreate).toHaveBeenCalledWith({
model: "o1-preview",
messages: [
{ role: "user", content: systemPrompt },
{ role: "user", content: "Hello" },
],
})
})
it("should handle missing content in response", async () => {
const mockResponse = {
choices: [{ message: { content: null } }],
usage: null,
}
mockCreate.mockResolvedValueOnce(mockResponse)
const generator = handler.createMessage(systemPrompt, messages)
const results = []
for await (const result of generator) {
results.push(result)
}
expect(results).toEqual([
{ type: "text", text: "" },
{ type: "usage", inputTokens: 0, outputTokens: 0 },
])
})
})
describe("streaming models", () => {
beforeEach(() => {
handler = new OpenAiNativeHandler({
...mockOptions,
apiModelId: "gpt-4o",
})
})
it("should handle streaming response", async () => {
const mockStream = [
{ choices: [{ delta: { content: "Hello" } }], usage: null },
{ choices: [{ delta: { content: " there" } }], usage: null },
{ choices: [{ delta: { content: "!" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
]
mockCreate.mockResolvedValueOnce(
(async function* () {
for (const chunk of mockStream) {
yield chunk
}
})()
)
const generator = handler.createMessage(systemPrompt, messages)
const results = []
for await (const result of generator) {
results.push(result)
}
expect(results).toEqual([
{ type: "text", text: "Hello" },
{ type: "text", text: " there" },
{ type: "text", text: "!" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
expect(mockCreate).toHaveBeenCalledWith({
model: "gpt-4o",
temperature: 0,
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: "Hello" },
],
stream: true,
stream_options: { include_usage: true },
})
})
it("should handle empty delta content", async () => {
const mockStream = [
{ choices: [{ delta: {} }], usage: null },
{ choices: [{ delta: { content: null } }], usage: null },
{ choices: [{ delta: { content: "Hello" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
]
mockCreate.mockResolvedValueOnce(
(async function* () {
for (const chunk of mockStream) {
yield chunk
}
})()
)
const generator = handler.createMessage(systemPrompt, messages)
const results = []
for await (const result of generator) {
results.push(result)
}
expect(results).toEqual([
{ type: "text", text: "Hello" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
})
})
it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error("API Error"))
const generator = handler.createMessage(systemPrompt, messages)
await expect(async () => {
for await (const _ of generator) {
// consume generator
for await (const chunk of stream) {
// Should not reach here
}
}).rejects.toThrow("API Error")
})
})
})
}).rejects.toThrow('API Error');
});
});
describe('completePrompt', () => {
it('should complete prompt successfully with gpt-4o model', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: 'gpt-4o',
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0
});
});
it('should complete prompt successfully with o1 model', async () => {
handler = new OpenAiNativeHandler({
apiModelId: 'o1',
openAiNativeApiKey: 'test-api-key'
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: 'o1',
messages: [{ role: 'user', content: 'Test prompt' }]
});
});
it('should complete prompt successfully with o1-preview model', async () => {
handler = new OpenAiNativeHandler({
apiModelId: 'o1-preview',
openAiNativeApiKey: 'test-api-key'
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: 'o1-preview',
messages: [{ role: 'user', content: 'Test prompt' }]
});
});
it('should complete prompt successfully with o1-mini model', async () => {
handler = new OpenAiNativeHandler({
apiModelId: 'o1-mini',
openAiNativeApiKey: 'test-api-key'
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: 'o1-mini',
messages: [{ role: 'user', content: 'Test prompt' }]
});
});
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('OpenAI Native completion error: API Error');
});
it('should handle empty response', async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: '' } }]
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
describe('getModel', () => {
it('should return model info', () => {
const modelInfo = handler.getModel();
expect(modelInfo.id).toBe(mockOptions.apiModelId);
expect(modelInfo.info).toBeDefined();
expect(modelInfo.info.maxTokens).toBe(4096);
expect(modelInfo.info.contextWindow).toBe(128_000);
});
it('should handle undefined model ID', () => {
const handlerWithoutModel = new OpenAiNativeHandler({
openAiNativeApiKey: 'test-api-key'
});
const modelInfo = handlerWithoutModel.getModel();
expect(modelInfo.id).toBe('gpt-4o'); // Default model
expect(modelInfo.info).toBeDefined();
});
});
});

View File

@@ -176,6 +176,32 @@ describe('OpenAiHandler', () => {
});
});
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.openAiModelId,
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0
});
});
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('OpenAI completion error: API Error');
});
it('should handle empty response', async () => {
mockCreate.mockImplementationOnce(() => ({
choices: [{ message: { content: '' } }]
}));
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
describe('getModel', () => {
it('should return model info with sane defaults', () => {
const model = handler.getModel();

View File

@@ -6,7 +6,42 @@ import { AnthropicVertex } from '@anthropic-ai/vertex-sdk';
jest.mock('@anthropic-ai/vertex-sdk', () => ({
AnthropicVertex: jest.fn().mockImplementation(() => ({
messages: {
create: jest.fn()
create: jest.fn().mockImplementation(async (options) => {
if (!options.stream) {
return {
id: 'test-completion',
content: [
{ type: 'text', text: 'Test response' }
],
role: 'assistant',
model: options.model,
usage: {
input_tokens: 10,
output_tokens: 5
}
}
}
return {
async *[Symbol.asyncIterator]() {
yield {
type: 'message_start',
message: {
usage: {
input_tokens: 10,
output_tokens: 5
}
}
}
yield {
type: 'content_block_start',
content_block: {
type: 'text',
text: 'Test response'
}
}
}
}
})
}
}))
}));
@@ -196,6 +231,49 @@ describe('VertexHandler', () => {
});
});
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(handler['client'].messages.create).toHaveBeenCalledWith({
model: 'claude-3-5-sonnet-v2@20241022',
max_tokens: 8192,
temperature: 0,
messages: [{ role: 'user', content: 'Test prompt' }],
stream: false
});
});
it('should handle API errors', async () => {
const mockError = new Error('Vertex API error');
const mockCreate = jest.fn().mockRejectedValue(mockError);
(handler['client'].messages as any).create = mockCreate;
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Vertex completion error: Vertex API error');
});
it('should handle non-text content', async () => {
const mockCreate = jest.fn().mockResolvedValue({
content: [{ type: 'image' }]
});
(handler['client'].messages as any).create = mockCreate;
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
it('should handle empty response', async () => {
const mockCreate = jest.fn().mockResolvedValue({
content: [{ type: 'text', text: '' }]
});
(handler['client'].messages as any).create = mockCreate;
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
describe('getModel', () => {
it('should return correct model info', () => {
const modelInfo = handler.getModel();

View File

@@ -7,10 +7,10 @@ import {
ApiHandlerOptions,
ModelInfo,
} from "../../shared/api"
import { ApiHandler } from "../index"
import { ApiHandler, SingleCompletionHandler } from "../index"
import { ApiStream } from "../transform/stream"
export class AnthropicHandler implements ApiHandler {
export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
private client: Anthropic
@@ -173,4 +173,27 @@ export class AnthropicHandler implements ApiHandler {
}
return { id: anthropicDefaultModelId, info: anthropicModels[anthropicDefaultModelId] }
}
async completePrompt(prompt: string): Promise<string> {
try {
const response = await this.client.messages.create({
model: this.getModel().id,
max_tokens: this.getModel().info.maxTokens || 8192,
temperature: 0,
messages: [{ role: "user", content: prompt }],
stream: false
})
const content = response.content[0]
if (content.type === 'text') {
return content.text
}
return ''
} catch (error) {
if (error instanceof Error) {
throw new Error(`Anthropic completion error: ${error.message}`)
}
throw error
}
}
}

View File

@@ -1,6 +1,6 @@
import { BedrockRuntimeClient, ConverseStreamCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime"
import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime"
import { Anthropic } from "@anthropic-ai/sdk"
import { ApiHandler } from "../"
import { ApiHandler, SingleCompletionHandler } from "../"
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
import { ApiStream } from "../transform/stream"
import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../transform/bedrock-converse-format"
@@ -38,7 +38,7 @@ export interface StreamEvent {
};
}
export class AwsBedrockHandler implements ApiHandler {
export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
private client: BedrockRuntimeClient
@@ -199,7 +199,7 @@ export class AwsBedrockHandler implements ApiHandler {
if (modelId) {
// For tests, allow any model ID
if (process.env.NODE_ENV === 'test') {
return {
return {
id: modelId,
info: {
maxTokens: 5000,
@@ -214,9 +214,68 @@ export class AwsBedrockHandler implements ApiHandler {
return { id, info: bedrockModels[id] }
}
}
return {
id: bedrockDefaultModelId,
info: bedrockModels[bedrockDefaultModelId]
return {
id: bedrockDefaultModelId,
info: bedrockModels[bedrockDefaultModelId]
}
}
async completePrompt(prompt: string): Promise<string> {
try {
const modelConfig = this.getModel()
// Handle cross-region inference
let modelId: string
if (this.options.awsUseCrossRegionInference) {
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
switch (regionPrefix) {
case "us-":
modelId = `us.${modelConfig.id}`
break
case "eu-":
modelId = `eu.${modelConfig.id}`
break
default:
modelId = modelConfig.id
break
}
} else {
modelId = modelConfig.id
}
const payload = {
modelId,
messages: convertToBedrockConverseMessages([{
role: "user",
content: prompt
}]),
inferenceConfig: {
maxTokens: modelConfig.info.maxTokens || 5000,
temperature: 0.3,
topP: 0.1
}
}
const command = new ConverseCommand(payload)
const response = await this.client.send(command)
if (response.output && response.output instanceof Uint8Array) {
try {
const outputStr = new TextDecoder().decode(response.output)
const output = JSON.parse(outputStr)
if (output.content) {
return output.content
}
} catch (parseError) {
console.error('Failed to parse Bedrock response:', parseError)
}
}
return ''
} catch (error) {
if (error instanceof Error) {
throw new Error(`Bedrock completion error: ${error.message}`)
}
throw error
}
}
}

View File

@@ -1,11 +1,11 @@
import { Anthropic } from "@anthropic-ai/sdk"
import { GoogleGenerativeAI } from "@google/generative-ai"
import { ApiHandler } from "../"
import { ApiHandler, SingleCompletionHandler } from "../"
import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../../shared/api"
import { convertAnthropicMessageToGemini } from "../transform/gemini-format"
import { ApiStream } from "../transform/stream"
export class GeminiHandler implements ApiHandler {
export class GeminiHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
private client: GoogleGenerativeAI
@@ -53,4 +53,26 @@ export class GeminiHandler implements ApiHandler {
}
return { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] }
}
async completePrompt(prompt: string): Promise<string> {
try {
const model = this.client.getGenerativeModel({
model: this.getModel().id,
})
const result = await model.generateContent({
contents: [{ role: "user", parts: [{ text: prompt }] }],
generationConfig: {
temperature: 0,
},
})
return result.response.text()
} catch (error) {
if (error instanceof Error) {
throw new Error(`Gemini completion error: ${error.message}`)
}
throw error
}
}
}

View File

@@ -1,13 +1,13 @@
import { Anthropic } from "@anthropic-ai/sdk"
import axios from "axios"
import OpenAI from "openai"
import { ApiHandler } from "../"
import { ApiHandler, SingleCompletionHandler } from "../"
import { ApiHandlerOptions, ModelInfo, glamaDefaultModelId, glamaDefaultModelInfo } from "../../shared/api"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"
import delay from "delay"
export class GlamaHandler implements ApiHandler {
export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
private client: OpenAI
@@ -129,4 +129,26 @@ export class GlamaHandler implements ApiHandler {
return { id: glamaDefaultModelId, info: glamaDefaultModelInfo }
}
async completePrompt(prompt: string): Promise<string> {
try {
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: this.getModel().id,
messages: [{ role: "user", content: prompt }],
temperature: 0,
}
if (this.getModel().id.startsWith("anthropic/")) {
requestOptions.max_tokens = 8192
}
const response = await this.client.chat.completions.create(requestOptions)
return response.choices[0]?.message.content || ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`Glama completion error: ${error.message}`)
}
throw error
}
}
}

View File

@@ -1,11 +1,11 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import { ApiHandler } from "../"
import { ApiHandler, SingleCompletionHandler } from "../"
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"
export class LmStudioHandler implements ApiHandler {
export class LmStudioHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
private client: OpenAI
@@ -53,4 +53,20 @@ export class LmStudioHandler implements ApiHandler {
info: openAiModelInfoSaneDefaults,
}
}
async completePrompt(prompt: string): Promise<string> {
try {
const response = await this.client.chat.completions.create({
model: this.getModel().id,
messages: [{ role: "user", content: prompt }],
temperature: 0,
stream: false
})
return response.choices[0]?.message.content || ""
} catch (error) {
throw new Error(
"Please check the LM Studio developer logs to debug what went wrong. You may need to load the model with a larger context length to work with Cline's prompts.",
)
}
}
}

View File

@@ -1,11 +1,11 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import { ApiHandler } from "../"
import { ApiHandler, SingleCompletionHandler } from "../"
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"
export class OllamaHandler implements ApiHandler {
export class OllamaHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
private client: OpenAI
@@ -46,4 +46,21 @@ export class OllamaHandler implements ApiHandler {
info: openAiModelInfoSaneDefaults,
}
}
async completePrompt(prompt: string): Promise<string> {
try {
const response = await this.client.chat.completions.create({
model: this.getModel().id,
messages: [{ role: "user", content: prompt }],
temperature: 0,
stream: false
})
return response.choices[0]?.message.content || ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`Ollama completion error: ${error.message}`)
}
throw error
}
}
}

View File

@@ -1,6 +1,6 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import { ApiHandler } from "../"
import { ApiHandler, SingleCompletionHandler } from "../"
import {
ApiHandlerOptions,
ModelInfo,
@@ -11,7 +11,7 @@ import {
import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"
export class OpenAiNativeHandler implements ApiHandler {
export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
private client: OpenAI
@@ -83,4 +83,37 @@ export class OpenAiNativeHandler implements ApiHandler {
}
return { id: openAiNativeDefaultModelId, info: openAiNativeModels[openAiNativeDefaultModelId] }
}
async completePrompt(prompt: string): Promise<string> {
try {
const modelId = this.getModel().id
let requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
switch (modelId) {
case "o1":
case "o1-preview":
case "o1-mini":
// o1 doesn't support non-1 temp or system prompt
requestOptions = {
model: modelId,
messages: [{ role: "user", content: prompt }]
}
break
default:
requestOptions = {
model: modelId,
messages: [{ role: "user", content: prompt }],
temperature: 0
}
}
const response = await this.client.chat.completions.create(requestOptions)
return response.choices[0]?.message.content || ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`OpenAI Native completion error: ${error.message}`)
}
throw error
}
}
}

View File

@@ -6,11 +6,11 @@ import {
ModelInfo,
openAiModelInfoSaneDefaults,
} from "../../shared/api"
import { ApiHandler } from "../index"
import { ApiHandler, SingleCompletionHandler } from "../index"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"
export class OpenAiHandler implements ApiHandler {
export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
protected options: ApiHandlerOptions
private client: OpenAI
@@ -100,4 +100,22 @@ export class OpenAiHandler implements ApiHandler {
info: openAiModelInfoSaneDefaults,
}
}
async completePrompt(prompt: string): Promise<string> {
try {
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: this.getModel().id,
messages: [{ role: "user", content: prompt }],
temperature: 0,
}
const response = await this.client.chat.completions.create(requestOptions)
return response.choices[0]?.message.content || ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`OpenAI completion error: ${error.message}`)
}
throw error
}
}
}

View File

@@ -1,11 +1,11 @@
import { Anthropic } from "@anthropic-ai/sdk"
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
import { ApiHandler } from "../"
import { ApiHandler, SingleCompletionHandler } from "../"
import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../../shared/api"
import { ApiStream } from "../transform/stream"
// https://docs.anthropic.com/en/api/claude-on-vertex-ai
export class VertexHandler implements ApiHandler {
export class VertexHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
private client: AnthropicVertex
@@ -83,4 +83,27 @@ export class VertexHandler implements ApiHandler {
}
return { id: vertexDefaultModelId, info: vertexModels[vertexDefaultModelId] }
}
async completePrompt(prompt: string): Promise<string> {
try {
const response = await this.client.messages.create({
model: this.getModel().id,
max_tokens: this.getModel().info.maxTokens || 8192,
temperature: 0,
messages: [{ role: "user", content: prompt }],
stream: false
})
const content = response.content[0]
if (content.type === 'text') {
return content.text
}
return ''
} catch (error) {
if (error instanceof Error) {
throw new Error(`Vertex completion error: ${error.message}`)
}
throw error
}
}
}