mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2026-02-05 12:05:16 -05:00
Prettier backfill
This commit is contained in:
@@ -1,239 +1,238 @@
|
||||
import { AnthropicHandler } from '../anthropic';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import { ApiStream } from '../../transform/stream';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { AnthropicHandler } from "../anthropic"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import { ApiStream } from "../../transform/stream"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock Anthropic client
|
||||
const mockBetaCreate = jest.fn();
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('@anthropic-ai/sdk', () => {
|
||||
return {
|
||||
Anthropic: jest.fn().mockImplementation(() => ({
|
||||
beta: {
|
||||
promptCaching: {
|
||||
messages: {
|
||||
create: mockBetaCreate.mockImplementation(async () => ({
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: 'message_start',
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 100,
|
||||
output_tokens: 50,
|
||||
cache_creation_input_tokens: 20,
|
||||
cache_read_input_tokens: 10
|
||||
}
|
||||
}
|
||||
};
|
||||
yield {
|
||||
type: 'content_block_start',
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'Hello'
|
||||
}
|
||||
};
|
||||
yield {
|
||||
type: 'content_block_delta',
|
||||
delta: {
|
||||
type: 'text_delta',
|
||||
text: ' world'
|
||||
}
|
||||
};
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
||||
},
|
||||
messages: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
content: [
|
||||
{ type: 'text', text: 'Test response' }
|
||||
],
|
||||
role: 'assistant',
|
||||
model: options.model,
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
return {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: 'message_start',
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
yield {
|
||||
type: 'content_block_start',
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'Test response'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
const mockBetaCreate = jest.fn()
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("@anthropic-ai/sdk", () => {
|
||||
return {
|
||||
Anthropic: jest.fn().mockImplementation(() => ({
|
||||
beta: {
|
||||
promptCaching: {
|
||||
messages: {
|
||||
create: mockBetaCreate.mockImplementation(async () => ({
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: "message_start",
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 100,
|
||||
output_tokens: 50,
|
||||
cache_creation_input_tokens: 20,
|
||||
cache_read_input_tokens: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
yield {
|
||||
type: "content_block_start",
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: "text",
|
||||
text: "Hello",
|
||||
},
|
||||
}
|
||||
yield {
|
||||
type: "content_block_delta",
|
||||
delta: {
|
||||
type: "text_delta",
|
||||
text: " world",
|
||||
},
|
||||
}
|
||||
},
|
||||
})),
|
||||
},
|
||||
},
|
||||
},
|
||||
messages: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: "test-completion",
|
||||
content: [{ type: "text", text: "Test response" }],
|
||||
role: "assistant",
|
||||
model: options.model,
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
},
|
||||
}
|
||||
}
|
||||
return {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: "message_start",
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
yield {
|
||||
type: "content_block_start",
|
||||
content_block: {
|
||||
type: "text",
|
||||
text: "Test response",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
})),
|
||||
}
|
||||
})
|
||||
|
||||
describe('AnthropicHandler', () => {
|
||||
let handler: AnthropicHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
describe("AnthropicHandler", () => {
|
||||
let handler: AnthropicHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiKey: 'test-api-key',
|
||||
apiModelId: 'claude-3-5-sonnet-20241022'
|
||||
};
|
||||
handler = new AnthropicHandler(mockOptions);
|
||||
mockBetaCreate.mockClear();
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiKey: "test-api-key",
|
||||
apiModelId: "claude-3-5-sonnet-20241022",
|
||||
}
|
||||
handler = new AnthropicHandler(mockOptions)
|
||||
mockBetaCreate.mockClear()
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(AnthropicHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(AnthropicHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
|
||||
})
|
||||
|
||||
it('should initialize with undefined API key', () => {
|
||||
// The SDK will handle API key validation, so we just verify it initializes
|
||||
const handlerWithoutKey = new AnthropicHandler({
|
||||
...mockOptions,
|
||||
apiKey: undefined
|
||||
});
|
||||
expect(handlerWithoutKey).toBeInstanceOf(AnthropicHandler);
|
||||
});
|
||||
it("should initialize with undefined API key", () => {
|
||||
// The SDK will handle API key validation, so we just verify it initializes
|
||||
const handlerWithoutKey = new AnthropicHandler({
|
||||
...mockOptions,
|
||||
apiKey: undefined,
|
||||
})
|
||||
expect(handlerWithoutKey).toBeInstanceOf(AnthropicHandler)
|
||||
})
|
||||
|
||||
it('should use custom base URL if provided', () => {
|
||||
const customBaseUrl = 'https://custom.anthropic.com';
|
||||
const handlerWithCustomUrl = new AnthropicHandler({
|
||||
...mockOptions,
|
||||
anthropicBaseUrl: customBaseUrl
|
||||
});
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler);
|
||||
});
|
||||
});
|
||||
it("should use custom base URL if provided", () => {
|
||||
const customBaseUrl = "https://custom.anthropic.com"
|
||||
const handlerWithCustomUrl = new AnthropicHandler({
|
||||
...mockOptions,
|
||||
anthropicBaseUrl: customBaseUrl,
|
||||
})
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler)
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: [{
|
||||
type: 'text' as const,
|
||||
text: 'Hello!'
|
||||
}]
|
||||
}
|
||||
];
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: "Hello!",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
it('should handle prompt caching for supported models', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, [
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text' as const, text: 'First message' }]
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text' as const, text: 'Response' }]
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text' as const, text: 'Second message' }]
|
||||
}
|
||||
]);
|
||||
it("should handle prompt caching for supported models", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, [
|
||||
{
|
||||
role: "user",
|
||||
content: [{ type: "text" as const, text: "First message" }],
|
||||
},
|
||||
{
|
||||
role: "assistant",
|
||||
content: [{ type: "text" as const, text: "Response" }],
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: [{ type: "text" as const, text: "Second message" }],
|
||||
},
|
||||
])
|
||||
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
// Verify usage information
|
||||
const usageChunk = chunks.find(chunk => chunk.type === 'usage');
|
||||
expect(usageChunk).toBeDefined();
|
||||
expect(usageChunk?.inputTokens).toBe(100);
|
||||
expect(usageChunk?.outputTokens).toBe(50);
|
||||
expect(usageChunk?.cacheWriteTokens).toBe(20);
|
||||
expect(usageChunk?.cacheReadTokens).toBe(10);
|
||||
// Verify usage information
|
||||
const usageChunk = chunks.find((chunk) => chunk.type === "usage")
|
||||
expect(usageChunk).toBeDefined()
|
||||
expect(usageChunk?.inputTokens).toBe(100)
|
||||
expect(usageChunk?.outputTokens).toBe(50)
|
||||
expect(usageChunk?.cacheWriteTokens).toBe(20)
|
||||
expect(usageChunk?.cacheReadTokens).toBe(10)
|
||||
|
||||
// Verify text content
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(2);
|
||||
expect(textChunks[0].text).toBe('Hello');
|
||||
expect(textChunks[1].text).toBe(' world');
|
||||
// Verify text content
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(2)
|
||||
expect(textChunks[0].text).toBe("Hello")
|
||||
expect(textChunks[1].text).toBe(" world")
|
||||
|
||||
// Verify beta API was used
|
||||
expect(mockBetaCreate).toHaveBeenCalled();
|
||||
expect(mockCreate).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
// Verify beta API was used
|
||||
expect(mockBetaCreate).toHaveBeenCalled()
|
||||
expect(mockCreate).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.apiModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
max_tokens: 8192,
|
||||
temperature: 0,
|
||||
stream: false
|
||||
});
|
||||
});
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.apiModelId,
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
max_tokens: 8192,
|
||||
temperature: 0,
|
||||
stream: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Anthropic completion error: API Error');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Anthropic completion error: API Error")
|
||||
})
|
||||
|
||||
it('should handle non-text content', async () => {
|
||||
mockCreate.mockImplementationOnce(async () => ({
|
||||
content: [{ type: 'image' }]
|
||||
}));
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
it("should handle non-text content", async () => {
|
||||
mockCreate.mockImplementationOnce(async () => ({
|
||||
content: [{ type: "image" }],
|
||||
}))
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockImplementationOnce(async () => ({
|
||||
content: [{ type: 'text', text: '' }]
|
||||
}));
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockImplementationOnce(async () => ({
|
||||
content: [{ type: "text", text: "" }],
|
||||
}))
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return default model if no model ID is provided', () => {
|
||||
const handlerWithoutModel = new AnthropicHandler({
|
||||
...mockOptions,
|
||||
apiModelId: undefined
|
||||
});
|
||||
const model = handlerWithoutModel.getModel();
|
||||
expect(model.id).toBeDefined();
|
||||
expect(model.info).toBeDefined();
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return default model if no model ID is provided", () => {
|
||||
const handlerWithoutModel = new AnthropicHandler({
|
||||
...mockOptions,
|
||||
apiModelId: undefined,
|
||||
})
|
||||
const model = handlerWithoutModel.getModel()
|
||||
expect(model.id).toBeDefined()
|
||||
expect(model.info).toBeDefined()
|
||||
})
|
||||
|
||||
it('should return specified model if valid model ID is provided', () => {
|
||||
const model = handler.getModel();
|
||||
expect(model.id).toBe(mockOptions.apiModelId);
|
||||
expect(model.info).toBeDefined();
|
||||
expect(model.info.maxTokens).toBe(8192);
|
||||
expect(model.info.contextWindow).toBe(200_000);
|
||||
expect(model.info.supportsImages).toBe(true);
|
||||
expect(model.info.supportsPromptCache).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
it("should return specified model if valid model ID is provided", () => {
|
||||
const model = handler.getModel()
|
||||
expect(model.id).toBe(mockOptions.apiModelId)
|
||||
expect(model.info).toBeDefined()
|
||||
expect(model.info.maxTokens).toBe(8192)
|
||||
expect(model.info.contextWindow).toBe(200_000)
|
||||
expect(model.info.supportsImages).toBe(true)
|
||||
expect(model.info.supportsPromptCache).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,246 +1,259 @@
|
||||
import { AwsBedrockHandler } from '../bedrock';
|
||||
import { MessageContent } from '../../../shared/api';
|
||||
import { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { AwsBedrockHandler } from "../bedrock"
|
||||
import { MessageContent } from "../../../shared/api"
|
||||
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
describe('AwsBedrockHandler', () => {
|
||||
let handler: AwsBedrockHandler;
|
||||
describe("AwsBedrockHandler", () => {
|
||||
let handler: AwsBedrockHandler
|
||||
|
||||
beforeEach(() => {
|
||||
handler = new AwsBedrockHandler({
|
||||
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||
awsAccessKey: 'test-access-key',
|
||||
awsSecretKey: 'test-secret-key',
|
||||
awsRegion: 'us-east-1'
|
||||
});
|
||||
});
|
||||
beforeEach(() => {
|
||||
handler = new AwsBedrockHandler({
|
||||
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
awsAccessKey: "test-access-key",
|
||||
awsSecretKey: "test-secret-key",
|
||||
awsRegion: "us-east-1",
|
||||
})
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided config', () => {
|
||||
expect(handler['options'].awsAccessKey).toBe('test-access-key');
|
||||
expect(handler['options'].awsSecretKey).toBe('test-secret-key');
|
||||
expect(handler['options'].awsRegion).toBe('us-east-1');
|
||||
expect(handler['options'].apiModelId).toBe('anthropic.claude-3-5-sonnet-20241022-v2:0');
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided config", () => {
|
||||
expect(handler["options"].awsAccessKey).toBe("test-access-key")
|
||||
expect(handler["options"].awsSecretKey).toBe("test-secret-key")
|
||||
expect(handler["options"].awsRegion).toBe("us-east-1")
|
||||
expect(handler["options"].apiModelId).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
})
|
||||
|
||||
it('should initialize with missing AWS credentials', () => {
|
||||
const handlerWithoutCreds = new AwsBedrockHandler({
|
||||
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||
awsRegion: 'us-east-1'
|
||||
});
|
||||
expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler);
|
||||
});
|
||||
});
|
||||
it("should initialize with missing AWS credentials", () => {
|
||||
const handlerWithoutCreds = new AwsBedrockHandler({
|
||||
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
awsRegion: "us-east-1",
|
||||
})
|
||||
expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler)
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello'
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi there!'
|
||||
}
|
||||
];
|
||||
describe("createMessage", () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello",
|
||||
},
|
||||
{
|
||||
role: "assistant",
|
||||
content: "Hi there!",
|
||||
},
|
||||
]
|
||||
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
|
||||
it('should handle text messages correctly', async () => {
|
||||
const mockResponse = {
|
||||
messages: [{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Hello! How can I help you?' }]
|
||||
}],
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
};
|
||||
it("should handle text messages correctly", async () => {
|
||||
const mockResponse = {
|
||||
messages: [
|
||||
{
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "Hello! How can I help you?" }],
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
},
|
||||
}
|
||||
|
||||
// Mock AWS SDK invoke
|
||||
const mockStream = {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
metadata: {
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
outputTokens: 5
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
// Mock AWS SDK invoke
|
||||
const mockStream = {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
metadata: {
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
outputTokens: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const mockInvoke = jest.fn().mockResolvedValue({
|
||||
stream: mockStream
|
||||
});
|
||||
const mockInvoke = jest.fn().mockResolvedValue({
|
||||
stream: mockStream,
|
||||
})
|
||||
|
||||
handler['client'] = {
|
||||
send: mockInvoke
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
handler["client"] = {
|
||||
send: mockInvoke,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const chunks = [];
|
||||
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
const chunks = []
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'usage',
|
||||
inputTokens: 10,
|
||||
outputTokens: 5
|
||||
});
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(mockInvoke).toHaveBeenCalledWith(expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0'
|
||||
})
|
||||
}));
|
||||
});
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
expect(chunks[0]).toEqual({
|
||||
type: "usage",
|
||||
inputTokens: 10,
|
||||
outputTokens: 5,
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
// Mock AWS SDK invoke with error
|
||||
const mockInvoke = jest.fn().mockRejectedValue(new Error('AWS Bedrock error'));
|
||||
expect(mockInvoke).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
}),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
handler['client'] = {
|
||||
send: mockInvoke
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
it("should handle API errors", async () => {
|
||||
// Mock AWS SDK invoke with error
|
||||
const mockInvoke = jest.fn().mockRejectedValue(new Error("AWS Bedrock error"))
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
handler["client"] = {
|
||||
send: mockInvoke,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
}
|
||||
}).rejects.toThrow('AWS Bedrock error');
|
||||
});
|
||||
});
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(JSON.stringify({
|
||||
content: 'Test response'
|
||||
}))
|
||||
};
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
}
|
||||
}).rejects.toThrow("AWS Bedrock error")
|
||||
})
|
||||
})
|
||||
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(
|
||||
JSON.stringify({
|
||||
content: "Test response",
|
||||
}),
|
||||
),
|
||||
}
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||
messages: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
role: 'user',
|
||||
content: [{ text: 'Test prompt' }]
|
||||
})
|
||||
]),
|
||||
inferenceConfig: expect.objectContaining({
|
||||
maxTokens: 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1
|
||||
})
|
||||
})
|
||||
}));
|
||||
});
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse)
|
||||
handler["client"] = {
|
||||
send: mockSend,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('AWS Bedrock error');
|
||||
const mockSend = jest.fn().mockRejectedValue(mockError);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockSend).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
messages: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
role: "user",
|
||||
content: [{ text: "Test prompt" }],
|
||||
}),
|
||||
]),
|
||||
inferenceConfig: expect.objectContaining({
|
||||
maxTokens: 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1,
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Bedrock completion error: AWS Bedrock error');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
const mockError = new Error("AWS Bedrock error")
|
||||
const mockSend = jest.fn().mockRejectedValue(mockError)
|
||||
handler["client"] = {
|
||||
send: mockSend,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
it('should handle invalid response format', async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode('invalid json')
|
||||
};
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"Bedrock completion error: AWS Bedrock error",
|
||||
)
|
||||
})
|
||||
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
it("should handle invalid response format", async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode("invalid json"),
|
||||
}
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse)
|
||||
handler["client"] = {
|
||||
send: mockSend,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(JSON.stringify({}))
|
||||
};
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
it("should handle empty response", async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(JSON.stringify({})),
|
||||
}
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse)
|
||||
handler["client"] = {
|
||||
send: mockSend,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
it('should handle cross-region inference', async () => {
|
||||
handler = new AwsBedrockHandler({
|
||||
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||
awsAccessKey: 'test-access-key',
|
||||
awsSecretKey: 'test-secret-key',
|
||||
awsRegion: 'us-east-1',
|
||||
awsUseCrossRegionInference: true
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(JSON.stringify({
|
||||
content: 'Test response'
|
||||
}))
|
||||
};
|
||||
it("should handle cross-region inference", async () => {
|
||||
handler = new AwsBedrockHandler({
|
||||
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
awsAccessKey: "test-access-key",
|
||||
awsSecretKey: "test-secret-key",
|
||||
awsRegion: "us-east-1",
|
||||
awsUseCrossRegionInference: true,
|
||||
})
|
||||
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(
|
||||
JSON.stringify({
|
||||
content: "Test response",
|
||||
}),
|
||||
),
|
||||
}
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: 'us.anthropic.claude-3-5-sonnet-20241022-v2:0'
|
||||
})
|
||||
}));
|
||||
});
|
||||
});
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse)
|
||||
handler["client"] = {
|
||||
send: mockSend,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return correct model info in test environment', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe('anthropic.claude-3-5-sonnet-20241022-v2:0');
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(5000); // Test environment value
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000); // Test environment value
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockSend).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
}),
|
||||
}),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('should return test model info for invalid model in test environment', () => {
|
||||
const invalidHandler = new AwsBedrockHandler({
|
||||
apiModelId: 'invalid-model',
|
||||
awsAccessKey: 'test-access-key',
|
||||
awsSecretKey: 'test-secret-key',
|
||||
awsRegion: 'us-east-1'
|
||||
});
|
||||
const modelInfo = invalidHandler.getModel();
|
||||
expect(modelInfo.id).toBe('invalid-model'); // In test env, returns whatever is passed
|
||||
expect(modelInfo.info.maxTokens).toBe(5000);
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||
});
|
||||
});
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return correct model info in test environment", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(5000) // Test environment value
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000) // Test environment value
|
||||
})
|
||||
|
||||
it("should return test model info for invalid model in test environment", () => {
|
||||
const invalidHandler = new AwsBedrockHandler({
|
||||
apiModelId: "invalid-model",
|
||||
awsAccessKey: "test-access-key",
|
||||
awsSecretKey: "test-secret-key",
|
||||
awsRegion: "us-east-1",
|
||||
})
|
||||
const modelInfo = invalidHandler.getModel()
|
||||
expect(modelInfo.id).toBe("invalid-model") // In test env, returns whatever is passed
|
||||
expect(modelInfo.info.maxTokens).toBe(5000)
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,203 +1,217 @@
|
||||
import { DeepSeekHandler } from '../deepseek';
|
||||
import { ApiHandlerOptions, deepSeekDefaultModelId } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { DeepSeekHandler } from "../deepseek"
|
||||
import { ApiHandlerOptions, deepSeekDefaultModelId } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response', refusal: null },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Return async iterator for streaming
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response", refusal: null },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
describe('DeepSeekHandler', () => {
|
||||
let handler: DeepSeekHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
// Return async iterator for streaming
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
deepSeekApiKey: 'test-api-key',
|
||||
deepSeekModelId: 'deepseek-chat',
|
||||
deepSeekBaseUrl: 'https://api.deepseek.com/v1'
|
||||
};
|
||||
handler = new DeepSeekHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
describe("DeepSeekHandler", () => {
|
||||
let handler: DeepSeekHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(DeepSeekHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.deepSeekModelId);
|
||||
});
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
deepSeekApiKey: "test-api-key",
|
||||
deepSeekModelId: "deepseek-chat",
|
||||
deepSeekBaseUrl: "https://api.deepseek.com/v1",
|
||||
}
|
||||
handler = new DeepSeekHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
it('should throw error if API key is missing', () => {
|
||||
expect(() => {
|
||||
new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekApiKey: undefined
|
||||
});
|
||||
}).toThrow('DeepSeek API key is required');
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(DeepSeekHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.deepSeekModelId)
|
||||
})
|
||||
|
||||
it('should use default model ID if not provided', () => {
|
||||
const handlerWithoutModel = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekModelId: undefined
|
||||
});
|
||||
expect(handlerWithoutModel.getModel().id).toBe(deepSeekDefaultModelId);
|
||||
});
|
||||
it("should throw error if API key is missing", () => {
|
||||
expect(() => {
|
||||
new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekApiKey: undefined,
|
||||
})
|
||||
}).toThrow("DeepSeek API key is required")
|
||||
})
|
||||
|
||||
it('should use default base URL if not provided', () => {
|
||||
const handlerWithoutBaseUrl = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekBaseUrl: undefined
|
||||
});
|
||||
expect(handlerWithoutBaseUrl).toBeInstanceOf(DeepSeekHandler);
|
||||
// The base URL is passed to OpenAI client internally
|
||||
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({
|
||||
baseURL: 'https://api.deepseek.com/v1'
|
||||
}));
|
||||
});
|
||||
it("should use default model ID if not provided", () => {
|
||||
const handlerWithoutModel = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekModelId: undefined,
|
||||
})
|
||||
expect(handlerWithoutModel.getModel().id).toBe(deepSeekDefaultModelId)
|
||||
})
|
||||
|
||||
it('should use custom base URL if provided', () => {
|
||||
const customBaseUrl = 'https://custom.deepseek.com/v1';
|
||||
const handlerWithCustomUrl = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekBaseUrl: customBaseUrl
|
||||
});
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(DeepSeekHandler);
|
||||
// The custom base URL is passed to OpenAI client
|
||||
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({
|
||||
baseURL: customBaseUrl
|
||||
}));
|
||||
});
|
||||
it("should use default base URL if not provided", () => {
|
||||
const handlerWithoutBaseUrl = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekBaseUrl: undefined,
|
||||
})
|
||||
expect(handlerWithoutBaseUrl).toBeInstanceOf(DeepSeekHandler)
|
||||
// The base URL is passed to OpenAI client internally
|
||||
expect(OpenAI).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
baseURL: "https://api.deepseek.com/v1",
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
it('should set includeMaxTokens to true', () => {
|
||||
// Create a new handler and verify OpenAI client was called with includeMaxTokens
|
||||
new DeepSeekHandler(mockOptions);
|
||||
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({
|
||||
apiKey: mockOptions.deepSeekApiKey
|
||||
}));
|
||||
});
|
||||
});
|
||||
it("should use custom base URL if provided", () => {
|
||||
const customBaseUrl = "https://custom.deepseek.com/v1"
|
||||
const handlerWithCustomUrl = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekBaseUrl: customBaseUrl,
|
||||
})
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(DeepSeekHandler)
|
||||
// The custom base URL is passed to OpenAI client
|
||||
expect(OpenAI).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
baseURL: customBaseUrl,
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info for valid model ID', () => {
|
||||
const model = handler.getModel();
|
||||
expect(model.id).toBe(mockOptions.deepSeekModelId);
|
||||
expect(model.info).toBeDefined();
|
||||
expect(model.info.maxTokens).toBe(8192);
|
||||
expect(model.info.contextWindow).toBe(64_000);
|
||||
expect(model.info.supportsImages).toBe(false);
|
||||
expect(model.info.supportsPromptCache).toBe(false);
|
||||
});
|
||||
it("should set includeMaxTokens to true", () => {
|
||||
// Create a new handler and verify OpenAI client was called with includeMaxTokens
|
||||
new DeepSeekHandler(mockOptions)
|
||||
expect(OpenAI).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
apiKey: mockOptions.deepSeekApiKey,
|
||||
}),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('should return provided model ID with default model info if model does not exist', () => {
|
||||
const handlerWithInvalidModel = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekModelId: 'invalid-model'
|
||||
});
|
||||
const model = handlerWithInvalidModel.getModel();
|
||||
expect(model.id).toBe('invalid-model'); // Returns provided ID
|
||||
expect(model.info).toBeDefined();
|
||||
expect(model.info).toBe(handler.getModel().info); // But uses default model info
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return model info for valid model ID", () => {
|
||||
const model = handler.getModel()
|
||||
expect(model.id).toBe(mockOptions.deepSeekModelId)
|
||||
expect(model.info).toBeDefined()
|
||||
expect(model.info.maxTokens).toBe(8192)
|
||||
expect(model.info.contextWindow).toBe(64_000)
|
||||
expect(model.info.supportsImages).toBe(false)
|
||||
expect(model.info.supportsPromptCache).toBe(false)
|
||||
})
|
||||
|
||||
it('should return default model if no model ID is provided', () => {
|
||||
const handlerWithoutModel = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekModelId: undefined
|
||||
});
|
||||
const model = handlerWithoutModel.getModel();
|
||||
expect(model.id).toBe(deepSeekDefaultModelId);
|
||||
expect(model.info).toBeDefined();
|
||||
});
|
||||
});
|
||||
it("should return provided model ID with default model info if model does not exist", () => {
|
||||
const handlerWithInvalidModel = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekModelId: "invalid-model",
|
||||
})
|
||||
const model = handlerWithInvalidModel.getModel()
|
||||
expect(model.id).toBe("invalid-model") // Returns provided ID
|
||||
expect(model.info).toBeDefined()
|
||||
expect(model.info).toBe(handler.getModel().info) // But uses default model info
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: [{
|
||||
type: 'text' as const,
|
||||
text: 'Hello!'
|
||||
}]
|
||||
}
|
||||
];
|
||||
it("should return default model if no model ID is provided", () => {
|
||||
const handlerWithoutModel = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekModelId: undefined,
|
||||
})
|
||||
const model = handlerWithoutModel.getModel()
|
||||
expect(model.id).toBe(deepSeekDefaultModelId)
|
||||
expect(model.info).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: "Hello!",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
it("should handle streaming responses", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
it('should include usage information', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(1)
|
||||
expect(textChunks[0].text).toBe("Test response")
|
||||
})
|
||||
|
||||
const usageChunks = chunks.filter(chunk => chunk.type === 'usage');
|
||||
expect(usageChunks.length).toBeGreaterThan(0);
|
||||
expect(usageChunks[0].inputTokens).toBe(10);
|
||||
expect(usageChunks[0].outputTokens).toBe(5);
|
||||
});
|
||||
});
|
||||
});
|
||||
it("should include usage information", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
|
||||
expect(usageChunks.length).toBeGreaterThan(0)
|
||||
expect(usageChunks[0].inputTokens).toBe(10)
|
||||
expect(usageChunks[0].outputTokens).toBe(5)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,212 +1,210 @@
|
||||
import { GeminiHandler } from '../gemini';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { GoogleGenerativeAI } from '@google/generative-ai';
|
||||
import { GeminiHandler } from "../gemini"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { GoogleGenerativeAI } from "@google/generative-ai"
|
||||
|
||||
// Mock the Google Generative AI SDK
|
||||
jest.mock('@google/generative-ai', () => ({
|
||||
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
|
||||
getGenerativeModel: jest.fn().mockReturnValue({
|
||||
generateContentStream: jest.fn(),
|
||||
generateContent: jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => 'Test response'
|
||||
}
|
||||
})
|
||||
})
|
||||
}))
|
||||
}));
|
||||
jest.mock("@google/generative-ai", () => ({
|
||||
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
|
||||
getGenerativeModel: jest.fn().mockReturnValue({
|
||||
generateContentStream: jest.fn(),
|
||||
generateContent: jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => "Test response",
|
||||
},
|
||||
}),
|
||||
}),
|
||||
})),
|
||||
}))
|
||||
|
||||
describe('GeminiHandler', () => {
|
||||
let handler: GeminiHandler;
|
||||
describe("GeminiHandler", () => {
|
||||
let handler: GeminiHandler
|
||||
|
||||
beforeEach(() => {
|
||||
handler = new GeminiHandler({
|
||||
apiKey: 'test-key',
|
||||
apiModelId: 'gemini-2.0-flash-thinking-exp-1219',
|
||||
geminiApiKey: 'test-key'
|
||||
});
|
||||
});
|
||||
beforeEach(() => {
|
||||
handler = new GeminiHandler({
|
||||
apiKey: "test-key",
|
||||
apiModelId: "gemini-2.0-flash-thinking-exp-1219",
|
||||
geminiApiKey: "test-key",
|
||||
})
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided config', () => {
|
||||
expect(handler['options'].geminiApiKey).toBe('test-key');
|
||||
expect(handler['options'].apiModelId).toBe('gemini-2.0-flash-thinking-exp-1219');
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided config", () => {
|
||||
expect(handler["options"].geminiApiKey).toBe("test-key")
|
||||
expect(handler["options"].apiModelId).toBe("gemini-2.0-flash-thinking-exp-1219")
|
||||
})
|
||||
|
||||
it('should throw if API key is missing', () => {
|
||||
expect(() => {
|
||||
new GeminiHandler({
|
||||
apiModelId: 'gemini-2.0-flash-thinking-exp-1219',
|
||||
geminiApiKey: ''
|
||||
});
|
||||
}).toThrow('API key is required for Google Gemini');
|
||||
});
|
||||
});
|
||||
it("should throw if API key is missing", () => {
|
||||
expect(() => {
|
||||
new GeminiHandler({
|
||||
apiModelId: "gemini-2.0-flash-thinking-exp-1219",
|
||||
geminiApiKey: "",
|
||||
})
|
||||
}).toThrow("API key is required for Google Gemini")
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello'
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi there!'
|
||||
}
|
||||
];
|
||||
describe("createMessage", () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello",
|
||||
},
|
||||
{
|
||||
role: "assistant",
|
||||
content: "Hi there!",
|
||||
},
|
||||
]
|
||||
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
|
||||
it('should handle text messages correctly', async () => {
|
||||
// Mock the stream response
|
||||
const mockStream = {
|
||||
stream: [
|
||||
{ text: () => 'Hello' },
|
||||
{ text: () => ' world!' }
|
||||
],
|
||||
response: {
|
||||
usageMetadata: {
|
||||
promptTokenCount: 10,
|
||||
candidatesTokenCount: 5
|
||||
}
|
||||
}
|
||||
};
|
||||
it("should handle text messages correctly", async () => {
|
||||
// Mock the stream response
|
||||
const mockStream = {
|
||||
stream: [{ text: () => "Hello" }, { text: () => " world!" }],
|
||||
response: {
|
||||
usageMetadata: {
|
||||
promptTokenCount: 10,
|
||||
candidatesTokenCount: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Setup the mock implementation
|
||||
const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream);
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContentStream: mockGenerateContentStream
|
||||
});
|
||||
// Setup the mock implementation
|
||||
const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream)
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContentStream: mockGenerateContentStream,
|
||||
})
|
||||
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const chunks = [];
|
||||
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
const chunks = []
|
||||
|
||||
// Should have 3 chunks: 'Hello', ' world!', and usage info
|
||||
expect(chunks.length).toBe(3);
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Hello'
|
||||
});
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'text',
|
||||
text: ' world!'
|
||||
});
|
||||
expect(chunks[2]).toEqual({
|
||||
type: 'usage',
|
||||
inputTokens: 10,
|
||||
outputTokens: 5
|
||||
});
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
// Verify the model configuration
|
||||
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
|
||||
model: 'gemini-2.0-flash-thinking-exp-1219',
|
||||
systemInstruction: systemPrompt
|
||||
});
|
||||
// Should have 3 chunks: 'Hello', ' world!', and usage info
|
||||
expect(chunks.length).toBe(3)
|
||||
expect(chunks[0]).toEqual({
|
||||
type: "text",
|
||||
text: "Hello",
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: "text",
|
||||
text: " world!",
|
||||
})
|
||||
expect(chunks[2]).toEqual({
|
||||
type: "usage",
|
||||
inputTokens: 10,
|
||||
outputTokens: 5,
|
||||
})
|
||||
|
||||
// Verify generation config
|
||||
expect(mockGenerateContentStream).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
generationConfig: {
|
||||
temperature: 0
|
||||
}
|
||||
})
|
||||
);
|
||||
});
|
||||
// Verify the model configuration
|
||||
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
|
||||
model: "gemini-2.0-flash-thinking-exp-1219",
|
||||
systemInstruction: systemPrompt,
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('Gemini API error');
|
||||
const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError);
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContentStream: mockGenerateContentStream
|
||||
});
|
||||
// Verify generation config
|
||||
expect(mockGenerateContentStream).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
generationConfig: {
|
||||
temperature: 0,
|
||||
},
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
it("should handle API errors", async () => {
|
||||
const mockError = new Error("Gemini API error")
|
||||
const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError)
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContentStream: mockGenerateContentStream,
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
}
|
||||
}).rejects.toThrow('Gemini API error');
|
||||
});
|
||||
});
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const mockGenerateContent = jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => 'Test response'
|
||||
}
|
||||
});
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent
|
||||
});
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
}
|
||||
}).rejects.toThrow("Gemini API error")
|
||||
})
|
||||
})
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
|
||||
model: 'gemini-2.0-flash-thinking-exp-1219'
|
||||
});
|
||||
expect(mockGenerateContent).toHaveBeenCalledWith({
|
||||
contents: [{ role: 'user', parts: [{ text: 'Test prompt' }] }],
|
||||
generationConfig: {
|
||||
temperature: 0
|
||||
}
|
||||
});
|
||||
});
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const mockGenerateContent = jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => "Test response",
|
||||
},
|
||||
})
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
})
|
||||
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('Gemini API error');
|
||||
const mockGenerateContent = jest.fn().mockRejectedValue(mockError);
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent
|
||||
});
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
|
||||
model: "gemini-2.0-flash-thinking-exp-1219",
|
||||
})
|
||||
expect(mockGenerateContent).toHaveBeenCalledWith({
|
||||
contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
|
||||
generationConfig: {
|
||||
temperature: 0,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Gemini completion error: Gemini API error');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
const mockError = new Error("Gemini API error")
|
||||
const mockGenerateContent = jest.fn().mockRejectedValue(mockError)
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
})
|
||||
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
const mockGenerateContent = jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => ''
|
||||
}
|
||||
});
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent
|
||||
});
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"Gemini completion error: Gemini API error",
|
||||
)
|
||||
})
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
it("should handle empty response", async () => {
|
||||
const mockGenerateContent = jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => "",
|
||||
},
|
||||
})
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
})
|
||||
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return correct model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe('gemini-2.0-flash-thinking-exp-1219');
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(8192);
|
||||
expect(modelInfo.info.contextWindow).toBe(32_767);
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
it('should return default model if invalid model specified', () => {
|
||||
const invalidHandler = new GeminiHandler({
|
||||
apiModelId: 'invalid-model',
|
||||
geminiApiKey: 'test-key'
|
||||
});
|
||||
const modelInfo = invalidHandler.getModel();
|
||||
expect(modelInfo.id).toBe('gemini-2.0-flash-thinking-exp-1219'); // Default model
|
||||
});
|
||||
});
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return correct model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219")
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(8192)
|
||||
expect(modelInfo.info.contextWindow).toBe(32_767)
|
||||
})
|
||||
|
||||
it("should return default model if invalid model specified", () => {
|
||||
const invalidHandler = new GeminiHandler({
|
||||
apiModelId: "invalid-model",
|
||||
geminiApiKey: "test-key",
|
||||
})
|
||||
const modelInfo = invalidHandler.getModel()
|
||||
expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219") // Default model
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,226 +1,238 @@
|
||||
import { GlamaHandler } from '../glama';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import axios from 'axios';
|
||||
import { GlamaHandler } from "../glama"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import axios from "axios"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
const mockWithResponse = jest.fn();
|
||||
const mockCreate = jest.fn()
|
||||
const mockWithResponse = jest.fn()
|
||||
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: (...args: any[]) => {
|
||||
const stream = {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: (...args: any[]) => {
|
||||
const stream = {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const result = mockCreate(...args);
|
||||
if (args[0].stream) {
|
||||
mockWithResponse.mockReturnValue(Promise.resolve({
|
||||
data: stream,
|
||||
response: {
|
||||
headers: {
|
||||
get: (name: string) => name === 'x-completion-request-id' ? 'test-request-id' : null
|
||||
}
|
||||
}
|
||||
}));
|
||||
result.withResponse = mockWithResponse;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
const result = mockCreate(...args)
|
||||
if (args[0].stream) {
|
||||
mockWithResponse.mockReturnValue(
|
||||
Promise.resolve({
|
||||
data: stream,
|
||||
response: {
|
||||
headers: {
|
||||
get: (name: string) =>
|
||||
name === "x-completion-request-id" ? "test-request-id" : null,
|
||||
},
|
||||
},
|
||||
}),
|
||||
)
|
||||
result.withResponse = mockWithResponse
|
||||
}
|
||||
return result
|
||||
},
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
})
|
||||
|
||||
describe('GlamaHandler', () => {
|
||||
let handler: GlamaHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
describe("GlamaHandler", () => {
|
||||
let handler: GlamaHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: 'anthropic/claude-3-5-sonnet',
|
||||
glamaModelId: 'anthropic/claude-3-5-sonnet',
|
||||
glamaApiKey: 'test-api-key'
|
||||
};
|
||||
handler = new GlamaHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
mockWithResponse.mockClear();
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: "anthropic/claude-3-5-sonnet",
|
||||
glamaModelId: "anthropic/claude-3-5-sonnet",
|
||||
glamaApiKey: "test-api-key",
|
||||
}
|
||||
handler = new GlamaHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
mockWithResponse.mockClear()
|
||||
|
||||
// Default mock implementation for non-streaming responses
|
||||
mockCreate.mockResolvedValue({
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
});
|
||||
});
|
||||
// Default mock implementation for non-streaming responses
|
||||
mockCreate.mockResolvedValue({
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response" },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(GlamaHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
|
||||
});
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(GlamaHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello!'
|
||||
}
|
||||
];
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello!",
|
||||
},
|
||||
]
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
// Mock axios for token usage request
|
||||
const mockAxios = jest.spyOn(axios, 'get').mockResolvedValueOnce({
|
||||
data: {
|
||||
tokenUsage: {
|
||||
promptTokens: 10,
|
||||
completionTokens: 5,
|
||||
cacheCreationInputTokens: 0,
|
||||
cacheReadInputTokens: 0
|
||||
},
|
||||
totalCostUsd: "0.00"
|
||||
}
|
||||
});
|
||||
it("should handle streaming responses", async () => {
|
||||
// Mock axios for token usage request
|
||||
const mockAxios = jest.spyOn(axios, "get").mockResolvedValueOnce({
|
||||
data: {
|
||||
tokenUsage: {
|
||||
promptTokens: 10,
|
||||
completionTokens: 5,
|
||||
cacheCreationInputTokens: 0,
|
||||
cacheReadInputTokens: 0,
|
||||
},
|
||||
totalCostUsd: "0.00",
|
||||
},
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks.length).toBe(2); // Text chunk and usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Test response'
|
||||
});
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'usage',
|
||||
inputTokens: 10,
|
||||
outputTokens: 5,
|
||||
cacheWriteTokens: 0,
|
||||
cacheReadTokens: 0,
|
||||
totalCost: 0
|
||||
});
|
||||
expect(chunks.length).toBe(2) // Text chunk and usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: "text",
|
||||
text: "Test response",
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: "usage",
|
||||
inputTokens: 10,
|
||||
outputTokens: 5,
|
||||
cacheWriteTokens: 0,
|
||||
cacheReadTokens: 0,
|
||||
totalCost: 0,
|
||||
})
|
||||
|
||||
mockAxios.mockRestore();
|
||||
});
|
||||
mockAxios.mockRestore()
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockImplementationOnce(() => {
|
||||
throw new Error('API Error');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockImplementationOnce(() => {
|
||||
throw new Error("API Error")
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks = [];
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks = []
|
||||
|
||||
try {
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
fail('Expected error to be thrown');
|
||||
} catch (error) {
|
||||
expect(error).toBeInstanceOf(Error);
|
||||
expect(error.message).toBe('API Error');
|
||||
}
|
||||
});
|
||||
});
|
||||
try {
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
fail("Expected error to be thrown")
|
||||
} catch (error) {
|
||||
expect(error).toBeInstanceOf(Error)
|
||||
expect(error.message).toBe("API Error")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
model: mockOptions.apiModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0,
|
||||
max_tokens: 8192
|
||||
}));
|
||||
});
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: mockOptions.apiModelId,
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
max_tokens: 8192,
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Glama completion error: API Error');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Glama completion error: API Error")
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: "" } }],
|
||||
})
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
|
||||
it('should not set max_tokens for non-Anthropic models', async () => {
|
||||
// Reset mock to clear any previous calls
|
||||
mockCreate.mockClear();
|
||||
|
||||
const nonAnthropicOptions = {
|
||||
apiModelId: 'openai/gpt-4',
|
||||
glamaModelId: 'openai/gpt-4',
|
||||
glamaApiKey: 'test-key',
|
||||
glamaModelInfo: {
|
||||
maxTokens: 4096,
|
||||
contextWindow: 8192,
|
||||
supportsImages: true,
|
||||
supportsPromptCache: false
|
||||
}
|
||||
};
|
||||
const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions);
|
||||
it("should not set max_tokens for non-Anthropic models", async () => {
|
||||
// Reset mock to clear any previous calls
|
||||
mockCreate.mockClear()
|
||||
|
||||
await nonAnthropicHandler.completePrompt('Test prompt');
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
model: 'openai/gpt-4',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0
|
||||
}));
|
||||
expect(mockCreate.mock.calls[0][0]).not.toHaveProperty('max_tokens');
|
||||
});
|
||||
});
|
||||
const nonAnthropicOptions = {
|
||||
apiModelId: "openai/gpt-4",
|
||||
glamaModelId: "openai/gpt-4",
|
||||
glamaApiKey: "test-key",
|
||||
glamaModelInfo: {
|
||||
maxTokens: 4096,
|
||||
contextWindow: 8192,
|
||||
supportsImages: true,
|
||||
supportsPromptCache: false,
|
||||
},
|
||||
}
|
||||
const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions)
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe(mockOptions.apiModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(8192);
|
||||
expect(modelInfo.info.contextWindow).toBe(200_000);
|
||||
});
|
||||
});
|
||||
});
|
||||
await nonAnthropicHandler.completePrompt("Test prompt")
|
||||
expect(mockCreate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: "openai/gpt-4",
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
}),
|
||||
)
|
||||
expect(mockCreate.mock.calls[0][0]).not.toHaveProperty("max_tokens")
|
||||
})
|
||||
})
|
||||
|
||||
describe("getModel", () => {
|
||||
it("should return model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe(mockOptions.apiModelId)
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(8192)
|
||||
expect(modelInfo.info.contextWindow).toBe(200_000)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,160 +1,167 @@
|
||||
import { LmStudioHandler } from '../lmstudio';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { LmStudioHandler } from "../lmstudio"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response" },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
describe('LmStudioHandler', () => {
|
||||
let handler: LmStudioHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: 'local-model',
|
||||
lmStudioModelId: 'local-model',
|
||||
lmStudioBaseUrl: 'http://localhost:1234/v1'
|
||||
};
|
||||
handler = new LmStudioHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
describe("LmStudioHandler", () => {
|
||||
let handler: LmStudioHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(LmStudioHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId);
|
||||
});
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: "local-model",
|
||||
lmStudioModelId: "local-model",
|
||||
lmStudioBaseUrl: "http://localhost:1234/v1",
|
||||
}
|
||||
handler = new LmStudioHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
it('should use default base URL if not provided', () => {
|
||||
const handlerWithoutUrl = new LmStudioHandler({
|
||||
apiModelId: 'local-model',
|
||||
lmStudioModelId: 'local-model'
|
||||
});
|
||||
expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler);
|
||||
});
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(LmStudioHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId)
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello!'
|
||||
}
|
||||
];
|
||||
it("should use default base URL if not provided", () => {
|
||||
const handlerWithoutUrl = new LmStudioHandler({
|
||||
apiModelId: "local-model",
|
||||
lmStudioModelId: "local-model",
|
||||
})
|
||||
expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler)
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello!",
|
||||
},
|
||||
]
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
it("should handle streaming responses", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(1)
|
||||
expect(textChunks[0].text).toBe("Test response")
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
|
||||
});
|
||||
});
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.lmStudioModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
});
|
||||
});
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow("Please check the LM Studio developer logs to debug what went wrong")
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
|
||||
});
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.lmStudioModelId,
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
stream: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"Please check the LM Studio developer logs to debug what went wrong",
|
||||
)
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe(mockOptions.lmStudioModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(-1);
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||
});
|
||||
});
|
||||
});
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: "" } }],
|
||||
})
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe("getModel", () => {
|
||||
it("should return model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe(mockOptions.lmStudioModelId)
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(-1)
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,160 +1,165 @@
|
||||
import { OllamaHandler } from '../ollama';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { OllamaHandler } from "../ollama"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response" },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
describe('OllamaHandler', () => {
|
||||
let handler: OllamaHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: 'llama2',
|
||||
ollamaModelId: 'llama2',
|
||||
ollamaBaseUrl: 'http://localhost:11434/v1'
|
||||
};
|
||||
handler = new OllamaHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
describe("OllamaHandler", () => {
|
||||
let handler: OllamaHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(OllamaHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.ollamaModelId);
|
||||
});
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: "llama2",
|
||||
ollamaModelId: "llama2",
|
||||
ollamaBaseUrl: "http://localhost:11434/v1",
|
||||
}
|
||||
handler = new OllamaHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
it('should use default base URL if not provided', () => {
|
||||
const handlerWithoutUrl = new OllamaHandler({
|
||||
apiModelId: 'llama2',
|
||||
ollamaModelId: 'llama2'
|
||||
});
|
||||
expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler);
|
||||
});
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(OllamaHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.ollamaModelId)
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello!'
|
||||
}
|
||||
];
|
||||
it("should use default base URL if not provided", () => {
|
||||
const handlerWithoutUrl = new OllamaHandler({
|
||||
apiModelId: "llama2",
|
||||
ollamaModelId: "llama2",
|
||||
})
|
||||
expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler)
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello!",
|
||||
},
|
||||
]
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
it("should handle streaming responses", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(1)
|
||||
expect(textChunks[0].text).toBe("Test response")
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow('API Error');
|
||||
});
|
||||
});
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.ollamaModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
});
|
||||
});
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow("API Error")
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Ollama completion error: API Error');
|
||||
});
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.ollamaModelId,
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
stream: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Ollama completion error: API Error")
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe(mockOptions.ollamaModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(-1);
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||
});
|
||||
});
|
||||
});
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: "" } }],
|
||||
})
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe("getModel", () => {
|
||||
it("should return model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe(mockOptions.ollamaModelId)
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(-1)
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,319 +1,326 @@
|
||||
import { OpenAiNativeHandler } from '../openai-native';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { OpenAiNativeHandler } from "../openai-native"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response" },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
describe('OpenAiNativeHandler', () => {
|
||||
let handler: OpenAiNativeHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello!'
|
||||
}
|
||||
];
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: 'gpt-4o',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
};
|
||||
handler = new OpenAiNativeHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
describe("OpenAiNativeHandler", () => {
|
||||
let handler: OpenAiNativeHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello!",
|
||||
},
|
||||
]
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(OpenAiNativeHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
|
||||
});
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: "gpt-4o",
|
||||
openAiNativeApiKey: "test-api-key",
|
||||
}
|
||||
handler = new OpenAiNativeHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
it('should initialize with empty API key', () => {
|
||||
const handlerWithoutKey = new OpenAiNativeHandler({
|
||||
apiModelId: 'gpt-4o',
|
||||
openAiNativeApiKey: ''
|
||||
});
|
||||
expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler);
|
||||
});
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(OpenAiNativeHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
it("should initialize with empty API key", () => {
|
||||
const handlerWithoutKey = new OpenAiNativeHandler({
|
||||
apiModelId: "gpt-4o",
|
||||
openAiNativeApiKey: "",
|
||||
})
|
||||
expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler)
|
||||
})
|
||||
})
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
describe("createMessage", () => {
|
||||
it("should handle streaming responses", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow('API Error');
|
||||
});
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(1)
|
||||
expect(textChunks[0].text).toBe("Test response")
|
||||
})
|
||||
|
||||
it('should handle missing content in response for o1 model', async () => {
|
||||
// Use o1 model which supports developer role
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: 'o1'
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow("API Error")
|
||||
})
|
||||
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: null } }],
|
||||
usage: {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0
|
||||
}
|
||||
});
|
||||
it("should handle missing content in response for o1 model", async () => {
|
||||
// Use o1 model which supports developer role
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: "o1",
|
||||
})
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages);
|
||||
const results = [];
|
||||
for await (const result of generator) {
|
||||
results.push(result);
|
||||
}
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: null } }],
|
||||
usage: {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
},
|
||||
})
|
||||
|
||||
expect(results).toEqual([
|
||||
{ type: 'text', text: '' },
|
||||
{ type: 'usage', inputTokens: 0, outputTokens: 0 }
|
||||
]);
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const results = []
|
||||
for await (const result of generator) {
|
||||
results.push(result)
|
||||
}
|
||||
|
||||
// Verify developer role is used for system prompt with o1 model
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1',
|
||||
messages: [
|
||||
{ role: 'developer', content: systemPrompt },
|
||||
{ role: 'user', content: 'Hello!' }
|
||||
]
|
||||
});
|
||||
});
|
||||
});
|
||||
expect(results).toEqual([
|
||||
{ type: "text", text: "" },
|
||||
{ type: "usage", inputTokens: 0, outputTokens: 0 },
|
||||
])
|
||||
|
||||
describe('streaming models', () => {
|
||||
beforeEach(() => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: 'gpt-4o',
|
||||
});
|
||||
});
|
||||
// Verify developer role is used for system prompt with o1 model
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "o1",
|
||||
messages: [
|
||||
{ role: "developer", content: systemPrompt },
|
||||
{ role: "user", content: "Hello!" },
|
||||
],
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle streaming response', async () => {
|
||||
const mockStream = [
|
||||
{ choices: [{ delta: { content: 'Hello' } }], usage: null },
|
||||
{ choices: [{ delta: { content: ' there' } }], usage: null },
|
||||
{ choices: [{ delta: { content: '!' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||
];
|
||||
describe("streaming models", () => {
|
||||
beforeEach(() => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: "gpt-4o",
|
||||
})
|
||||
})
|
||||
|
||||
mockCreate.mockResolvedValueOnce(
|
||||
(async function* () {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk;
|
||||
}
|
||||
})()
|
||||
);
|
||||
it("should handle streaming response", async () => {
|
||||
const mockStream = [
|
||||
{ choices: [{ delta: { content: "Hello" } }], usage: null },
|
||||
{ choices: [{ delta: { content: " there" } }], usage: null },
|
||||
{ choices: [{ delta: { content: "!" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||
]
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages);
|
||||
const results = [];
|
||||
for await (const result of generator) {
|
||||
results.push(result);
|
||||
}
|
||||
mockCreate.mockResolvedValueOnce(
|
||||
(async function* () {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk
|
||||
}
|
||||
})(),
|
||||
)
|
||||
|
||||
expect(results).toEqual([
|
||||
{ type: 'text', text: 'Hello' },
|
||||
{ type: 'text', text: ' there' },
|
||||
{ type: 'text', text: '!' },
|
||||
{ type: 'usage', inputTokens: 10, outputTokens: 5 },
|
||||
]);
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const results = []
|
||||
for await (const result of generator) {
|
||||
results.push(result)
|
||||
}
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'gpt-4o',
|
||||
temperature: 0,
|
||||
messages: [
|
||||
{ role: 'system', content: systemPrompt },
|
||||
{ role: 'user', content: 'Hello!' },
|
||||
],
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
});
|
||||
});
|
||||
expect(results).toEqual([
|
||||
{ type: "text", text: "Hello" },
|
||||
{ type: "text", text: " there" },
|
||||
{ type: "text", text: "!" },
|
||||
{ type: "usage", inputTokens: 10, outputTokens: 5 },
|
||||
])
|
||||
|
||||
it('should handle empty delta content', async () => {
|
||||
const mockStream = [
|
||||
{ choices: [{ delta: {} }], usage: null },
|
||||
{ choices: [{ delta: { content: null } }], usage: null },
|
||||
{ choices: [{ delta: { content: 'Hello' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||
];
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "gpt-4o",
|
||||
temperature: 0,
|
||||
messages: [
|
||||
{ role: "system", content: systemPrompt },
|
||||
{ role: "user", content: "Hello!" },
|
||||
],
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
})
|
||||
})
|
||||
|
||||
mockCreate.mockResolvedValueOnce(
|
||||
(async function* () {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk;
|
||||
}
|
||||
})()
|
||||
);
|
||||
it("should handle empty delta content", async () => {
|
||||
const mockStream = [
|
||||
{ choices: [{ delta: {} }], usage: null },
|
||||
{ choices: [{ delta: { content: null } }], usage: null },
|
||||
{ choices: [{ delta: { content: "Hello" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||
]
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages);
|
||||
const results = [];
|
||||
for await (const result of generator) {
|
||||
results.push(result);
|
||||
}
|
||||
mockCreate.mockResolvedValueOnce(
|
||||
(async function* () {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk
|
||||
}
|
||||
})(),
|
||||
)
|
||||
|
||||
expect(results).toEqual([
|
||||
{ type: 'text', text: 'Hello' },
|
||||
{ type: 'usage', inputTokens: 10, outputTokens: 5 },
|
||||
]);
|
||||
});
|
||||
});
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const results = []
|
||||
for await (const result of generator) {
|
||||
results.push(result)
|
||||
}
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully with gpt-4o model', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'gpt-4o',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0
|
||||
});
|
||||
});
|
||||
expect(results).toEqual([
|
||||
{ type: "text", text: "Hello" },
|
||||
{ type: "usage", inputTokens: 10, outputTokens: 5 },
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
it('should complete prompt successfully with o1 model', async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: 'o1',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully with gpt-4o model", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "gpt-4o",
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
})
|
||||
})
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||
});
|
||||
});
|
||||
it("should complete prompt successfully with o1 model", async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: "o1",
|
||||
openAiNativeApiKey: "test-api-key",
|
||||
})
|
||||
|
||||
it('should complete prompt successfully with o1-preview model', async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: 'o1-preview',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "o1",
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
})
|
||||
})
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1-preview',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||
});
|
||||
});
|
||||
it("should complete prompt successfully with o1-preview model", async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: "o1-preview",
|
||||
openAiNativeApiKey: "test-api-key",
|
||||
})
|
||||
|
||||
it('should complete prompt successfully with o1-mini model', async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: 'o1-mini',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "o1-preview",
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
})
|
||||
})
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1-mini',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||
});
|
||||
});
|
||||
it("should complete prompt successfully with o1-mini model", async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: "o1-mini",
|
||||
openAiNativeApiKey: "test-api-key",
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('OpenAI Native completion error: API Error');
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "o1-mini",
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"OpenAI Native completion error: API Error",
|
||||
)
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe(mockOptions.apiModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(4096);
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||
});
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: "" } }],
|
||||
})
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle undefined model ID', () => {
|
||||
const handlerWithoutModel = new OpenAiNativeHandler({
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
const modelInfo = handlerWithoutModel.getModel();
|
||||
expect(modelInfo.id).toBe('gpt-4o'); // Default model
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe(mockOptions.apiModelId)
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(4096)
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000)
|
||||
})
|
||||
|
||||
it("should handle undefined model ID", () => {
|
||||
const handlerWithoutModel = new OpenAiNativeHandler({
|
||||
openAiNativeApiKey: "test-api-key",
|
||||
})
|
||||
const modelInfo = handlerWithoutModel.getModel()
|
||||
expect(modelInfo.id).toBe("gpt-4o") // Default model
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,224 +1,233 @@
|
||||
import { OpenAiHandler } from '../openai';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import { ApiStream } from '../../transform/stream';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { OpenAiHandler } from "../openai"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import { ApiStream } from "../../transform/stream"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response', refusal: null },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response", refusal: null },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
describe('OpenAiHandler', () => {
|
||||
let handler: OpenAiHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
openAiApiKey: 'test-api-key',
|
||||
openAiModelId: 'gpt-4',
|
||||
openAiBaseUrl: 'https://api.openai.com/v1'
|
||||
};
|
||||
handler = new OpenAiHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
describe("OpenAiHandler", () => {
|
||||
let handler: OpenAiHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(OpenAiHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.openAiModelId);
|
||||
});
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
openAiApiKey: "test-api-key",
|
||||
openAiModelId: "gpt-4",
|
||||
openAiBaseUrl: "https://api.openai.com/v1",
|
||||
}
|
||||
handler = new OpenAiHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
it('should use custom base URL if provided', () => {
|
||||
const customBaseUrl = 'https://custom.openai.com/v1';
|
||||
const handlerWithCustomUrl = new OpenAiHandler({
|
||||
...mockOptions,
|
||||
openAiBaseUrl: customBaseUrl
|
||||
});
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(OpenAiHandler);
|
||||
});
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(OpenAiHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.openAiModelId)
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: [{
|
||||
type: 'text' as const,
|
||||
text: 'Hello!'
|
||||
}]
|
||||
}
|
||||
];
|
||||
it("should use custom base URL if provided", () => {
|
||||
const customBaseUrl = "https://custom.openai.com/v1"
|
||||
const handlerWithCustomUrl = new OpenAiHandler({
|
||||
...mockOptions,
|
||||
openAiBaseUrl: customBaseUrl,
|
||||
})
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(OpenAiHandler)
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle non-streaming mode', async () => {
|
||||
const handler = new OpenAiHandler({
|
||||
...mockOptions,
|
||||
openAiStreamingEnabled: false
|
||||
});
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: "Hello!",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
it("should handle non-streaming mode", async () => {
|
||||
const handler = new OpenAiHandler({
|
||||
...mockOptions,
|
||||
openAiStreamingEnabled: false,
|
||||
})
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunk = chunks.find(chunk => chunk.type === 'text');
|
||||
const usageChunk = chunks.find(chunk => chunk.type === 'usage');
|
||||
|
||||
expect(textChunk).toBeDefined();
|
||||
expect(textChunk?.text).toBe('Test response');
|
||||
expect(usageChunk).toBeDefined();
|
||||
expect(usageChunk?.inputTokens).toBe(10);
|
||||
expect(usageChunk?.outputTokens).toBe(5);
|
||||
});
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunk = chunks.find((chunk) => chunk.type === "text")
|
||||
const usageChunk = chunks.find((chunk) => chunk.type === "usage")
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
});
|
||||
expect(textChunk).toBeDefined()
|
||||
expect(textChunk?.text).toBe("Test response")
|
||||
expect(usageChunk).toBeDefined()
|
||||
expect(usageChunk?.inputTokens).toBe(10)
|
||||
expect(usageChunk?.outputTokens).toBe(5)
|
||||
})
|
||||
|
||||
describe('error handling', () => {
|
||||
const testMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: [{
|
||||
type: 'text' as const,
|
||||
text: 'Hello'
|
||||
}]
|
||||
}
|
||||
];
|
||||
it("should handle streaming responses", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(1)
|
||||
expect(textChunks[0].text).toBe("Test response")
|
||||
})
|
||||
})
|
||||
|
||||
const stream = handler.createMessage('system prompt', testMessages);
|
||||
describe("error handling", () => {
|
||||
const testMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: "Hello",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow('API Error');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
|
||||
it('should handle rate limiting', async () => {
|
||||
const rateLimitError = new Error('Rate limit exceeded');
|
||||
rateLimitError.name = 'Error';
|
||||
(rateLimitError as any).status = 429;
|
||||
mockCreate.mockRejectedValueOnce(rateLimitError);
|
||||
const stream = handler.createMessage("system prompt", testMessages)
|
||||
|
||||
const stream = handler.createMessage('system prompt', testMessages);
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow("API Error")
|
||||
})
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow('Rate limit exceeded');
|
||||
});
|
||||
});
|
||||
it("should handle rate limiting", async () => {
|
||||
const rateLimitError = new Error("Rate limit exceeded")
|
||||
rateLimitError.name = "Error"
|
||||
;(rateLimitError as any).status = 429
|
||||
mockCreate.mockRejectedValueOnce(rateLimitError)
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.openAiModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0
|
||||
});
|
||||
});
|
||||
const stream = handler.createMessage("system prompt", testMessages)
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('OpenAI completion error: API Error');
|
||||
});
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow("Rate limit exceeded")
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockImplementationOnce(() => ({
|
||||
choices: [{ message: { content: '' } }]
|
||||
}));
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.openAiModelId,
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
})
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info with sane defaults', () => {
|
||||
const model = handler.getModel();
|
||||
expect(model.id).toBe(mockOptions.openAiModelId);
|
||||
expect(model.info).toBeDefined();
|
||||
expect(model.info.contextWindow).toBe(128_000);
|
||||
expect(model.info.supportsImages).toBe(true);
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("OpenAI completion error: API Error")
|
||||
})
|
||||
|
||||
it('should handle undefined model ID', () => {
|
||||
const handlerWithoutModel = new OpenAiHandler({
|
||||
...mockOptions,
|
||||
openAiModelId: undefined
|
||||
});
|
||||
const model = handlerWithoutModel.getModel();
|
||||
expect(model.id).toBe('');
|
||||
expect(model.info).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockImplementationOnce(() => ({
|
||||
choices: [{ message: { content: "" } }],
|
||||
}))
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe("getModel", () => {
|
||||
it("should return model info with sane defaults", () => {
|
||||
const model = handler.getModel()
|
||||
expect(model.id).toBe(mockOptions.openAiModelId)
|
||||
expect(model.info).toBeDefined()
|
||||
expect(model.info.contextWindow).toBe(128_000)
|
||||
expect(model.info.supportsImages).toBe(true)
|
||||
})
|
||||
|
||||
it("should handle undefined model ID", () => {
|
||||
const handlerWithoutModel = new OpenAiHandler({
|
||||
...mockOptions,
|
||||
openAiModelId: undefined,
|
||||
})
|
||||
const model = handlerWithoutModel.getModel()
|
||||
expect(model.id).toBe("")
|
||||
expect(model.info).toBeDefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,283 +1,297 @@
|
||||
import { OpenRouterHandler } from '../openrouter'
|
||||
import { ApiHandlerOptions, ModelInfo } from '../../../shared/api'
|
||||
import OpenAI from 'openai'
|
||||
import axios from 'axios'
|
||||
import { Anthropic } from '@anthropic-ai/sdk'
|
||||
import { OpenRouterHandler } from "../openrouter"
|
||||
import { ApiHandlerOptions, ModelInfo } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import axios from "axios"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock dependencies
|
||||
jest.mock('openai')
|
||||
jest.mock('axios')
|
||||
jest.mock('delay', () => jest.fn(() => Promise.resolve()))
|
||||
jest.mock("openai")
|
||||
jest.mock("axios")
|
||||
jest.mock("delay", () => jest.fn(() => Promise.resolve()))
|
||||
|
||||
describe('OpenRouterHandler', () => {
|
||||
const mockOptions: ApiHandlerOptions = {
|
||||
openRouterApiKey: 'test-key',
|
||||
openRouterModelId: 'test-model',
|
||||
openRouterModelInfo: {
|
||||
name: 'Test Model',
|
||||
description: 'Test Description',
|
||||
maxTokens: 1000,
|
||||
contextWindow: 2000,
|
||||
supportsPromptCache: true,
|
||||
inputPrice: 0.01,
|
||||
outputPrice: 0.02
|
||||
} as ModelInfo
|
||||
}
|
||||
describe("OpenRouterHandler", () => {
|
||||
const mockOptions: ApiHandlerOptions = {
|
||||
openRouterApiKey: "test-key",
|
||||
openRouterModelId: "test-model",
|
||||
openRouterModelInfo: {
|
||||
name: "Test Model",
|
||||
description: "Test Description",
|
||||
maxTokens: 1000,
|
||||
contextWindow: 2000,
|
||||
supportsPromptCache: true,
|
||||
inputPrice: 0.01,
|
||||
outputPrice: 0.02,
|
||||
} as ModelInfo,
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
test('constructor initializes with correct options', () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
expect(handler).toBeInstanceOf(OpenRouterHandler)
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
baseURL: 'https://openrouter.ai/api/v1',
|
||||
apiKey: mockOptions.openRouterApiKey,
|
||||
defaultHeaders: {
|
||||
'HTTP-Referer': 'https://github.com/RooVetGit/Roo-Cline',
|
||||
'X-Title': 'Roo-Cline',
|
||||
},
|
||||
})
|
||||
})
|
||||
test("constructor initializes with correct options", () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
expect(handler).toBeInstanceOf(OpenRouterHandler)
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
baseURL: "https://openrouter.ai/api/v1",
|
||||
apiKey: mockOptions.openRouterApiKey,
|
||||
defaultHeaders: {
|
||||
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
|
||||
"X-Title": "Roo-Cline",
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
test('getModel returns correct model info when options are provided', () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const result = handler.getModel()
|
||||
|
||||
expect(result).toEqual({
|
||||
id: mockOptions.openRouterModelId,
|
||||
info: mockOptions.openRouterModelInfo
|
||||
})
|
||||
})
|
||||
test("getModel returns correct model info when options are provided", () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const result = handler.getModel()
|
||||
|
||||
test('getModel returns default model info when options are not provided', () => {
|
||||
const handler = new OpenRouterHandler({})
|
||||
const result = handler.getModel()
|
||||
|
||||
expect(result.id).toBe('anthropic/claude-3.5-sonnet:beta')
|
||||
expect(result.info.supportsPromptCache).toBe(true)
|
||||
})
|
||||
expect(result).toEqual({
|
||||
id: mockOptions.openRouterModelId,
|
||||
info: mockOptions.openRouterModelInfo,
|
||||
})
|
||||
})
|
||||
|
||||
test('createMessage generates correct stream chunks', async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: 'test-id',
|
||||
choices: [{
|
||||
delta: {
|
||||
content: 'test response'
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
test("getModel returns default model info when options are not provided", () => {
|
||||
const handler = new OpenRouterHandler({})
|
||||
const result = handler.getModel()
|
||||
|
||||
// Mock OpenAI chat.completions.create
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
expect(result.id).toBe("anthropic/claude-3.5-sonnet:beta")
|
||||
expect(result.info.supportsPromptCache).toBe(true)
|
||||
})
|
||||
|
||||
// Mock axios.get for generation details
|
||||
;(axios.get as jest.Mock).mockResolvedValue({
|
||||
data: {
|
||||
data: {
|
||||
native_tokens_prompt: 10,
|
||||
native_tokens_completion: 20,
|
||||
total_cost: 0.001
|
||||
}
|
||||
}
|
||||
})
|
||||
test("createMessage generates correct stream chunks", async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: "test-id",
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: "test response",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const systemPrompt = 'test system prompt'
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{ role: 'user' as const, content: 'test message' }]
|
||||
// Mock OpenAI chat.completions.create
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const chunks = []
|
||||
|
||||
for await (const chunk of generator) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
// Mock axios.get for generation details
|
||||
;(axios.get as jest.Mock).mockResolvedValue({
|
||||
data: {
|
||||
data: {
|
||||
native_tokens_prompt: 10,
|
||||
native_tokens_completion: 20,
|
||||
total_cost: 0.001,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Verify stream chunks
|
||||
expect(chunks).toHaveLength(2) // One text chunk and one usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'test response'
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'usage',
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalCost: 0.001,
|
||||
fullResponseText: 'test response'
|
||||
})
|
||||
const systemPrompt = "test system prompt"
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user" as const, content: "test message" }]
|
||||
|
||||
// Verify OpenAI client was called with correct parameters
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
model: mockOptions.openRouterModelId,
|
||||
temperature: 0,
|
||||
messages: expect.arrayContaining([
|
||||
{ role: 'system', content: systemPrompt },
|
||||
{ role: 'user', content: 'test message' }
|
||||
]),
|
||||
stream: true
|
||||
}))
|
||||
})
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const chunks = []
|
||||
|
||||
test('createMessage with middle-out transform enabled', async () => {
|
||||
const handler = new OpenRouterHandler({
|
||||
...mockOptions,
|
||||
openRouterUseMiddleOutTransform: true
|
||||
})
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: 'test-id',
|
||||
choices: [{
|
||||
delta: {
|
||||
content: 'test response'
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
for await (const chunk of generator) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
||||
// Verify stream chunks
|
||||
expect(chunks).toHaveLength(2) // One text chunk and one usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: "text",
|
||||
text: "test response",
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: "usage",
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalCost: 0.001,
|
||||
fullResponseText: "test response",
|
||||
})
|
||||
|
||||
await handler.createMessage('test', []).next()
|
||||
// Verify OpenAI client was called with correct parameters
|
||||
expect(mockCreate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: mockOptions.openRouterModelId,
|
||||
temperature: 0,
|
||||
messages: expect.arrayContaining([
|
||||
{ role: "system", content: systemPrompt },
|
||||
{ role: "user", content: "test message" },
|
||||
]),
|
||||
stream: true,
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
transforms: ['middle-out']
|
||||
}))
|
||||
})
|
||||
test("createMessage with middle-out transform enabled", async () => {
|
||||
const handler = new OpenRouterHandler({
|
||||
...mockOptions,
|
||||
openRouterUseMiddleOutTransform: true,
|
||||
})
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: "test-id",
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: "test response",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
test('createMessage with Claude model adds cache control', async () => {
|
||||
const handler = new OpenRouterHandler({
|
||||
...mockOptions,
|
||||
openRouterModelId: 'anthropic/claude-3.5-sonnet'
|
||||
})
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: 'test-id',
|
||||
choices: [{
|
||||
delta: {
|
||||
content: 'test response'
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
||||
await handler.createMessage("test", []).next()
|
||||
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{ role: 'user', content: 'message 1' },
|
||||
{ role: 'assistant', content: 'response 1' },
|
||||
{ role: 'user', content: 'message 2' }
|
||||
]
|
||||
expect(mockCreate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
transforms: ["middle-out"],
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
await handler.createMessage('test system', messages).next()
|
||||
test("createMessage with Claude model adds cache control", async () => {
|
||||
const handler = new OpenRouterHandler({
|
||||
...mockOptions,
|
||||
openRouterModelId: "anthropic/claude-3.5-sonnet",
|
||||
})
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: "test-id",
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: "test response",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
messages: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
role: 'system',
|
||||
content: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
cache_control: { type: 'ephemeral' }
|
||||
})
|
||||
])
|
||||
})
|
||||
])
|
||||
}))
|
||||
})
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
||||
|
||||
test('createMessage handles API errors', async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
error: {
|
||||
message: 'API Error',
|
||||
code: 500
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{ role: "user", content: "message 1" },
|
||||
{ role: "assistant", content: "response 1" },
|
||||
{ role: "user", content: "message 2" },
|
||||
]
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
await handler.createMessage("test system", messages).next()
|
||||
|
||||
const generator = handler.createMessage('test', [])
|
||||
await expect(generator.next()).rejects.toThrow('OpenRouter API Error 500: API Error')
|
||||
})
|
||||
expect(mockCreate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
messages: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
role: "system",
|
||||
content: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
cache_control: { type: "ephemeral" },
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
test('completePrompt returns correct response', async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockResponse = {
|
||||
choices: [{
|
||||
message: {
|
||||
content: 'test completion'
|
||||
}
|
||||
}]
|
||||
}
|
||||
test("createMessage handles API errors", async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
error: {
|
||||
message: "API Error",
|
||||
code: 500,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockResponse)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
|
||||
const result = await handler.completePrompt('test prompt')
|
||||
const generator = handler.createMessage("test", [])
|
||||
await expect(generator.next()).rejects.toThrow("OpenRouter API Error 500: API Error")
|
||||
})
|
||||
|
||||
expect(result).toBe('test completion')
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.openRouterModelId,
|
||||
messages: [{ role: 'user', content: 'test prompt' }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
})
|
||||
})
|
||||
test("completePrompt returns correct response", async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockResponse = {
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
content: "test completion",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
test('completePrompt handles API errors', async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockError = {
|
||||
error: {
|
||||
message: 'API Error',
|
||||
code: 500
|
||||
}
|
||||
}
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockResponse)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockError)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
const result = await handler.completePrompt("test prompt")
|
||||
|
||||
await expect(handler.completePrompt('test prompt'))
|
||||
.rejects.toThrow('OpenRouter API Error 500: API Error')
|
||||
})
|
||||
expect(result).toBe("test completion")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.openRouterModelId,
|
||||
messages: [{ role: "user", content: "test prompt" }],
|
||||
temperature: 0,
|
||||
stream: false,
|
||||
})
|
||||
})
|
||||
|
||||
test('completePrompt handles unexpected errors', async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockCreate = jest.fn().mockRejectedValue(new Error('Unexpected error'))
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
test("completePrompt handles API errors", async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockError = {
|
||||
error: {
|
||||
message: "API Error",
|
||||
code: 500,
|
||||
},
|
||||
}
|
||||
|
||||
await expect(handler.completePrompt('test prompt'))
|
||||
.rejects.toThrow('OpenRouter completion error: Unexpected error')
|
||||
})
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockError)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
|
||||
await expect(handler.completePrompt("test prompt")).rejects.toThrow("OpenRouter API Error 500: API Error")
|
||||
})
|
||||
|
||||
test("completePrompt handles unexpected errors", async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockCreate = jest.fn().mockRejectedValue(new Error("Unexpected error"))
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
|
||||
await expect(handler.completePrompt("test prompt")).rejects.toThrow(
|
||||
"OpenRouter completion error: Unexpected error",
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,296 +1,295 @@
|
||||
import { VertexHandler } from '../vertex';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { AnthropicVertex } from '@anthropic-ai/vertex-sdk';
|
||||
import { VertexHandler } from "../vertex"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
|
||||
|
||||
// Mock Vertex SDK
|
||||
jest.mock('@anthropic-ai/vertex-sdk', () => ({
|
||||
AnthropicVertex: jest.fn().mockImplementation(() => ({
|
||||
messages: {
|
||||
create: jest.fn().mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
content: [
|
||||
{ type: 'text', text: 'Test response' }
|
||||
],
|
||||
role: 'assistant',
|
||||
model: options.model,
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
return {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: 'message_start',
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
yield {
|
||||
type: 'content_block_start',
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'Test response'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}))
|
||||
}));
|
||||
jest.mock("@anthropic-ai/vertex-sdk", () => ({
|
||||
AnthropicVertex: jest.fn().mockImplementation(() => ({
|
||||
messages: {
|
||||
create: jest.fn().mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: "test-completion",
|
||||
content: [{ type: "text", text: "Test response" }],
|
||||
role: "assistant",
|
||||
model: options.model,
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
},
|
||||
}
|
||||
}
|
||||
return {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: "message_start",
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
yield {
|
||||
type: "content_block_start",
|
||||
content_block: {
|
||||
type: "text",
|
||||
text: "Test response",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
})),
|
||||
}))
|
||||
|
||||
describe('VertexHandler', () => {
|
||||
let handler: VertexHandler;
|
||||
describe("VertexHandler", () => {
|
||||
let handler: VertexHandler
|
||||
|
||||
beforeEach(() => {
|
||||
handler = new VertexHandler({
|
||||
apiModelId: 'claude-3-5-sonnet-v2@20241022',
|
||||
vertexProjectId: 'test-project',
|
||||
vertexRegion: 'us-central1'
|
||||
});
|
||||
});
|
||||
beforeEach(() => {
|
||||
handler = new VertexHandler({
|
||||
apiModelId: "claude-3-5-sonnet-v2@20241022",
|
||||
vertexProjectId: "test-project",
|
||||
vertexRegion: "us-central1",
|
||||
})
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided config', () => {
|
||||
expect(AnthropicVertex).toHaveBeenCalledWith({
|
||||
projectId: 'test-project',
|
||||
region: 'us-central1'
|
||||
});
|
||||
});
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided config", () => {
|
||||
expect(AnthropicVertex).toHaveBeenCalledWith({
|
||||
projectId: "test-project",
|
||||
region: "us-central1",
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello'
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi there!'
|
||||
}
|
||||
];
|
||||
describe("createMessage", () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello",
|
||||
},
|
||||
{
|
||||
role: "assistant",
|
||||
content: "Hi there!",
|
||||
},
|
||||
]
|
||||
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
|
||||
it('should handle streaming responses correctly', async () => {
|
||||
const mockStream = [
|
||||
{
|
||||
type: 'message_start',
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 0
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
type: 'content_block_start',
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'Hello'
|
||||
}
|
||||
},
|
||||
{
|
||||
type: 'content_block_delta',
|
||||
delta: {
|
||||
type: 'text_delta',
|
||||
text: ' world!'
|
||||
}
|
||||
},
|
||||
{
|
||||
type: 'message_delta',
|
||||
usage: {
|
||||
output_tokens: 5
|
||||
}
|
||||
}
|
||||
];
|
||||
it("should handle streaming responses correctly", async () => {
|
||||
const mockStream = [
|
||||
{
|
||||
type: "message_start",
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "content_block_start",
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: "text",
|
||||
text: "Hello",
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "content_block_delta",
|
||||
delta: {
|
||||
type: "text_delta",
|
||||
text: " world!",
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "message_delta",
|
||||
usage: {
|
||||
output_tokens: 5,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
// Setup async iterator for mock stream
|
||||
const asyncIterator = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk;
|
||||
}
|
||||
}
|
||||
};
|
||||
// Setup async iterator for mock stream
|
||||
const asyncIterator = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const chunks = [];
|
||||
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
const chunks = []
|
||||
|
||||
expect(chunks.length).toBe(4);
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'usage',
|
||||
inputTokens: 10,
|
||||
outputTokens: 0
|
||||
});
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Hello'
|
||||
});
|
||||
expect(chunks[2]).toEqual({
|
||||
type: 'text',
|
||||
text: ' world!'
|
||||
});
|
||||
expect(chunks[3]).toEqual({
|
||||
type: 'usage',
|
||||
inputTokens: 0,
|
||||
outputTokens: 5
|
||||
});
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'claude-3-5-sonnet-v2@20241022',
|
||||
max_tokens: 8192,
|
||||
temperature: 0,
|
||||
system: systemPrompt,
|
||||
messages: mockMessages,
|
||||
stream: true
|
||||
});
|
||||
});
|
||||
expect(chunks.length).toBe(4)
|
||||
expect(chunks[0]).toEqual({
|
||||
type: "usage",
|
||||
inputTokens: 10,
|
||||
outputTokens: 0,
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: "text",
|
||||
text: "Hello",
|
||||
})
|
||||
expect(chunks[2]).toEqual({
|
||||
type: "text",
|
||||
text: " world!",
|
||||
})
|
||||
expect(chunks[3]).toEqual({
|
||||
type: "usage",
|
||||
inputTokens: 0,
|
||||
outputTokens: 5,
|
||||
})
|
||||
|
||||
it('should handle multiple content blocks with line breaks', async () => {
|
||||
const mockStream = [
|
||||
{
|
||||
type: 'content_block_start',
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'First line'
|
||||
}
|
||||
},
|
||||
{
|
||||
type: 'content_block_start',
|
||||
index: 1,
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'Second line'
|
||||
}
|
||||
}
|
||||
];
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "claude-3-5-sonnet-v2@20241022",
|
||||
max_tokens: 8192,
|
||||
temperature: 0,
|
||||
system: systemPrompt,
|
||||
messages: mockMessages,
|
||||
stream: true,
|
||||
})
|
||||
})
|
||||
|
||||
const asyncIterator = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk;
|
||||
}
|
||||
}
|
||||
};
|
||||
it("should handle multiple content blocks with line breaks", async () => {
|
||||
const mockStream = [
|
||||
{
|
||||
type: "content_block_start",
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: "text",
|
||||
text: "First line",
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "content_block_start",
|
||||
index: 1,
|
||||
content_block: {
|
||||
type: "text",
|
||||
text: "Second line",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
const asyncIterator = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const chunks = [];
|
||||
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
expect(chunks.length).toBe(3);
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'First line'
|
||||
});
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'text',
|
||||
text: '\n'
|
||||
});
|
||||
expect(chunks[2]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Second line'
|
||||
});
|
||||
});
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
const chunks = []
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('Vertex API error');
|
||||
const mockCreate = jest.fn().mockRejectedValue(mockError);
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
expect(chunks.length).toBe(3)
|
||||
expect(chunks[0]).toEqual({
|
||||
type: "text",
|
||||
text: "First line",
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: "text",
|
||||
text: "\n",
|
||||
})
|
||||
expect(chunks[2]).toEqual({
|
||||
type: "text",
|
||||
text: "Second line",
|
||||
})
|
||||
})
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
}
|
||||
}).rejects.toThrow('Vertex API error');
|
||||
});
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
const mockError = new Error("Vertex API error")
|
||||
const mockCreate = jest.fn().mockRejectedValue(mockError)
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(handler['client'].messages.create).toHaveBeenCalledWith({
|
||||
model: 'claude-3-5-sonnet-v2@20241022',
|
||||
max_tokens: 8192,
|
||||
temperature: 0,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
stream: false
|
||||
});
|
||||
});
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('Vertex API error');
|
||||
const mockCreate = jest.fn().mockRejectedValue(mockError);
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
}
|
||||
}).rejects.toThrow("Vertex API error")
|
||||
})
|
||||
})
|
||||
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Vertex completion error: Vertex API error');
|
||||
});
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(handler["client"].messages.create).toHaveBeenCalledWith({
|
||||
model: "claude-3-5-sonnet-v2@20241022",
|
||||
max_tokens: 8192,
|
||||
temperature: 0,
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
stream: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle non-text content', async () => {
|
||||
const mockCreate = jest.fn().mockResolvedValue({
|
||||
content: [{ type: 'image' }]
|
||||
});
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
it("should handle API errors", async () => {
|
||||
const mockError = new Error("Vertex API error")
|
||||
const mockCreate = jest.fn().mockRejectedValue(mockError)
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"Vertex completion error: Vertex API error",
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
const mockCreate = jest.fn().mockResolvedValue({
|
||||
content: [{ type: 'text', text: '' }]
|
||||
});
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
it("should handle non-text content", async () => {
|
||||
const mockCreate = jest.fn().mockResolvedValue({
|
||||
content: [{ type: "image" }],
|
||||
})
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return correct model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe('claude-3-5-sonnet-v2@20241022');
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(8192);
|
||||
expect(modelInfo.info.contextWindow).toBe(200_000);
|
||||
});
|
||||
it("should handle empty response", async () => {
|
||||
const mockCreate = jest.fn().mockResolvedValue({
|
||||
content: [{ type: "text", text: "" }],
|
||||
})
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
it('should return default model if invalid model specified', () => {
|
||||
const invalidHandler = new VertexHandler({
|
||||
apiModelId: 'invalid-model',
|
||||
vertexProjectId: 'test-project',
|
||||
vertexRegion: 'us-central1'
|
||||
});
|
||||
const modelInfo = invalidHandler.getModel();
|
||||
expect(modelInfo.id).toBe('claude-3-5-sonnet-v2@20241022'); // Default model
|
||||
});
|
||||
});
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe("getModel", () => {
|
||||
it("should return correct model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe("claude-3-5-sonnet-v2@20241022")
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(8192)
|
||||
expect(modelInfo.info.contextWindow).toBe(200_000)
|
||||
})
|
||||
|
||||
it("should return default model if invalid model specified", () => {
|
||||
const invalidHandler = new VertexHandler({
|
||||
apiModelId: "invalid-model",
|
||||
vertexProjectId: "test-project",
|
||||
vertexRegion: "us-central1",
|
||||
})
|
||||
const modelInfo = invalidHandler.getModel()
|
||||
expect(modelInfo.id).toBe("claude-3-5-sonnet-v2@20241022") // Default model
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,289 +1,295 @@
|
||||
import * as vscode from 'vscode';
|
||||
import { VsCodeLmHandler } from '../vscode-lm';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import * as vscode from "vscode"
|
||||
import { VsCodeLmHandler } from "../vscode-lm"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock vscode namespace
|
||||
jest.mock('vscode', () => {
|
||||
jest.mock("vscode", () => {
|
||||
class MockLanguageModelTextPart {
|
||||
type = 'text';
|
||||
type = "text"
|
||||
constructor(public value: string) {}
|
||||
}
|
||||
|
||||
class MockLanguageModelToolCallPart {
|
||||
type = 'tool_call';
|
||||
type = "tool_call"
|
||||
constructor(
|
||||
public callId: string,
|
||||
public name: string,
|
||||
public input: any
|
||||
public input: any,
|
||||
) {}
|
||||
}
|
||||
|
||||
return {
|
||||
workspace: {
|
||||
onDidChangeConfiguration: jest.fn((callback) => ({
|
||||
dispose: jest.fn()
|
||||
}))
|
||||
dispose: jest.fn(),
|
||||
})),
|
||||
},
|
||||
CancellationTokenSource: jest.fn(() => ({
|
||||
token: {
|
||||
isCancellationRequested: false,
|
||||
onCancellationRequested: jest.fn()
|
||||
onCancellationRequested: jest.fn(),
|
||||
},
|
||||
cancel: jest.fn(),
|
||||
dispose: jest.fn()
|
||||
dispose: jest.fn(),
|
||||
})),
|
||||
CancellationError: class CancellationError extends Error {
|
||||
constructor() {
|
||||
super('Operation cancelled');
|
||||
this.name = 'CancellationError';
|
||||
super("Operation cancelled")
|
||||
this.name = "CancellationError"
|
||||
}
|
||||
},
|
||||
LanguageModelChatMessage: {
|
||||
Assistant: jest.fn((content) => ({
|
||||
role: 'assistant',
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
|
||||
role: "assistant",
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
|
||||
})),
|
||||
User: jest.fn((content) => ({
|
||||
role: 'user',
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
|
||||
}))
|
||||
role: "user",
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
|
||||
})),
|
||||
},
|
||||
LanguageModelTextPart: MockLanguageModelTextPart,
|
||||
LanguageModelToolCallPart: MockLanguageModelToolCallPart,
|
||||
lm: {
|
||||
selectChatModels: jest.fn()
|
||||
}
|
||||
};
|
||||
});
|
||||
selectChatModels: jest.fn(),
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
const mockLanguageModelChat = {
|
||||
id: 'test-model',
|
||||
name: 'Test Model',
|
||||
vendor: 'test-vendor',
|
||||
family: 'test-family',
|
||||
version: '1.0',
|
||||
id: "test-model",
|
||||
name: "Test Model",
|
||||
vendor: "test-vendor",
|
||||
family: "test-family",
|
||||
version: "1.0",
|
||||
maxInputTokens: 4096,
|
||||
sendRequest: jest.fn(),
|
||||
countTokens: jest.fn()
|
||||
};
|
||||
countTokens: jest.fn(),
|
||||
}
|
||||
|
||||
describe('VsCodeLmHandler', () => {
|
||||
let handler: VsCodeLmHandler;
|
||||
describe("VsCodeLmHandler", () => {
|
||||
let handler: VsCodeLmHandler
|
||||
const defaultOptions: ApiHandlerOptions = {
|
||||
vsCodeLmModelSelector: {
|
||||
vendor: 'test-vendor',
|
||||
family: 'test-family'
|
||||
}
|
||||
};
|
||||
vendor: "test-vendor",
|
||||
family: "test-family",
|
||||
},
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
handler = new VsCodeLmHandler(defaultOptions);
|
||||
});
|
||||
jest.clearAllMocks()
|
||||
handler = new VsCodeLmHandler(defaultOptions)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
handler.dispose();
|
||||
});
|
||||
handler.dispose()
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeDefined();
|
||||
expect(vscode.workspace.onDidChangeConfiguration).toHaveBeenCalled();
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeDefined()
|
||||
expect(vscode.workspace.onDidChangeConfiguration).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle configuration changes', () => {
|
||||
const callback = (vscode.workspace.onDidChangeConfiguration as jest.Mock).mock.calls[0][0];
|
||||
callback({ affectsConfiguration: () => true });
|
||||
it("should handle configuration changes", () => {
|
||||
const callback = (vscode.workspace.onDidChangeConfiguration as jest.Mock).mock.calls[0][0]
|
||||
callback({ affectsConfiguration: () => true })
|
||||
// Should reset client when config changes
|
||||
expect(handler['client']).toBeNull();
|
||||
});
|
||||
});
|
||||
expect(handler["client"]).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('createClient', () => {
|
||||
it('should create client with selector', async () => {
|
||||
const mockModel = { ...mockLanguageModelChat };
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
||||
describe("createClient", () => {
|
||||
it("should create client with selector", async () => {
|
||||
const mockModel = { ...mockLanguageModelChat }
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||
|
||||
const client = await handler['createClient']({
|
||||
vendor: 'test-vendor',
|
||||
family: 'test-family'
|
||||
});
|
||||
const client = await handler["createClient"]({
|
||||
vendor: "test-vendor",
|
||||
family: "test-family",
|
||||
})
|
||||
|
||||
expect(client).toBeDefined();
|
||||
expect(client.id).toBe('test-model');
|
||||
expect(client).toBeDefined()
|
||||
expect(client.id).toBe("test-model")
|
||||
expect(vscode.lm.selectChatModels).toHaveBeenCalledWith({
|
||||
vendor: 'test-vendor',
|
||||
family: 'test-family'
|
||||
});
|
||||
});
|
||||
vendor: "test-vendor",
|
||||
family: "test-family",
|
||||
})
|
||||
})
|
||||
|
||||
it('should return default client when no models available', async () => {
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([]);
|
||||
it("should return default client when no models available", async () => {
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([])
|
||||
|
||||
const client = await handler['createClient']({});
|
||||
|
||||
expect(client).toBeDefined();
|
||||
expect(client.id).toBe('default-lm');
|
||||
expect(client.vendor).toBe('vscode');
|
||||
});
|
||||
});
|
||||
const client = await handler["createClient"]({})
|
||||
|
||||
describe('createMessage', () => {
|
||||
expect(client).toBeDefined()
|
||||
expect(client.id).toBe("default-lm")
|
||||
expect(client.vendor).toBe("vscode")
|
||||
})
|
||||
})
|
||||
|
||||
describe("createMessage", () => {
|
||||
beforeEach(() => {
|
||||
const mockModel = { ...mockLanguageModelChat };
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
||||
mockLanguageModelChat.countTokens.mockResolvedValue(10);
|
||||
});
|
||||
const mockModel = { ...mockLanguageModelChat }
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||
mockLanguageModelChat.countTokens.mockResolvedValue(10)
|
||||
})
|
||||
|
||||
it('should stream text responses', async () => {
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
||||
role: 'user' as const,
|
||||
content: 'Hello'
|
||||
}];
|
||||
it("should stream text responses", async () => {
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: "Hello",
|
||||
},
|
||||
]
|
||||
|
||||
const responseText = 'Hello! How can I help you?';
|
||||
const responseText = "Hello! How can I help you?"
|
||||
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
|
||||
stream: (async function* () {
|
||||
yield new vscode.LanguageModelTextPart(responseText);
|
||||
return;
|
||||
yield new vscode.LanguageModelTextPart(responseText)
|
||||
return
|
||||
})(),
|
||||
text: (async function* () {
|
||||
yield responseText;
|
||||
return;
|
||||
})()
|
||||
});
|
||||
yield responseText
|
||||
return
|
||||
})(),
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks = [];
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks).toHaveLength(2); // Text chunk + usage chunk
|
||||
expect(chunks).toHaveLength(2) // Text chunk + usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: responseText
|
||||
});
|
||||
type: "text",
|
||||
text: responseText,
|
||||
})
|
||||
expect(chunks[1]).toMatchObject({
|
||||
type: 'usage',
|
||||
type: "usage",
|
||||
inputTokens: expect.any(Number),
|
||||
outputTokens: expect.any(Number)
|
||||
});
|
||||
});
|
||||
outputTokens: expect.any(Number),
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle tool calls', async () => {
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
||||
role: 'user' as const,
|
||||
content: 'Calculate 2+2'
|
||||
}];
|
||||
it("should handle tool calls", async () => {
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: "Calculate 2+2",
|
||||
},
|
||||
]
|
||||
|
||||
const toolCallData = {
|
||||
name: 'calculator',
|
||||
arguments: { operation: 'add', numbers: [2, 2] },
|
||||
callId: 'call-1'
|
||||
};
|
||||
name: "calculator",
|
||||
arguments: { operation: "add", numbers: [2, 2] },
|
||||
callId: "call-1",
|
||||
}
|
||||
|
||||
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
|
||||
stream: (async function* () {
|
||||
yield new vscode.LanguageModelToolCallPart(
|
||||
toolCallData.callId,
|
||||
toolCallData.name,
|
||||
toolCallData.arguments
|
||||
);
|
||||
return;
|
||||
toolCallData.arguments,
|
||||
)
|
||||
return
|
||||
})(),
|
||||
text: (async function* () {
|
||||
yield JSON.stringify({ type: 'tool_call', ...toolCallData });
|
||||
return;
|
||||
})()
|
||||
});
|
||||
yield JSON.stringify({ type: "tool_call", ...toolCallData })
|
||||
return
|
||||
})(),
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks = [];
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks).toHaveLength(2); // Tool call chunk + usage chunk
|
||||
expect(chunks).toHaveLength(2) // Tool call chunk + usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: JSON.stringify({ type: 'tool_call', ...toolCallData })
|
||||
});
|
||||
});
|
||||
type: "text",
|
||||
text: JSON.stringify({ type: "tool_call", ...toolCallData }),
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle errors', async () => {
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
||||
role: 'user' as const,
|
||||
content: 'Hello'
|
||||
}];
|
||||
it("should handle errors", async () => {
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: "Hello",
|
||||
},
|
||||
]
|
||||
|
||||
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error('API Error'));
|
||||
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("API Error"))
|
||||
|
||||
await expect(async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
for await (const _ of stream) {
|
||||
// consume stream
|
||||
}
|
||||
}).rejects.toThrow('API Error');
|
||||
});
|
||||
});
|
||||
}).rejects.toThrow("API Error")
|
||||
})
|
||||
})
|
||||
|
||||
describe("getModel", () => {
|
||||
it("should return model info when client exists", async () => {
|
||||
const mockModel = { ...mockLanguageModelChat }
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info when client exists', async () => {
|
||||
const mockModel = { ...mockLanguageModelChat };
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
||||
|
||||
// Initialize client
|
||||
await handler['getClient']();
|
||||
|
||||
const model = handler.getModel();
|
||||
expect(model.id).toBe('test-model');
|
||||
expect(model.info).toBeDefined();
|
||||
expect(model.info.contextWindow).toBe(4096);
|
||||
});
|
||||
await handler["getClient"]()
|
||||
|
||||
it('should return fallback model info when no client exists', () => {
|
||||
const model = handler.getModel();
|
||||
expect(model.id).toBe('test-vendor/test-family');
|
||||
expect(model.info).toBeDefined();
|
||||
});
|
||||
});
|
||||
const model = handler.getModel()
|
||||
expect(model.id).toBe("test-model")
|
||||
expect(model.info).toBeDefined()
|
||||
expect(model.info.contextWindow).toBe(4096)
|
||||
})
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete single prompt', async () => {
|
||||
const mockModel = { ...mockLanguageModelChat };
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
||||
it("should return fallback model info when no client exists", () => {
|
||||
const model = handler.getModel()
|
||||
expect(model.id).toBe("test-vendor/test-family")
|
||||
expect(model.info).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
const responseText = 'Completed text';
|
||||
describe("completePrompt", () => {
|
||||
it("should complete single prompt", async () => {
|
||||
const mockModel = { ...mockLanguageModelChat }
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||
|
||||
const responseText = "Completed text"
|
||||
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
|
||||
stream: (async function* () {
|
||||
yield new vscode.LanguageModelTextPart(responseText);
|
||||
return;
|
||||
yield new vscode.LanguageModelTextPart(responseText)
|
||||
return
|
||||
})(),
|
||||
text: (async function* () {
|
||||
yield responseText;
|
||||
return;
|
||||
})()
|
||||
});
|
||||
yield responseText
|
||||
return
|
||||
})(),
|
||||
})
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe(responseText);
|
||||
expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled();
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe(responseText)
|
||||
expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle errors during completion', async () => {
|
||||
const mockModel = { ...mockLanguageModelChat };
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
||||
it("should handle errors during completion", async () => {
|
||||
const mockModel = { ...mockLanguageModelChat }
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||
|
||||
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error('Completion failed'));
|
||||
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("Completion failed"))
|
||||
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects
|
||||
.toThrow('VSCode LM completion error: Completion failed');
|
||||
});
|
||||
});
|
||||
});
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"VSCode LM completion error: Completion failed",
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user