Prettier backfill

This commit is contained in:
Matt Rubens
2025-01-17 14:11:28 -05:00
parent 3bcb4ff8c5
commit 60a0a824b9
174 changed files with 15715 additions and 15428 deletions

View File

@@ -1,239 +1,238 @@
import { AnthropicHandler } from '../anthropic';
import { ApiHandlerOptions } from '../../../shared/api';
import { ApiStream } from '../../transform/stream';
import { Anthropic } from '@anthropic-ai/sdk';
import { AnthropicHandler } from "../anthropic"
import { ApiHandlerOptions } from "../../../shared/api"
import { ApiStream } from "../../transform/stream"
import { Anthropic } from "@anthropic-ai/sdk"
// Mock Anthropic client
const mockBetaCreate = jest.fn();
const mockCreate = jest.fn();
jest.mock('@anthropic-ai/sdk', () => {
return {
Anthropic: jest.fn().mockImplementation(() => ({
beta: {
promptCaching: {
messages: {
create: mockBetaCreate.mockImplementation(async () => ({
async *[Symbol.asyncIterator]() {
yield {
type: 'message_start',
message: {
usage: {
input_tokens: 100,
output_tokens: 50,
cache_creation_input_tokens: 20,
cache_read_input_tokens: 10
}
}
};
yield {
type: 'content_block_start',
index: 0,
content_block: {
type: 'text',
text: 'Hello'
}
};
yield {
type: 'content_block_delta',
delta: {
type: 'text_delta',
text: ' world'
}
};
}
}))
}
}
},
messages: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: 'test-completion',
content: [
{ type: 'text', text: 'Test response' }
],
role: 'assistant',
model: options.model,
usage: {
input_tokens: 10,
output_tokens: 5
}
}
}
return {
async *[Symbol.asyncIterator]() {
yield {
type: 'message_start',
message: {
usage: {
input_tokens: 10,
output_tokens: 5
}
}
}
yield {
type: 'content_block_start',
content_block: {
type: 'text',
text: 'Test response'
}
}
}
}
})
}
}))
};
});
const mockBetaCreate = jest.fn()
const mockCreate = jest.fn()
jest.mock("@anthropic-ai/sdk", () => {
return {
Anthropic: jest.fn().mockImplementation(() => ({
beta: {
promptCaching: {
messages: {
create: mockBetaCreate.mockImplementation(async () => ({
async *[Symbol.asyncIterator]() {
yield {
type: "message_start",
message: {
usage: {
input_tokens: 100,
output_tokens: 50,
cache_creation_input_tokens: 20,
cache_read_input_tokens: 10,
},
},
}
yield {
type: "content_block_start",
index: 0,
content_block: {
type: "text",
text: "Hello",
},
}
yield {
type: "content_block_delta",
delta: {
type: "text_delta",
text: " world",
},
}
},
})),
},
},
},
messages: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: "test-completion",
content: [{ type: "text", text: "Test response" }],
role: "assistant",
model: options.model,
usage: {
input_tokens: 10,
output_tokens: 5,
},
}
}
return {
async *[Symbol.asyncIterator]() {
yield {
type: "message_start",
message: {
usage: {
input_tokens: 10,
output_tokens: 5,
},
},
}
yield {
type: "content_block_start",
content_block: {
type: "text",
text: "Test response",
},
}
},
}
}),
},
})),
}
})
describe('AnthropicHandler', () => {
let handler: AnthropicHandler;
let mockOptions: ApiHandlerOptions;
describe("AnthropicHandler", () => {
let handler: AnthropicHandler
let mockOptions: ApiHandlerOptions
beforeEach(() => {
mockOptions = {
apiKey: 'test-api-key',
apiModelId: 'claude-3-5-sonnet-20241022'
};
handler = new AnthropicHandler(mockOptions);
mockBetaCreate.mockClear();
mockCreate.mockClear();
});
beforeEach(() => {
mockOptions = {
apiKey: "test-api-key",
apiModelId: "claude-3-5-sonnet-20241022",
}
handler = new AnthropicHandler(mockOptions)
mockBetaCreate.mockClear()
mockCreate.mockClear()
})
describe('constructor', () => {
it('should initialize with provided options', () => {
expect(handler).toBeInstanceOf(AnthropicHandler);
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
});
describe("constructor", () => {
it("should initialize with provided options", () => {
expect(handler).toBeInstanceOf(AnthropicHandler)
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
})
it('should initialize with undefined API key', () => {
// The SDK will handle API key validation, so we just verify it initializes
const handlerWithoutKey = new AnthropicHandler({
...mockOptions,
apiKey: undefined
});
expect(handlerWithoutKey).toBeInstanceOf(AnthropicHandler);
});
it("should initialize with undefined API key", () => {
// The SDK will handle API key validation, so we just verify it initializes
const handlerWithoutKey = new AnthropicHandler({
...mockOptions,
apiKey: undefined,
})
expect(handlerWithoutKey).toBeInstanceOf(AnthropicHandler)
})
it('should use custom base URL if provided', () => {
const customBaseUrl = 'https://custom.anthropic.com';
const handlerWithCustomUrl = new AnthropicHandler({
...mockOptions,
anthropicBaseUrl: customBaseUrl
});
expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler);
});
});
it("should use custom base URL if provided", () => {
const customBaseUrl = "https://custom.anthropic.com"
const handlerWithCustomUrl = new AnthropicHandler({
...mockOptions,
anthropicBaseUrl: customBaseUrl,
})
expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler)
})
})
describe('createMessage', () => {
const systemPrompt = 'You are a helpful assistant.';
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: [{
type: 'text' as const,
text: 'Hello!'
}]
}
];
describe("createMessage", () => {
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [
{
type: "text" as const,
text: "Hello!",
},
],
},
]
it('should handle prompt caching for supported models', async () => {
const stream = handler.createMessage(systemPrompt, [
{
role: 'user',
content: [{ type: 'text' as const, text: 'First message' }]
},
{
role: 'assistant',
content: [{ type: 'text' as const, text: 'Response' }]
},
{
role: 'user',
content: [{ type: 'text' as const, text: 'Second message' }]
}
]);
it("should handle prompt caching for supported models", async () => {
const stream = handler.createMessage(systemPrompt, [
{
role: "user",
content: [{ type: "text" as const, text: "First message" }],
},
{
role: "assistant",
content: [{ type: "text" as const, text: "Response" }],
},
{
role: "user",
content: [{ type: "text" as const, text: "Second message" }],
},
])
const chunks: any[] = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}
// Verify usage information
const usageChunk = chunks.find(chunk => chunk.type === 'usage');
expect(usageChunk).toBeDefined();
expect(usageChunk?.inputTokens).toBe(100);
expect(usageChunk?.outputTokens).toBe(50);
expect(usageChunk?.cacheWriteTokens).toBe(20);
expect(usageChunk?.cacheReadTokens).toBe(10);
// Verify usage information
const usageChunk = chunks.find((chunk) => chunk.type === "usage")
expect(usageChunk).toBeDefined()
expect(usageChunk?.inputTokens).toBe(100)
expect(usageChunk?.outputTokens).toBe(50)
expect(usageChunk?.cacheWriteTokens).toBe(20)
expect(usageChunk?.cacheReadTokens).toBe(10)
// Verify text content
const textChunks = chunks.filter(chunk => chunk.type === 'text');
expect(textChunks).toHaveLength(2);
expect(textChunks[0].text).toBe('Hello');
expect(textChunks[1].text).toBe(' world');
// Verify text content
const textChunks = chunks.filter((chunk) => chunk.type === "text")
expect(textChunks).toHaveLength(2)
expect(textChunks[0].text).toBe("Hello")
expect(textChunks[1].text).toBe(" world")
// Verify beta API was used
expect(mockBetaCreate).toHaveBeenCalled();
expect(mockCreate).not.toHaveBeenCalled();
});
});
// Verify beta API was used
expect(mockBetaCreate).toHaveBeenCalled()
expect(mockCreate).not.toHaveBeenCalled()
})
})
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.apiModelId,
messages: [{ role: 'user', content: 'Test prompt' }],
max_tokens: 8192,
temperature: 0,
stream: false
});
});
describe("completePrompt", () => {
it("should complete prompt successfully", async () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.apiModelId,
messages: [{ role: "user", content: "Test prompt" }],
max_tokens: 8192,
temperature: 0,
stream: false,
})
})
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Anthropic completion error: API Error');
});
it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error("API Error"))
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Anthropic completion error: API Error")
})
it('should handle non-text content', async () => {
mockCreate.mockImplementationOnce(async () => ({
content: [{ type: 'image' }]
}));
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
it("should handle non-text content", async () => {
mockCreate.mockImplementationOnce(async () => ({
content: [{ type: "image" }],
}))
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})
it('should handle empty response', async () => {
mockCreate.mockImplementationOnce(async () => ({
content: [{ type: 'text', text: '' }]
}));
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
it("should handle empty response", async () => {
mockCreate.mockImplementationOnce(async () => ({
content: [{ type: "text", text: "" }],
}))
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})
})
describe('getModel', () => {
it('should return default model if no model ID is provided', () => {
const handlerWithoutModel = new AnthropicHandler({
...mockOptions,
apiModelId: undefined
});
const model = handlerWithoutModel.getModel();
expect(model.id).toBeDefined();
expect(model.info).toBeDefined();
});
describe("getModel", () => {
it("should return default model if no model ID is provided", () => {
const handlerWithoutModel = new AnthropicHandler({
...mockOptions,
apiModelId: undefined,
})
const model = handlerWithoutModel.getModel()
expect(model.id).toBeDefined()
expect(model.info).toBeDefined()
})
it('should return specified model if valid model ID is provided', () => {
const model = handler.getModel();
expect(model.id).toBe(mockOptions.apiModelId);
expect(model.info).toBeDefined();
expect(model.info.maxTokens).toBe(8192);
expect(model.info.contextWindow).toBe(200_000);
expect(model.info.supportsImages).toBe(true);
expect(model.info.supportsPromptCache).toBe(true);
});
});
});
it("should return specified model if valid model ID is provided", () => {
const model = handler.getModel()
expect(model.id).toBe(mockOptions.apiModelId)
expect(model.info).toBeDefined()
expect(model.info.maxTokens).toBe(8192)
expect(model.info.contextWindow).toBe(200_000)
expect(model.info.supportsImages).toBe(true)
expect(model.info.supportsPromptCache).toBe(true)
})
})
})

View File

@@ -1,246 +1,259 @@
import { AwsBedrockHandler } from '../bedrock';
import { MessageContent } from '../../../shared/api';
import { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime';
import { Anthropic } from '@anthropic-ai/sdk';
import { AwsBedrockHandler } from "../bedrock"
import { MessageContent } from "../../../shared/api"
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
import { Anthropic } from "@anthropic-ai/sdk"
describe('AwsBedrockHandler', () => {
let handler: AwsBedrockHandler;
describe("AwsBedrockHandler", () => {
let handler: AwsBedrockHandler
beforeEach(() => {
handler = new AwsBedrockHandler({
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
awsAccessKey: 'test-access-key',
awsSecretKey: 'test-secret-key',
awsRegion: 'us-east-1'
});
});
beforeEach(() => {
handler = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
})
})
describe('constructor', () => {
it('should initialize with provided config', () => {
expect(handler['options'].awsAccessKey).toBe('test-access-key');
expect(handler['options'].awsSecretKey).toBe('test-secret-key');
expect(handler['options'].awsRegion).toBe('us-east-1');
expect(handler['options'].apiModelId).toBe('anthropic.claude-3-5-sonnet-20241022-v2:0');
});
describe("constructor", () => {
it("should initialize with provided config", () => {
expect(handler["options"].awsAccessKey).toBe("test-access-key")
expect(handler["options"].awsSecretKey).toBe("test-secret-key")
expect(handler["options"].awsRegion).toBe("us-east-1")
expect(handler["options"].apiModelId).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
})
it('should initialize with missing AWS credentials', () => {
const handlerWithoutCreds = new AwsBedrockHandler({
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
awsRegion: 'us-east-1'
});
expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler);
});
});
it("should initialize with missing AWS credentials", () => {
const handlerWithoutCreds = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsRegion: "us-east-1",
})
expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler)
})
})
describe('createMessage', () => {
const mockMessages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: 'Hello'
},
{
role: 'assistant',
content: 'Hi there!'
}
];
describe("createMessage", () => {
const mockMessages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello",
},
{
role: "assistant",
content: "Hi there!",
},
]
const systemPrompt = 'You are a helpful assistant';
const systemPrompt = "You are a helpful assistant"
it('should handle text messages correctly', async () => {
const mockResponse = {
messages: [{
role: 'assistant',
content: [{ type: 'text', text: 'Hello! How can I help you?' }]
}],
usage: {
input_tokens: 10,
output_tokens: 5
}
};
it("should handle text messages correctly", async () => {
const mockResponse = {
messages: [
{
role: "assistant",
content: [{ type: "text", text: "Hello! How can I help you?" }],
},
],
usage: {
input_tokens: 10,
output_tokens: 5,
},
}
// Mock AWS SDK invoke
const mockStream = {
[Symbol.asyncIterator]: async function* () {
yield {
metadata: {
usage: {
inputTokens: 10,
outputTokens: 5
}
}
};
}
};
// Mock AWS SDK invoke
const mockStream = {
[Symbol.asyncIterator]: async function* () {
yield {
metadata: {
usage: {
inputTokens: 10,
outputTokens: 5,
},
},
}
},
}
const mockInvoke = jest.fn().mockResolvedValue({
stream: mockStream
});
const mockInvoke = jest.fn().mockResolvedValue({
stream: mockStream,
})
handler['client'] = {
send: mockInvoke
} as unknown as BedrockRuntimeClient;
handler["client"] = {
send: mockInvoke,
} as unknown as BedrockRuntimeClient
const stream = handler.createMessage(systemPrompt, mockMessages);
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []
expect(chunks.length).toBeGreaterThan(0);
expect(chunks[0]).toEqual({
type: 'usage',
inputTokens: 10,
outputTokens: 5
});
for await (const chunk of stream) {
chunks.push(chunk)
}
expect(mockInvoke).toHaveBeenCalledWith(expect.objectContaining({
input: expect.objectContaining({
modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0'
})
}));
});
expect(chunks.length).toBeGreaterThan(0)
expect(chunks[0]).toEqual({
type: "usage",
inputTokens: 10,
outputTokens: 5,
})
it('should handle API errors', async () => {
// Mock AWS SDK invoke with error
const mockInvoke = jest.fn().mockRejectedValue(new Error('AWS Bedrock error'));
expect(mockInvoke).toHaveBeenCalledWith(
expect.objectContaining({
input: expect.objectContaining({
modelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
}),
}),
)
})
handler['client'] = {
send: mockInvoke
} as unknown as BedrockRuntimeClient;
it("should handle API errors", async () => {
// Mock AWS SDK invoke with error
const mockInvoke = jest.fn().mockRejectedValue(new Error("AWS Bedrock error"))
const stream = handler.createMessage(systemPrompt, mockMessages);
handler["client"] = {
send: mockInvoke,
} as unknown as BedrockRuntimeClient
await expect(async () => {
for await (const chunk of stream) {
// Should throw before yielding any chunks
}
}).rejects.toThrow('AWS Bedrock error');
});
});
const stream = handler.createMessage(systemPrompt, mockMessages)
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const mockResponse = {
output: new TextEncoder().encode(JSON.stringify({
content: 'Test response'
}))
};
await expect(async () => {
for await (const chunk of stream) {
// Should throw before yielding any chunks
}
}).rejects.toThrow("AWS Bedrock error")
})
})
const mockSend = jest.fn().mockResolvedValue(mockResponse);
handler['client'] = {
send: mockSend
} as unknown as BedrockRuntimeClient;
describe("completePrompt", () => {
it("should complete prompt successfully", async () => {
const mockResponse = {
output: new TextEncoder().encode(
JSON.stringify({
content: "Test response",
}),
),
}
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
input: expect.objectContaining({
modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
messages: expect.arrayContaining([
expect.objectContaining({
role: 'user',
content: [{ text: 'Test prompt' }]
})
]),
inferenceConfig: expect.objectContaining({
maxTokens: 5000,
temperature: 0.3,
topP: 0.1
})
})
}));
});
const mockSend = jest.fn().mockResolvedValue(mockResponse)
handler["client"] = {
send: mockSend,
} as unknown as BedrockRuntimeClient
it('should handle API errors', async () => {
const mockError = new Error('AWS Bedrock error');
const mockSend = jest.fn().mockRejectedValue(mockError);
handler['client'] = {
send: mockSend
} as unknown as BedrockRuntimeClient;
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockSend).toHaveBeenCalledWith(
expect.objectContaining({
input: expect.objectContaining({
modelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
messages: expect.arrayContaining([
expect.objectContaining({
role: "user",
content: [{ text: "Test prompt" }],
}),
]),
inferenceConfig: expect.objectContaining({
maxTokens: 5000,
temperature: 0.3,
topP: 0.1,
}),
}),
}),
)
})
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Bedrock completion error: AWS Bedrock error');
});
it("should handle API errors", async () => {
const mockError = new Error("AWS Bedrock error")
const mockSend = jest.fn().mockRejectedValue(mockError)
handler["client"] = {
send: mockSend,
} as unknown as BedrockRuntimeClient
it('should handle invalid response format', async () => {
const mockResponse = {
output: new TextEncoder().encode('invalid json')
};
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
"Bedrock completion error: AWS Bedrock error",
)
})
const mockSend = jest.fn().mockResolvedValue(mockResponse);
handler['client'] = {
send: mockSend
} as unknown as BedrockRuntimeClient;
it("should handle invalid response format", async () => {
const mockResponse = {
output: new TextEncoder().encode("invalid json"),
}
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
const mockSend = jest.fn().mockResolvedValue(mockResponse)
handler["client"] = {
send: mockSend,
} as unknown as BedrockRuntimeClient
it('should handle empty response', async () => {
const mockResponse = {
output: new TextEncoder().encode(JSON.stringify({}))
};
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})
const mockSend = jest.fn().mockResolvedValue(mockResponse);
handler['client'] = {
send: mockSend
} as unknown as BedrockRuntimeClient;
it("should handle empty response", async () => {
const mockResponse = {
output: new TextEncoder().encode(JSON.stringify({})),
}
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
const mockSend = jest.fn().mockResolvedValue(mockResponse)
handler["client"] = {
send: mockSend,
} as unknown as BedrockRuntimeClient
it('should handle cross-region inference', async () => {
handler = new AwsBedrockHandler({
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
awsAccessKey: 'test-access-key',
awsSecretKey: 'test-secret-key',
awsRegion: 'us-east-1',
awsUseCrossRegionInference: true
});
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})
const mockResponse = {
output: new TextEncoder().encode(JSON.stringify({
content: 'Test response'
}))
};
it("should handle cross-region inference", async () => {
handler = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
awsUseCrossRegionInference: true,
})
const mockSend = jest.fn().mockResolvedValue(mockResponse);
handler['client'] = {
send: mockSend
} as unknown as BedrockRuntimeClient;
const mockResponse = {
output: new TextEncoder().encode(
JSON.stringify({
content: "Test response",
}),
),
}
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
input: expect.objectContaining({
modelId: 'us.anthropic.claude-3-5-sonnet-20241022-v2:0'
})
}));
});
});
const mockSend = jest.fn().mockResolvedValue(mockResponse)
handler["client"] = {
send: mockSend,
} as unknown as BedrockRuntimeClient
describe('getModel', () => {
it('should return correct model info in test environment', () => {
const modelInfo = handler.getModel();
expect(modelInfo.id).toBe('anthropic.claude-3-5-sonnet-20241022-v2:0');
expect(modelInfo.info).toBeDefined();
expect(modelInfo.info.maxTokens).toBe(5000); // Test environment value
expect(modelInfo.info.contextWindow).toBe(128_000); // Test environment value
});
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockSend).toHaveBeenCalledWith(
expect.objectContaining({
input: expect.objectContaining({
modelId: "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
}),
}),
)
})
})
it('should return test model info for invalid model in test environment', () => {
const invalidHandler = new AwsBedrockHandler({
apiModelId: 'invalid-model',
awsAccessKey: 'test-access-key',
awsSecretKey: 'test-secret-key',
awsRegion: 'us-east-1'
});
const modelInfo = invalidHandler.getModel();
expect(modelInfo.id).toBe('invalid-model'); // In test env, returns whatever is passed
expect(modelInfo.info.maxTokens).toBe(5000);
expect(modelInfo.info.contextWindow).toBe(128_000);
});
});
});
describe("getModel", () => {
it("should return correct model info in test environment", () => {
const modelInfo = handler.getModel()
expect(modelInfo.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
expect(modelInfo.info).toBeDefined()
expect(modelInfo.info.maxTokens).toBe(5000) // Test environment value
expect(modelInfo.info.contextWindow).toBe(128_000) // Test environment value
})
it("should return test model info for invalid model in test environment", () => {
const invalidHandler = new AwsBedrockHandler({
apiModelId: "invalid-model",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
})
const modelInfo = invalidHandler.getModel()
expect(modelInfo.id).toBe("invalid-model") // In test env, returns whatever is passed
expect(modelInfo.info.maxTokens).toBe(5000)
expect(modelInfo.info.contextWindow).toBe(128_000)
})
})
})

View File

@@ -1,203 +1,217 @@
import { DeepSeekHandler } from '../deepseek';
import { ApiHandlerOptions, deepSeekDefaultModelId } from '../../../shared/api';
import OpenAI from 'openai';
import { Anthropic } from '@anthropic-ai/sdk';
import { DeepSeekHandler } from "../deepseek"
import { ApiHandlerOptions, deepSeekDefaultModelId } from "../../../shared/api"
import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"
// Mock OpenAI client
const mockCreate = jest.fn();
jest.mock('openai', () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: 'test-completion',
choices: [{
message: { role: 'assistant', content: 'Test response', refusal: null },
finish_reason: 'stop',
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
// Return async iterator for streaming
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [{
delta: { content: 'Test response' },
index: 0
}],
usage: null
};
yield {
choices: [{
delta: {},
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
};
})
}
}
}))
};
});
const mockCreate = jest.fn()
jest.mock("openai", () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: "test-completion",
choices: [
{
message: { role: "assistant", content: "Test response", refusal: null },
finish_reason: "stop",
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}
}
describe('DeepSeekHandler', () => {
let handler: DeepSeekHandler;
let mockOptions: ApiHandlerOptions;
// Return async iterator for streaming
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [
{
delta: { content: "Test response" },
index: 0,
},
],
usage: null,
}
yield {
choices: [
{
delta: {},
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}
},
}
}),
},
},
})),
}
})
beforeEach(() => {
mockOptions = {
deepSeekApiKey: 'test-api-key',
deepSeekModelId: 'deepseek-chat',
deepSeekBaseUrl: 'https://api.deepseek.com/v1'
};
handler = new DeepSeekHandler(mockOptions);
mockCreate.mockClear();
});
describe("DeepSeekHandler", () => {
let handler: DeepSeekHandler
let mockOptions: ApiHandlerOptions
describe('constructor', () => {
it('should initialize with provided options', () => {
expect(handler).toBeInstanceOf(DeepSeekHandler);
expect(handler.getModel().id).toBe(mockOptions.deepSeekModelId);
});
beforeEach(() => {
mockOptions = {
deepSeekApiKey: "test-api-key",
deepSeekModelId: "deepseek-chat",
deepSeekBaseUrl: "https://api.deepseek.com/v1",
}
handler = new DeepSeekHandler(mockOptions)
mockCreate.mockClear()
})
it('should throw error if API key is missing', () => {
expect(() => {
new DeepSeekHandler({
...mockOptions,
deepSeekApiKey: undefined
});
}).toThrow('DeepSeek API key is required');
});
describe("constructor", () => {
it("should initialize with provided options", () => {
expect(handler).toBeInstanceOf(DeepSeekHandler)
expect(handler.getModel().id).toBe(mockOptions.deepSeekModelId)
})
it('should use default model ID if not provided', () => {
const handlerWithoutModel = new DeepSeekHandler({
...mockOptions,
deepSeekModelId: undefined
});
expect(handlerWithoutModel.getModel().id).toBe(deepSeekDefaultModelId);
});
it("should throw error if API key is missing", () => {
expect(() => {
new DeepSeekHandler({
...mockOptions,
deepSeekApiKey: undefined,
})
}).toThrow("DeepSeek API key is required")
})
it('should use default base URL if not provided', () => {
const handlerWithoutBaseUrl = new DeepSeekHandler({
...mockOptions,
deepSeekBaseUrl: undefined
});
expect(handlerWithoutBaseUrl).toBeInstanceOf(DeepSeekHandler);
// The base URL is passed to OpenAI client internally
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({
baseURL: 'https://api.deepseek.com/v1'
}));
});
it("should use default model ID if not provided", () => {
const handlerWithoutModel = new DeepSeekHandler({
...mockOptions,
deepSeekModelId: undefined,
})
expect(handlerWithoutModel.getModel().id).toBe(deepSeekDefaultModelId)
})
it('should use custom base URL if provided', () => {
const customBaseUrl = 'https://custom.deepseek.com/v1';
const handlerWithCustomUrl = new DeepSeekHandler({
...mockOptions,
deepSeekBaseUrl: customBaseUrl
});
expect(handlerWithCustomUrl).toBeInstanceOf(DeepSeekHandler);
// The custom base URL is passed to OpenAI client
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({
baseURL: customBaseUrl
}));
});
it("should use default base URL if not provided", () => {
const handlerWithoutBaseUrl = new DeepSeekHandler({
...mockOptions,
deepSeekBaseUrl: undefined,
})
expect(handlerWithoutBaseUrl).toBeInstanceOf(DeepSeekHandler)
// The base URL is passed to OpenAI client internally
expect(OpenAI).toHaveBeenCalledWith(
expect.objectContaining({
baseURL: "https://api.deepseek.com/v1",
}),
)
})
it('should set includeMaxTokens to true', () => {
// Create a new handler and verify OpenAI client was called with includeMaxTokens
new DeepSeekHandler(mockOptions);
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({
apiKey: mockOptions.deepSeekApiKey
}));
});
});
it("should use custom base URL if provided", () => {
const customBaseUrl = "https://custom.deepseek.com/v1"
const handlerWithCustomUrl = new DeepSeekHandler({
...mockOptions,
deepSeekBaseUrl: customBaseUrl,
})
expect(handlerWithCustomUrl).toBeInstanceOf(DeepSeekHandler)
// The custom base URL is passed to OpenAI client
expect(OpenAI).toHaveBeenCalledWith(
expect.objectContaining({
baseURL: customBaseUrl,
}),
)
})
describe('getModel', () => {
it('should return model info for valid model ID', () => {
const model = handler.getModel();
expect(model.id).toBe(mockOptions.deepSeekModelId);
expect(model.info).toBeDefined();
expect(model.info.maxTokens).toBe(8192);
expect(model.info.contextWindow).toBe(64_000);
expect(model.info.supportsImages).toBe(false);
expect(model.info.supportsPromptCache).toBe(false);
});
it("should set includeMaxTokens to true", () => {
// Create a new handler and verify OpenAI client was called with includeMaxTokens
new DeepSeekHandler(mockOptions)
expect(OpenAI).toHaveBeenCalledWith(
expect.objectContaining({
apiKey: mockOptions.deepSeekApiKey,
}),
)
})
})
it('should return provided model ID with default model info if model does not exist', () => {
const handlerWithInvalidModel = new DeepSeekHandler({
...mockOptions,
deepSeekModelId: 'invalid-model'
});
const model = handlerWithInvalidModel.getModel();
expect(model.id).toBe('invalid-model'); // Returns provided ID
expect(model.info).toBeDefined();
expect(model.info).toBe(handler.getModel().info); // But uses default model info
});
describe("getModel", () => {
it("should return model info for valid model ID", () => {
const model = handler.getModel()
expect(model.id).toBe(mockOptions.deepSeekModelId)
expect(model.info).toBeDefined()
expect(model.info.maxTokens).toBe(8192)
expect(model.info.contextWindow).toBe(64_000)
expect(model.info.supportsImages).toBe(false)
expect(model.info.supportsPromptCache).toBe(false)
})
it('should return default model if no model ID is provided', () => {
const handlerWithoutModel = new DeepSeekHandler({
...mockOptions,
deepSeekModelId: undefined
});
const model = handlerWithoutModel.getModel();
expect(model.id).toBe(deepSeekDefaultModelId);
expect(model.info).toBeDefined();
});
});
it("should return provided model ID with default model info if model does not exist", () => {
const handlerWithInvalidModel = new DeepSeekHandler({
...mockOptions,
deepSeekModelId: "invalid-model",
})
const model = handlerWithInvalidModel.getModel()
expect(model.id).toBe("invalid-model") // Returns provided ID
expect(model.info).toBeDefined()
expect(model.info).toBe(handler.getModel().info) // But uses default model info
})
describe('createMessage', () => {
const systemPrompt = 'You are a helpful assistant.';
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: [{
type: 'text' as const,
text: 'Hello!'
}]
}
];
it("should return default model if no model ID is provided", () => {
const handlerWithoutModel = new DeepSeekHandler({
...mockOptions,
deepSeekModelId: undefined,
})
const model = handlerWithoutModel.getModel()
expect(model.id).toBe(deepSeekDefaultModelId)
expect(model.info).toBeDefined()
})
})
it('should handle streaming responses', async () => {
const stream = handler.createMessage(systemPrompt, messages);
const chunks: any[] = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
describe("createMessage", () => {
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [
{
type: "text" as const,
text: "Hello!",
},
],
},
]
expect(chunks.length).toBeGreaterThan(0);
const textChunks = chunks.filter(chunk => chunk.type === 'text');
expect(textChunks).toHaveLength(1);
expect(textChunks[0].text).toBe('Test response');
});
it("should handle streaming responses", async () => {
const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}
it('should include usage information', async () => {
const stream = handler.createMessage(systemPrompt, messages);
const chunks: any[] = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
expect(chunks.length).toBeGreaterThan(0)
const textChunks = chunks.filter((chunk) => chunk.type === "text")
expect(textChunks).toHaveLength(1)
expect(textChunks[0].text).toBe("Test response")
})
const usageChunks = chunks.filter(chunk => chunk.type === 'usage');
expect(usageChunks.length).toBeGreaterThan(0);
expect(usageChunks[0].inputTokens).toBe(10);
expect(usageChunks[0].outputTokens).toBe(5);
});
});
});
it("should include usage information", async () => {
const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}
const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
expect(usageChunks.length).toBeGreaterThan(0)
expect(usageChunks[0].inputTokens).toBe(10)
expect(usageChunks[0].outputTokens).toBe(5)
})
})
})

View File

@@ -1,212 +1,210 @@
import { GeminiHandler } from '../gemini';
import { Anthropic } from '@anthropic-ai/sdk';
import { GoogleGenerativeAI } from '@google/generative-ai';
import { GeminiHandler } from "../gemini"
import { Anthropic } from "@anthropic-ai/sdk"
import { GoogleGenerativeAI } from "@google/generative-ai"
// Mock the Google Generative AI SDK
jest.mock('@google/generative-ai', () => ({
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
getGenerativeModel: jest.fn().mockReturnValue({
generateContentStream: jest.fn(),
generateContent: jest.fn().mockResolvedValue({
response: {
text: () => 'Test response'
}
})
})
}))
}));
jest.mock("@google/generative-ai", () => ({
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
getGenerativeModel: jest.fn().mockReturnValue({
generateContentStream: jest.fn(),
generateContent: jest.fn().mockResolvedValue({
response: {
text: () => "Test response",
},
}),
}),
})),
}))
describe('GeminiHandler', () => {
let handler: GeminiHandler;
describe("GeminiHandler", () => {
let handler: GeminiHandler
beforeEach(() => {
handler = new GeminiHandler({
apiKey: 'test-key',
apiModelId: 'gemini-2.0-flash-thinking-exp-1219',
geminiApiKey: 'test-key'
});
});
beforeEach(() => {
handler = new GeminiHandler({
apiKey: "test-key",
apiModelId: "gemini-2.0-flash-thinking-exp-1219",
geminiApiKey: "test-key",
})
})
describe('constructor', () => {
it('should initialize with provided config', () => {
expect(handler['options'].geminiApiKey).toBe('test-key');
expect(handler['options'].apiModelId).toBe('gemini-2.0-flash-thinking-exp-1219');
});
describe("constructor", () => {
it("should initialize with provided config", () => {
expect(handler["options"].geminiApiKey).toBe("test-key")
expect(handler["options"].apiModelId).toBe("gemini-2.0-flash-thinking-exp-1219")
})
it('should throw if API key is missing', () => {
expect(() => {
new GeminiHandler({
apiModelId: 'gemini-2.0-flash-thinking-exp-1219',
geminiApiKey: ''
});
}).toThrow('API key is required for Google Gemini');
});
});
it("should throw if API key is missing", () => {
expect(() => {
new GeminiHandler({
apiModelId: "gemini-2.0-flash-thinking-exp-1219",
geminiApiKey: "",
})
}).toThrow("API key is required for Google Gemini")
})
})
describe('createMessage', () => {
const mockMessages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: 'Hello'
},
{
role: 'assistant',
content: 'Hi there!'
}
];
describe("createMessage", () => {
const mockMessages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello",
},
{
role: "assistant",
content: "Hi there!",
},
]
const systemPrompt = 'You are a helpful assistant';
const systemPrompt = "You are a helpful assistant"
it('should handle text messages correctly', async () => {
// Mock the stream response
const mockStream = {
stream: [
{ text: () => 'Hello' },
{ text: () => ' world!' }
],
response: {
usageMetadata: {
promptTokenCount: 10,
candidatesTokenCount: 5
}
}
};
it("should handle text messages correctly", async () => {
// Mock the stream response
const mockStream = {
stream: [{ text: () => "Hello" }, { text: () => " world!" }],
response: {
usageMetadata: {
promptTokenCount: 10,
candidatesTokenCount: 5,
},
},
}
// Setup the mock implementation
const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream);
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContentStream: mockGenerateContentStream
});
// Setup the mock implementation
const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream)
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContentStream: mockGenerateContentStream,
})
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
const stream = handler.createMessage(systemPrompt, mockMessages);
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []
// Should have 3 chunks: 'Hello', ' world!', and usage info
expect(chunks.length).toBe(3);
expect(chunks[0]).toEqual({
type: 'text',
text: 'Hello'
});
expect(chunks[1]).toEqual({
type: 'text',
text: ' world!'
});
expect(chunks[2]).toEqual({
type: 'usage',
inputTokens: 10,
outputTokens: 5
});
for await (const chunk of stream) {
chunks.push(chunk)
}
// Verify the model configuration
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
model: 'gemini-2.0-flash-thinking-exp-1219',
systemInstruction: systemPrompt
});
// Should have 3 chunks: 'Hello', ' world!', and usage info
expect(chunks.length).toBe(3)
expect(chunks[0]).toEqual({
type: "text",
text: "Hello",
})
expect(chunks[1]).toEqual({
type: "text",
text: " world!",
})
expect(chunks[2]).toEqual({
type: "usage",
inputTokens: 10,
outputTokens: 5,
})
// Verify generation config
expect(mockGenerateContentStream).toHaveBeenCalledWith(
expect.objectContaining({
generationConfig: {
temperature: 0
}
})
);
});
// Verify the model configuration
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
model: "gemini-2.0-flash-thinking-exp-1219",
systemInstruction: systemPrompt,
})
it('should handle API errors', async () => {
const mockError = new Error('Gemini API error');
const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError);
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContentStream: mockGenerateContentStream
});
// Verify generation config
expect(mockGenerateContentStream).toHaveBeenCalledWith(
expect.objectContaining({
generationConfig: {
temperature: 0,
},
}),
)
})
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
it("should handle API errors", async () => {
const mockError = new Error("Gemini API error")
const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError)
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContentStream: mockGenerateContentStream,
})
const stream = handler.createMessage(systemPrompt, mockMessages);
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
await expect(async () => {
for await (const chunk of stream) {
// Should throw before yielding any chunks
}
}).rejects.toThrow('Gemini API error');
});
});
const stream = handler.createMessage(systemPrompt, mockMessages)
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const mockGenerateContent = jest.fn().mockResolvedValue({
response: {
text: () => 'Test response'
}
});
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent
});
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
await expect(async () => {
for await (const chunk of stream) {
// Should throw before yielding any chunks
}
}).rejects.toThrow("Gemini API error")
})
})
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
model: 'gemini-2.0-flash-thinking-exp-1219'
});
expect(mockGenerateContent).toHaveBeenCalledWith({
contents: [{ role: 'user', parts: [{ text: 'Test prompt' }] }],
generationConfig: {
temperature: 0
}
});
});
describe("completePrompt", () => {
it("should complete prompt successfully", async () => {
const mockGenerateContent = jest.fn().mockResolvedValue({
response: {
text: () => "Test response",
},
})
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent,
})
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
it('should handle API errors', async () => {
const mockError = new Error('Gemini API error');
const mockGenerateContent = jest.fn().mockRejectedValue(mockError);
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent
});
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
model: "gemini-2.0-flash-thinking-exp-1219",
})
expect(mockGenerateContent).toHaveBeenCalledWith({
contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
generationConfig: {
temperature: 0,
},
})
})
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Gemini completion error: Gemini API error');
});
it("should handle API errors", async () => {
const mockError = new Error("Gemini API error")
const mockGenerateContent = jest.fn().mockRejectedValue(mockError)
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent,
})
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
it('should handle empty response', async () => {
const mockGenerateContent = jest.fn().mockResolvedValue({
response: {
text: () => ''
}
});
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent
});
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
"Gemini completion error: Gemini API error",
)
})
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
it("should handle empty response", async () => {
const mockGenerateContent = jest.fn().mockResolvedValue({
response: {
text: () => "",
},
})
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent,
})
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
describe('getModel', () => {
it('should return correct model info', () => {
const modelInfo = handler.getModel();
expect(modelInfo.id).toBe('gemini-2.0-flash-thinking-exp-1219');
expect(modelInfo.info).toBeDefined();
expect(modelInfo.info.maxTokens).toBe(8192);
expect(modelInfo.info.contextWindow).toBe(32_767);
});
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})
})
it('should return default model if invalid model specified', () => {
const invalidHandler = new GeminiHandler({
apiModelId: 'invalid-model',
geminiApiKey: 'test-key'
});
const modelInfo = invalidHandler.getModel();
expect(modelInfo.id).toBe('gemini-2.0-flash-thinking-exp-1219'); // Default model
});
});
});
describe("getModel", () => {
it("should return correct model info", () => {
const modelInfo = handler.getModel()
expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219")
expect(modelInfo.info).toBeDefined()
expect(modelInfo.info.maxTokens).toBe(8192)
expect(modelInfo.info.contextWindow).toBe(32_767)
})
it("should return default model if invalid model specified", () => {
const invalidHandler = new GeminiHandler({
apiModelId: "invalid-model",
geminiApiKey: "test-key",
})
const modelInfo = invalidHandler.getModel()
expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219") // Default model
})
})
})

View File

@@ -1,226 +1,238 @@
import { GlamaHandler } from '../glama';
import { ApiHandlerOptions } from '../../../shared/api';
import OpenAI from 'openai';
import { Anthropic } from '@anthropic-ai/sdk';
import axios from 'axios';
import { GlamaHandler } from "../glama"
import { ApiHandlerOptions } from "../../../shared/api"
import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"
import axios from "axios"
// Mock OpenAI client
const mockCreate = jest.fn();
const mockWithResponse = jest.fn();
const mockCreate = jest.fn()
const mockWithResponse = jest.fn()
jest.mock('openai', () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: (...args: any[]) => {
const stream = {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [{
delta: { content: 'Test response' },
index: 0
}],
usage: null
};
yield {
choices: [{
delta: {},
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
};
jest.mock("openai", () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: (...args: any[]) => {
const stream = {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [
{
delta: { content: "Test response" },
index: 0,
},
],
usage: null,
}
yield {
choices: [
{
delta: {},
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}
},
}
const result = mockCreate(...args);
if (args[0].stream) {
mockWithResponse.mockReturnValue(Promise.resolve({
data: stream,
response: {
headers: {
get: (name: string) => name === 'x-completion-request-id' ? 'test-request-id' : null
}
}
}));
result.withResponse = mockWithResponse;
}
return result;
}
}
}
}))
};
});
const result = mockCreate(...args)
if (args[0].stream) {
mockWithResponse.mockReturnValue(
Promise.resolve({
data: stream,
response: {
headers: {
get: (name: string) =>
name === "x-completion-request-id" ? "test-request-id" : null,
},
},
}),
)
result.withResponse = mockWithResponse
}
return result
},
},
},
})),
}
})
describe('GlamaHandler', () => {
let handler: GlamaHandler;
let mockOptions: ApiHandlerOptions;
describe("GlamaHandler", () => {
let handler: GlamaHandler
let mockOptions: ApiHandlerOptions
beforeEach(() => {
mockOptions = {
apiModelId: 'anthropic/claude-3-5-sonnet',
glamaModelId: 'anthropic/claude-3-5-sonnet',
glamaApiKey: 'test-api-key'
};
handler = new GlamaHandler(mockOptions);
mockCreate.mockClear();
mockWithResponse.mockClear();
beforeEach(() => {
mockOptions = {
apiModelId: "anthropic/claude-3-5-sonnet",
glamaModelId: "anthropic/claude-3-5-sonnet",
glamaApiKey: "test-api-key",
}
handler = new GlamaHandler(mockOptions)
mockCreate.mockClear()
mockWithResponse.mockClear()
// Default mock implementation for non-streaming responses
mockCreate.mockResolvedValue({
id: 'test-completion',
choices: [{
message: { role: 'assistant', content: 'Test response' },
finish_reason: 'stop',
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
});
});
// Default mock implementation for non-streaming responses
mockCreate.mockResolvedValue({
id: "test-completion",
choices: [
{
message: { role: "assistant", content: "Test response" },
finish_reason: "stop",
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
})
})
describe('constructor', () => {
it('should initialize with provided options', () => {
expect(handler).toBeInstanceOf(GlamaHandler);
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
});
});
describe("constructor", () => {
it("should initialize with provided options", () => {
expect(handler).toBeInstanceOf(GlamaHandler)
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
})
})
describe('createMessage', () => {
const systemPrompt = 'You are a helpful assistant.';
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: 'Hello!'
}
];
describe("createMessage", () => {
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello!",
},
]
it('should handle streaming responses', async () => {
// Mock axios for token usage request
const mockAxios = jest.spyOn(axios, 'get').mockResolvedValueOnce({
data: {
tokenUsage: {
promptTokens: 10,
completionTokens: 5,
cacheCreationInputTokens: 0,
cacheReadInputTokens: 0
},
totalCostUsd: "0.00"
}
});
it("should handle streaming responses", async () => {
// Mock axios for token usage request
const mockAxios = jest.spyOn(axios, "get").mockResolvedValueOnce({
data: {
tokenUsage: {
promptTokens: 10,
completionTokens: 5,
cacheCreationInputTokens: 0,
cacheReadInputTokens: 0,
},
totalCostUsd: "0.00",
},
})
const stream = handler.createMessage(systemPrompt, messages);
const chunks: any[] = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}
expect(chunks.length).toBe(2); // Text chunk and usage chunk
expect(chunks[0]).toEqual({
type: 'text',
text: 'Test response'
});
expect(chunks[1]).toEqual({
type: 'usage',
inputTokens: 10,
outputTokens: 5,
cacheWriteTokens: 0,
cacheReadTokens: 0,
totalCost: 0
});
expect(chunks.length).toBe(2) // Text chunk and usage chunk
expect(chunks[0]).toEqual({
type: "text",
text: "Test response",
})
expect(chunks[1]).toEqual({
type: "usage",
inputTokens: 10,
outputTokens: 5,
cacheWriteTokens: 0,
cacheReadTokens: 0,
totalCost: 0,
})
mockAxios.mockRestore();
});
mockAxios.mockRestore()
})
it('should handle API errors', async () => {
mockCreate.mockImplementationOnce(() => {
throw new Error('API Error');
});
it("should handle API errors", async () => {
mockCreate.mockImplementationOnce(() => {
throw new Error("API Error")
})
const stream = handler.createMessage(systemPrompt, messages);
const chunks = [];
const stream = handler.createMessage(systemPrompt, messages)
const chunks = []
try {
for await (const chunk of stream) {
chunks.push(chunk);
}
fail('Expected error to be thrown');
} catch (error) {
expect(error).toBeInstanceOf(Error);
expect(error.message).toBe('API Error');
}
});
});
try {
for await (const chunk of stream) {
chunks.push(chunk)
}
fail("Expected error to be thrown")
} catch (error) {
expect(error).toBeInstanceOf(Error)
expect(error.message).toBe("API Error")
}
})
})
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
model: mockOptions.apiModelId,
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0,
max_tokens: 8192
}));
});
describe("completePrompt", () => {
it("should complete prompt successfully", async () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: mockOptions.apiModelId,
messages: [{ role: "user", content: "Test prompt" }],
temperature: 0,
max_tokens: 8192,
}),
)
})
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Glama completion error: API Error');
});
it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error("API Error"))
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Glama completion error: API Error")
})
it('should handle empty response', async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: '' } }]
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
it("should handle empty response", async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: "" } }],
})
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})
it('should not set max_tokens for non-Anthropic models', async () => {
// Reset mock to clear any previous calls
mockCreate.mockClear();
const nonAnthropicOptions = {
apiModelId: 'openai/gpt-4',
glamaModelId: 'openai/gpt-4',
glamaApiKey: 'test-key',
glamaModelInfo: {
maxTokens: 4096,
contextWindow: 8192,
supportsImages: true,
supportsPromptCache: false
}
};
const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions);
it("should not set max_tokens for non-Anthropic models", async () => {
// Reset mock to clear any previous calls
mockCreate.mockClear()
await nonAnthropicHandler.completePrompt('Test prompt');
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
model: 'openai/gpt-4',
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0
}));
expect(mockCreate.mock.calls[0][0]).not.toHaveProperty('max_tokens');
});
});
const nonAnthropicOptions = {
apiModelId: "openai/gpt-4",
glamaModelId: "openai/gpt-4",
glamaApiKey: "test-key",
glamaModelInfo: {
maxTokens: 4096,
contextWindow: 8192,
supportsImages: true,
supportsPromptCache: false,
},
}
const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions)
describe('getModel', () => {
it('should return model info', () => {
const modelInfo = handler.getModel();
expect(modelInfo.id).toBe(mockOptions.apiModelId);
expect(modelInfo.info).toBeDefined();
expect(modelInfo.info.maxTokens).toBe(8192);
expect(modelInfo.info.contextWindow).toBe(200_000);
});
});
});
await nonAnthropicHandler.completePrompt("Test prompt")
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: "openai/gpt-4",
messages: [{ role: "user", content: "Test prompt" }],
temperature: 0,
}),
)
expect(mockCreate.mock.calls[0][0]).not.toHaveProperty("max_tokens")
})
})
describe("getModel", () => {
it("should return model info", () => {
const modelInfo = handler.getModel()
expect(modelInfo.id).toBe(mockOptions.apiModelId)
expect(modelInfo.info).toBeDefined()
expect(modelInfo.info.maxTokens).toBe(8192)
expect(modelInfo.info.contextWindow).toBe(200_000)
})
})
})

View File

@@ -1,160 +1,167 @@
import { LmStudioHandler } from '../lmstudio';
import { ApiHandlerOptions } from '../../../shared/api';
import OpenAI from 'openai';
import { Anthropic } from '@anthropic-ai/sdk';
import { LmStudioHandler } from "../lmstudio"
import { ApiHandlerOptions } from "../../../shared/api"
import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"
// Mock OpenAI client
const mockCreate = jest.fn();
jest.mock('openai', () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: 'test-completion',
choices: [{
message: { role: 'assistant', content: 'Test response' },
finish_reason: 'stop',
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [{
delta: { content: 'Test response' },
index: 0
}],
usage: null
};
yield {
choices: [{
delta: {},
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
};
})
}
}
}))
};
});
const mockCreate = jest.fn()
jest.mock("openai", () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: "test-completion",
choices: [
{
message: { role: "assistant", content: "Test response" },
finish_reason: "stop",
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}
}
describe('LmStudioHandler', () => {
let handler: LmStudioHandler;
let mockOptions: ApiHandlerOptions;
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [
{
delta: { content: "Test response" },
index: 0,
},
],
usage: null,
}
yield {
choices: [
{
delta: {},
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}
},
}
}),
},
},
})),
}
})
beforeEach(() => {
mockOptions = {
apiModelId: 'local-model',
lmStudioModelId: 'local-model',
lmStudioBaseUrl: 'http://localhost:1234/v1'
};
handler = new LmStudioHandler(mockOptions);
mockCreate.mockClear();
});
describe("LmStudioHandler", () => {
let handler: LmStudioHandler
let mockOptions: ApiHandlerOptions
describe('constructor', () => {
it('should initialize with provided options', () => {
expect(handler).toBeInstanceOf(LmStudioHandler);
expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId);
});
beforeEach(() => {
mockOptions = {
apiModelId: "local-model",
lmStudioModelId: "local-model",
lmStudioBaseUrl: "http://localhost:1234/v1",
}
handler = new LmStudioHandler(mockOptions)
mockCreate.mockClear()
})
it('should use default base URL if not provided', () => {
const handlerWithoutUrl = new LmStudioHandler({
apiModelId: 'local-model',
lmStudioModelId: 'local-model'
});
expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler);
});
});
describe("constructor", () => {
it("should initialize with provided options", () => {
expect(handler).toBeInstanceOf(LmStudioHandler)
expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId)
})
describe('createMessage', () => {
const systemPrompt = 'You are a helpful assistant.';
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: 'Hello!'
}
];
it("should use default base URL if not provided", () => {
const handlerWithoutUrl = new LmStudioHandler({
apiModelId: "local-model",
lmStudioModelId: "local-model",
})
expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler)
})
})
it('should handle streaming responses', async () => {
const stream = handler.createMessage(systemPrompt, messages);
const chunks: any[] = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
describe("createMessage", () => {
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello!",
},
]
expect(chunks.length).toBeGreaterThan(0);
const textChunks = chunks.filter(chunk => chunk.type === 'text');
expect(textChunks).toHaveLength(1);
expect(textChunks[0].text).toBe('Test response');
});
it("should handle streaming responses", async () => {
const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
expect(chunks.length).toBeGreaterThan(0)
const textChunks = chunks.filter((chunk) => chunk.type === "text")
expect(textChunks).toHaveLength(1)
expect(textChunks[0].text).toBe("Test response")
})
const stream = handler.createMessage(systemPrompt, messages);
it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error("API Error"))
await expect(async () => {
for await (const chunk of stream) {
// Should not reach here
}
}).rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
});
});
const stream = handler.createMessage(systemPrompt, messages)
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.lmStudioModelId,
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0,
stream: false
});
});
await expect(async () => {
for await (const chunk of stream) {
// Should not reach here
}
}).rejects.toThrow("Please check the LM Studio developer logs to debug what went wrong")
})
})
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
});
describe("completePrompt", () => {
it("should complete prompt successfully", async () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.lmStudioModelId,
messages: [{ role: "user", content: "Test prompt" }],
temperature: 0,
stream: false,
})
})
it('should handle empty response', async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: '' } }]
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error("API Error"))
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
"Please check the LM Studio developer logs to debug what went wrong",
)
})
describe('getModel', () => {
it('should return model info', () => {
const modelInfo = handler.getModel();
expect(modelInfo.id).toBe(mockOptions.lmStudioModelId);
expect(modelInfo.info).toBeDefined();
expect(modelInfo.info.maxTokens).toBe(-1);
expect(modelInfo.info.contextWindow).toBe(128_000);
});
});
});
it("should handle empty response", async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: "" } }],
})
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})
})
describe("getModel", () => {
it("should return model info", () => {
const modelInfo = handler.getModel()
expect(modelInfo.id).toBe(mockOptions.lmStudioModelId)
expect(modelInfo.info).toBeDefined()
expect(modelInfo.info.maxTokens).toBe(-1)
expect(modelInfo.info.contextWindow).toBe(128_000)
})
})
})

View File

@@ -1,160 +1,165 @@
import { OllamaHandler } from '../ollama';
import { ApiHandlerOptions } from '../../../shared/api';
import OpenAI from 'openai';
import { Anthropic } from '@anthropic-ai/sdk';
import { OllamaHandler } from "../ollama"
import { ApiHandlerOptions } from "../../../shared/api"
import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"
// Mock OpenAI client
const mockCreate = jest.fn();
jest.mock('openai', () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: 'test-completion',
choices: [{
message: { role: 'assistant', content: 'Test response' },
finish_reason: 'stop',
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [{
delta: { content: 'Test response' },
index: 0
}],
usage: null
};
yield {
choices: [{
delta: {},
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
};
})
}
}
}))
};
});
const mockCreate = jest.fn()
jest.mock("openai", () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: "test-completion",
choices: [
{
message: { role: "assistant", content: "Test response" },
finish_reason: "stop",
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}
}
describe('OllamaHandler', () => {
let handler: OllamaHandler;
let mockOptions: ApiHandlerOptions;
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [
{
delta: { content: "Test response" },
index: 0,
},
],
usage: null,
}
yield {
choices: [
{
delta: {},
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}
},
}
}),
},
},
})),
}
})
beforeEach(() => {
mockOptions = {
apiModelId: 'llama2',
ollamaModelId: 'llama2',
ollamaBaseUrl: 'http://localhost:11434/v1'
};
handler = new OllamaHandler(mockOptions);
mockCreate.mockClear();
});
describe("OllamaHandler", () => {
let handler: OllamaHandler
let mockOptions: ApiHandlerOptions
describe('constructor', () => {
it('should initialize with provided options', () => {
expect(handler).toBeInstanceOf(OllamaHandler);
expect(handler.getModel().id).toBe(mockOptions.ollamaModelId);
});
beforeEach(() => {
mockOptions = {
apiModelId: "llama2",
ollamaModelId: "llama2",
ollamaBaseUrl: "http://localhost:11434/v1",
}
handler = new OllamaHandler(mockOptions)
mockCreate.mockClear()
})
it('should use default base URL if not provided', () => {
const handlerWithoutUrl = new OllamaHandler({
apiModelId: 'llama2',
ollamaModelId: 'llama2'
});
expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler);
});
});
describe("constructor", () => {
it("should initialize with provided options", () => {
expect(handler).toBeInstanceOf(OllamaHandler)
expect(handler.getModel().id).toBe(mockOptions.ollamaModelId)
})
describe('createMessage', () => {
const systemPrompt = 'You are a helpful assistant.';
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: 'Hello!'
}
];
it("should use default base URL if not provided", () => {
const handlerWithoutUrl = new OllamaHandler({
apiModelId: "llama2",
ollamaModelId: "llama2",
})
expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler)
})
})
it('should handle streaming responses', async () => {
const stream = handler.createMessage(systemPrompt, messages);
const chunks: any[] = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
describe("createMessage", () => {
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello!",
},
]
expect(chunks.length).toBeGreaterThan(0);
const textChunks = chunks.filter(chunk => chunk.type === 'text');
expect(textChunks).toHaveLength(1);
expect(textChunks[0].text).toBe('Test response');
});
it("should handle streaming responses", async () => {
const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
expect(chunks.length).toBeGreaterThan(0)
const textChunks = chunks.filter((chunk) => chunk.type === "text")
expect(textChunks).toHaveLength(1)
expect(textChunks[0].text).toBe("Test response")
})
const stream = handler.createMessage(systemPrompt, messages);
it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error("API Error"))
await expect(async () => {
for await (const chunk of stream) {
// Should not reach here
}
}).rejects.toThrow('API Error');
});
});
const stream = handler.createMessage(systemPrompt, messages)
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.ollamaModelId,
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0,
stream: false
});
});
await expect(async () => {
for await (const chunk of stream) {
// Should not reach here
}
}).rejects.toThrow("API Error")
})
})
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Ollama completion error: API Error');
});
describe("completePrompt", () => {
it("should complete prompt successfully", async () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.ollamaModelId,
messages: [{ role: "user", content: "Test prompt" }],
temperature: 0,
stream: false,
})
})
it('should handle empty response', async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: '' } }]
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error("API Error"))
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Ollama completion error: API Error")
})
describe('getModel', () => {
it('should return model info', () => {
const modelInfo = handler.getModel();
expect(modelInfo.id).toBe(mockOptions.ollamaModelId);
expect(modelInfo.info).toBeDefined();
expect(modelInfo.info.maxTokens).toBe(-1);
expect(modelInfo.info.contextWindow).toBe(128_000);
});
});
});
it("should handle empty response", async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: "" } }],
})
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})
})
describe("getModel", () => {
it("should return model info", () => {
const modelInfo = handler.getModel()
expect(modelInfo.id).toBe(mockOptions.ollamaModelId)
expect(modelInfo.info).toBeDefined()
expect(modelInfo.info.maxTokens).toBe(-1)
expect(modelInfo.info.contextWindow).toBe(128_000)
})
})
})

View File

@@ -1,319 +1,326 @@
import { OpenAiNativeHandler } from '../openai-native';
import { ApiHandlerOptions } from '../../../shared/api';
import OpenAI from 'openai';
import { Anthropic } from '@anthropic-ai/sdk';
import { OpenAiNativeHandler } from "../openai-native"
import { ApiHandlerOptions } from "../../../shared/api"
import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"
// Mock OpenAI client
const mockCreate = jest.fn();
jest.mock('openai', () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: 'test-completion',
choices: [{
message: { role: 'assistant', content: 'Test response' },
finish_reason: 'stop',
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [{
delta: { content: 'Test response' },
index: 0
}],
usage: null
};
yield {
choices: [{
delta: {},
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
};
})
}
}
}))
};
});
const mockCreate = jest.fn()
jest.mock("openai", () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: "test-completion",
choices: [
{
message: { role: "assistant", content: "Test response" },
finish_reason: "stop",
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}
}
describe('OpenAiNativeHandler', () => {
let handler: OpenAiNativeHandler;
let mockOptions: ApiHandlerOptions;
const systemPrompt = 'You are a helpful assistant.';
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: 'Hello!'
}
];
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [
{
delta: { content: "Test response" },
index: 0,
},
],
usage: null,
}
yield {
choices: [
{
delta: {},
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}
},
}
}),
},
},
})),
}
})
beforeEach(() => {
mockOptions = {
apiModelId: 'gpt-4o',
openAiNativeApiKey: 'test-api-key'
};
handler = new OpenAiNativeHandler(mockOptions);
mockCreate.mockClear();
});
describe("OpenAiNativeHandler", () => {
let handler: OpenAiNativeHandler
let mockOptions: ApiHandlerOptions
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello!",
},
]
describe('constructor', () => {
it('should initialize with provided options', () => {
expect(handler).toBeInstanceOf(OpenAiNativeHandler);
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
});
beforeEach(() => {
mockOptions = {
apiModelId: "gpt-4o",
openAiNativeApiKey: "test-api-key",
}
handler = new OpenAiNativeHandler(mockOptions)
mockCreate.mockClear()
})
it('should initialize with empty API key', () => {
const handlerWithoutKey = new OpenAiNativeHandler({
apiModelId: 'gpt-4o',
openAiNativeApiKey: ''
});
expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler);
});
});
describe("constructor", () => {
it("should initialize with provided options", () => {
expect(handler).toBeInstanceOf(OpenAiNativeHandler)
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
})
describe('createMessage', () => {
it('should handle streaming responses', async () => {
const stream = handler.createMessage(systemPrompt, messages);
const chunks: any[] = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
it("should initialize with empty API key", () => {
const handlerWithoutKey = new OpenAiNativeHandler({
apiModelId: "gpt-4o",
openAiNativeApiKey: "",
})
expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler)
})
})
expect(chunks.length).toBeGreaterThan(0);
const textChunks = chunks.filter(chunk => chunk.type === 'text');
expect(textChunks).toHaveLength(1);
expect(textChunks[0].text).toBe('Test response');
});
describe("createMessage", () => {
it("should handle streaming responses", async () => {
const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
const stream = handler.createMessage(systemPrompt, messages);
await expect(async () => {
for await (const chunk of stream) {
// Should not reach here
}
}).rejects.toThrow('API Error');
});
expect(chunks.length).toBeGreaterThan(0)
const textChunks = chunks.filter((chunk) => chunk.type === "text")
expect(textChunks).toHaveLength(1)
expect(textChunks[0].text).toBe("Test response")
})
it('should handle missing content in response for o1 model', async () => {
// Use o1 model which supports developer role
handler = new OpenAiNativeHandler({
...mockOptions,
apiModelId: 'o1'
});
it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error("API Error"))
const stream = handler.createMessage(systemPrompt, messages)
await expect(async () => {
for await (const chunk of stream) {
// Should not reach here
}
}).rejects.toThrow("API Error")
})
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: null } }],
usage: {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0
}
});
it("should handle missing content in response for o1 model", async () => {
// Use o1 model which supports developer role
handler = new OpenAiNativeHandler({
...mockOptions,
apiModelId: "o1",
})
const generator = handler.createMessage(systemPrompt, messages);
const results = [];
for await (const result of generator) {
results.push(result);
}
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: null } }],
usage: {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
})
expect(results).toEqual([
{ type: 'text', text: '' },
{ type: 'usage', inputTokens: 0, outputTokens: 0 }
]);
const generator = handler.createMessage(systemPrompt, messages)
const results = []
for await (const result of generator) {
results.push(result)
}
// Verify developer role is used for system prompt with o1 model
expect(mockCreate).toHaveBeenCalledWith({
model: 'o1',
messages: [
{ role: 'developer', content: systemPrompt },
{ role: 'user', content: 'Hello!' }
]
});
});
});
expect(results).toEqual([
{ type: "text", text: "" },
{ type: "usage", inputTokens: 0, outputTokens: 0 },
])
describe('streaming models', () => {
beforeEach(() => {
handler = new OpenAiNativeHandler({
...mockOptions,
apiModelId: 'gpt-4o',
});
});
// Verify developer role is used for system prompt with o1 model
expect(mockCreate).toHaveBeenCalledWith({
model: "o1",
messages: [
{ role: "developer", content: systemPrompt },
{ role: "user", content: "Hello!" },
],
})
})
})
it('should handle streaming response', async () => {
const mockStream = [
{ choices: [{ delta: { content: 'Hello' } }], usage: null },
{ choices: [{ delta: { content: ' there' } }], usage: null },
{ choices: [{ delta: { content: '!' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
];
describe("streaming models", () => {
beforeEach(() => {
handler = new OpenAiNativeHandler({
...mockOptions,
apiModelId: "gpt-4o",
})
})
mockCreate.mockResolvedValueOnce(
(async function* () {
for (const chunk of mockStream) {
yield chunk;
}
})()
);
it("should handle streaming response", async () => {
const mockStream = [
{ choices: [{ delta: { content: "Hello" } }], usage: null },
{ choices: [{ delta: { content: " there" } }], usage: null },
{ choices: [{ delta: { content: "!" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
]
const generator = handler.createMessage(systemPrompt, messages);
const results = [];
for await (const result of generator) {
results.push(result);
}
mockCreate.mockResolvedValueOnce(
(async function* () {
for (const chunk of mockStream) {
yield chunk
}
})(),
)
expect(results).toEqual([
{ type: 'text', text: 'Hello' },
{ type: 'text', text: ' there' },
{ type: 'text', text: '!' },
{ type: 'usage', inputTokens: 10, outputTokens: 5 },
]);
const generator = handler.createMessage(systemPrompt, messages)
const results = []
for await (const result of generator) {
results.push(result)
}
expect(mockCreate).toHaveBeenCalledWith({
model: 'gpt-4o',
temperature: 0,
messages: [
{ role: 'system', content: systemPrompt },
{ role: 'user', content: 'Hello!' },
],
stream: true,
stream_options: { include_usage: true },
});
});
expect(results).toEqual([
{ type: "text", text: "Hello" },
{ type: "text", text: " there" },
{ type: "text", text: "!" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
it('should handle empty delta content', async () => {
const mockStream = [
{ choices: [{ delta: {} }], usage: null },
{ choices: [{ delta: { content: null } }], usage: null },
{ choices: [{ delta: { content: 'Hello' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
];
expect(mockCreate).toHaveBeenCalledWith({
model: "gpt-4o",
temperature: 0,
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: "Hello!" },
],
stream: true,
stream_options: { include_usage: true },
})
})
mockCreate.mockResolvedValueOnce(
(async function* () {
for (const chunk of mockStream) {
yield chunk;
}
})()
);
it("should handle empty delta content", async () => {
const mockStream = [
{ choices: [{ delta: {} }], usage: null },
{ choices: [{ delta: { content: null } }], usage: null },
{ choices: [{ delta: { content: "Hello" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
]
const generator = handler.createMessage(systemPrompt, messages);
const results = [];
for await (const result of generator) {
results.push(result);
}
mockCreate.mockResolvedValueOnce(
(async function* () {
for (const chunk of mockStream) {
yield chunk
}
})(),
)
expect(results).toEqual([
{ type: 'text', text: 'Hello' },
{ type: 'usage', inputTokens: 10, outputTokens: 5 },
]);
});
});
const generator = handler.createMessage(systemPrompt, messages)
const results = []
for await (const result of generator) {
results.push(result)
}
describe('completePrompt', () => {
it('should complete prompt successfully with gpt-4o model', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: 'gpt-4o',
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0
});
});
expect(results).toEqual([
{ type: "text", text: "Hello" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
})
})
it('should complete prompt successfully with o1 model', async () => {
handler = new OpenAiNativeHandler({
apiModelId: 'o1',
openAiNativeApiKey: 'test-api-key'
});
describe("completePrompt", () => {
it("should complete prompt successfully with gpt-4o model", async () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({
model: "gpt-4o",
messages: [{ role: "user", content: "Test prompt" }],
temperature: 0,
})
})
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: 'o1',
messages: [{ role: 'user', content: 'Test prompt' }]
});
});
it("should complete prompt successfully with o1 model", async () => {
handler = new OpenAiNativeHandler({
apiModelId: "o1",
openAiNativeApiKey: "test-api-key",
})
it('should complete prompt successfully with o1-preview model', async () => {
handler = new OpenAiNativeHandler({
apiModelId: 'o1-preview',
openAiNativeApiKey: 'test-api-key'
});
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({
model: "o1",
messages: [{ role: "user", content: "Test prompt" }],
})
})
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: 'o1-preview',
messages: [{ role: 'user', content: 'Test prompt' }]
});
});
it("should complete prompt successfully with o1-preview model", async () => {
handler = new OpenAiNativeHandler({
apiModelId: "o1-preview",
openAiNativeApiKey: "test-api-key",
})
it('should complete prompt successfully with o1-mini model', async () => {
handler = new OpenAiNativeHandler({
apiModelId: 'o1-mini',
openAiNativeApiKey: 'test-api-key'
});
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({
model: "o1-preview",
messages: [{ role: "user", content: "Test prompt" }],
})
})
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: 'o1-mini',
messages: [{ role: 'user', content: 'Test prompt' }]
});
});
it("should complete prompt successfully with o1-mini model", async () => {
handler = new OpenAiNativeHandler({
apiModelId: "o1-mini",
openAiNativeApiKey: "test-api-key",
})
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('OpenAI Native completion error: API Error');
});
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({
model: "o1-mini",
messages: [{ role: "user", content: "Test prompt" }],
})
})
it('should handle empty response', async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: '' } }]
});
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error("API Error"))
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
"OpenAI Native completion error: API Error",
)
})
describe('getModel', () => {
it('should return model info', () => {
const modelInfo = handler.getModel();
expect(modelInfo.id).toBe(mockOptions.apiModelId);
expect(modelInfo.info).toBeDefined();
expect(modelInfo.info.maxTokens).toBe(4096);
expect(modelInfo.info.contextWindow).toBe(128_000);
});
it("should handle empty response", async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: "" } }],
})
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})
})
it('should handle undefined model ID', () => {
const handlerWithoutModel = new OpenAiNativeHandler({
openAiNativeApiKey: 'test-api-key'
});
const modelInfo = handlerWithoutModel.getModel();
expect(modelInfo.id).toBe('gpt-4o'); // Default model
expect(modelInfo.info).toBeDefined();
});
});
});
describe("getModel", () => {
it("should return model info", () => {
const modelInfo = handler.getModel()
expect(modelInfo.id).toBe(mockOptions.apiModelId)
expect(modelInfo.info).toBeDefined()
expect(modelInfo.info.maxTokens).toBe(4096)
expect(modelInfo.info.contextWindow).toBe(128_000)
})
it("should handle undefined model ID", () => {
const handlerWithoutModel = new OpenAiNativeHandler({
openAiNativeApiKey: "test-api-key",
})
const modelInfo = handlerWithoutModel.getModel()
expect(modelInfo.id).toBe("gpt-4o") // Default model
expect(modelInfo.info).toBeDefined()
})
})
})

View File

@@ -1,224 +1,233 @@
import { OpenAiHandler } from '../openai';
import { ApiHandlerOptions } from '../../../shared/api';
import { ApiStream } from '../../transform/stream';
import OpenAI from 'openai';
import { Anthropic } from '@anthropic-ai/sdk';
import { OpenAiHandler } from "../openai"
import { ApiHandlerOptions } from "../../../shared/api"
import { ApiStream } from "../../transform/stream"
import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"
// Mock OpenAI client
const mockCreate = jest.fn();
jest.mock('openai', () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: 'test-completion',
choices: [{
message: { role: 'assistant', content: 'Test response', refusal: null },
finish_reason: 'stop',
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [{
delta: { content: 'Test response' },
index: 0
}],
usage: null
};
yield {
choices: [{
delta: {},
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
}
};
})
}
}
}))
};
});
const mockCreate = jest.fn()
jest.mock("openai", () => {
return {
__esModule: true,
default: jest.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: "test-completion",
choices: [
{
message: { role: "assistant", content: "Test response", refusal: null },
finish_reason: "stop",
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}
}
describe('OpenAiHandler', () => {
let handler: OpenAiHandler;
let mockOptions: ApiHandlerOptions;
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [
{
delta: { content: "Test response" },
index: 0,
},
],
usage: null,
}
yield {
choices: [
{
delta: {},
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}
},
}
}),
},
},
})),
}
})
beforeEach(() => {
mockOptions = {
openAiApiKey: 'test-api-key',
openAiModelId: 'gpt-4',
openAiBaseUrl: 'https://api.openai.com/v1'
};
handler = new OpenAiHandler(mockOptions);
mockCreate.mockClear();
});
describe("OpenAiHandler", () => {
let handler: OpenAiHandler
let mockOptions: ApiHandlerOptions
describe('constructor', () => {
it('should initialize with provided options', () => {
expect(handler).toBeInstanceOf(OpenAiHandler);
expect(handler.getModel().id).toBe(mockOptions.openAiModelId);
});
beforeEach(() => {
mockOptions = {
openAiApiKey: "test-api-key",
openAiModelId: "gpt-4",
openAiBaseUrl: "https://api.openai.com/v1",
}
handler = new OpenAiHandler(mockOptions)
mockCreate.mockClear()
})
it('should use custom base URL if provided', () => {
const customBaseUrl = 'https://custom.openai.com/v1';
const handlerWithCustomUrl = new OpenAiHandler({
...mockOptions,
openAiBaseUrl: customBaseUrl
});
expect(handlerWithCustomUrl).toBeInstanceOf(OpenAiHandler);
});
});
describe("constructor", () => {
it("should initialize with provided options", () => {
expect(handler).toBeInstanceOf(OpenAiHandler)
expect(handler.getModel().id).toBe(mockOptions.openAiModelId)
})
describe('createMessage', () => {
const systemPrompt = 'You are a helpful assistant.';
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: [{
type: 'text' as const,
text: 'Hello!'
}]
}
];
it("should use custom base URL if provided", () => {
const customBaseUrl = "https://custom.openai.com/v1"
const handlerWithCustomUrl = new OpenAiHandler({
...mockOptions,
openAiBaseUrl: customBaseUrl,
})
expect(handlerWithCustomUrl).toBeInstanceOf(OpenAiHandler)
})
})
it('should handle non-streaming mode', async () => {
const handler = new OpenAiHandler({
...mockOptions,
openAiStreamingEnabled: false
});
describe("createMessage", () => {
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [
{
type: "text" as const,
text: "Hello!",
},
],
},
]
const stream = handler.createMessage(systemPrompt, messages);
const chunks: any[] = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
it("should handle non-streaming mode", async () => {
const handler = new OpenAiHandler({
...mockOptions,
openAiStreamingEnabled: false,
})
expect(chunks.length).toBeGreaterThan(0);
const textChunk = chunks.find(chunk => chunk.type === 'text');
const usageChunk = chunks.find(chunk => chunk.type === 'usage');
expect(textChunk).toBeDefined();
expect(textChunk?.text).toBe('Test response');
expect(usageChunk).toBeDefined();
expect(usageChunk?.inputTokens).toBe(10);
expect(usageChunk?.outputTokens).toBe(5);
});
const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}
it('should handle streaming responses', async () => {
const stream = handler.createMessage(systemPrompt, messages);
const chunks: any[] = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
expect(chunks.length).toBeGreaterThan(0)
const textChunk = chunks.find((chunk) => chunk.type === "text")
const usageChunk = chunks.find((chunk) => chunk.type === "usage")
expect(chunks.length).toBeGreaterThan(0);
const textChunks = chunks.filter(chunk => chunk.type === 'text');
expect(textChunks).toHaveLength(1);
expect(textChunks[0].text).toBe('Test response');
});
});
expect(textChunk).toBeDefined()
expect(textChunk?.text).toBe("Test response")
expect(usageChunk).toBeDefined()
expect(usageChunk?.inputTokens).toBe(10)
expect(usageChunk?.outputTokens).toBe(5)
})
describe('error handling', () => {
const testMessages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: [{
type: 'text' as const,
text: 'Hello'
}]
}
];
it("should handle streaming responses", async () => {
const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
expect(chunks.length).toBeGreaterThan(0)
const textChunks = chunks.filter((chunk) => chunk.type === "text")
expect(textChunks).toHaveLength(1)
expect(textChunks[0].text).toBe("Test response")
})
})
const stream = handler.createMessage('system prompt', testMessages);
describe("error handling", () => {
const testMessages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [
{
type: "text" as const,
text: "Hello",
},
],
},
]
await expect(async () => {
for await (const chunk of stream) {
// Should not reach here
}
}).rejects.toThrow('API Error');
});
it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error("API Error"))
it('should handle rate limiting', async () => {
const rateLimitError = new Error('Rate limit exceeded');
rateLimitError.name = 'Error';
(rateLimitError as any).status = 429;
mockCreate.mockRejectedValueOnce(rateLimitError);
const stream = handler.createMessage("system prompt", testMessages)
const stream = handler.createMessage('system prompt', testMessages);
await expect(async () => {
for await (const chunk of stream) {
// Should not reach here
}
}).rejects.toThrow("API Error")
})
await expect(async () => {
for await (const chunk of stream) {
// Should not reach here
}
}).rejects.toThrow('Rate limit exceeded');
});
});
it("should handle rate limiting", async () => {
const rateLimitError = new Error("Rate limit exceeded")
rateLimitError.name = "Error"
;(rateLimitError as any).status = 429
mockCreate.mockRejectedValueOnce(rateLimitError)
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.openAiModelId,
messages: [{ role: 'user', content: 'Test prompt' }],
temperature: 0
});
});
const stream = handler.createMessage("system prompt", testMessages)
it('should handle API errors', async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error'));
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('OpenAI completion error: API Error');
});
await expect(async () => {
for await (const chunk of stream) {
// Should not reach here
}
}).rejects.toThrow("Rate limit exceeded")
})
})
it('should handle empty response', async () => {
mockCreate.mockImplementationOnce(() => ({
choices: [{ message: { content: '' } }]
}));
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
describe("completePrompt", () => {
it("should complete prompt successfully", async () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.openAiModelId,
messages: [{ role: "user", content: "Test prompt" }],
temperature: 0,
})
})
describe('getModel', () => {
it('should return model info with sane defaults', () => {
const model = handler.getModel();
expect(model.id).toBe(mockOptions.openAiModelId);
expect(model.info).toBeDefined();
expect(model.info.contextWindow).toBe(128_000);
expect(model.info.supportsImages).toBe(true);
});
it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error("API Error"))
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("OpenAI completion error: API Error")
})
it('should handle undefined model ID', () => {
const handlerWithoutModel = new OpenAiHandler({
...mockOptions,
openAiModelId: undefined
});
const model = handlerWithoutModel.getModel();
expect(model.id).toBe('');
expect(model.info).toBeDefined();
});
});
});
it("should handle empty response", async () => {
mockCreate.mockImplementationOnce(() => ({
choices: [{ message: { content: "" } }],
}))
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})
})
describe("getModel", () => {
it("should return model info with sane defaults", () => {
const model = handler.getModel()
expect(model.id).toBe(mockOptions.openAiModelId)
expect(model.info).toBeDefined()
expect(model.info.contextWindow).toBe(128_000)
expect(model.info.supportsImages).toBe(true)
})
it("should handle undefined model ID", () => {
const handlerWithoutModel = new OpenAiHandler({
...mockOptions,
openAiModelId: undefined,
})
const model = handlerWithoutModel.getModel()
expect(model.id).toBe("")
expect(model.info).toBeDefined()
})
})
})

View File

@@ -1,283 +1,297 @@
import { OpenRouterHandler } from '../openrouter'
import { ApiHandlerOptions, ModelInfo } from '../../../shared/api'
import OpenAI from 'openai'
import axios from 'axios'
import { Anthropic } from '@anthropic-ai/sdk'
import { OpenRouterHandler } from "../openrouter"
import { ApiHandlerOptions, ModelInfo } from "../../../shared/api"
import OpenAI from "openai"
import axios from "axios"
import { Anthropic } from "@anthropic-ai/sdk"
// Mock dependencies
jest.mock('openai')
jest.mock('axios')
jest.mock('delay', () => jest.fn(() => Promise.resolve()))
jest.mock("openai")
jest.mock("axios")
jest.mock("delay", () => jest.fn(() => Promise.resolve()))
describe('OpenRouterHandler', () => {
const mockOptions: ApiHandlerOptions = {
openRouterApiKey: 'test-key',
openRouterModelId: 'test-model',
openRouterModelInfo: {
name: 'Test Model',
description: 'Test Description',
maxTokens: 1000,
contextWindow: 2000,
supportsPromptCache: true,
inputPrice: 0.01,
outputPrice: 0.02
} as ModelInfo
}
describe("OpenRouterHandler", () => {
const mockOptions: ApiHandlerOptions = {
openRouterApiKey: "test-key",
openRouterModelId: "test-model",
openRouterModelInfo: {
name: "Test Model",
description: "Test Description",
maxTokens: 1000,
contextWindow: 2000,
supportsPromptCache: true,
inputPrice: 0.01,
outputPrice: 0.02,
} as ModelInfo,
}
beforeEach(() => {
jest.clearAllMocks()
})
beforeEach(() => {
jest.clearAllMocks()
})
test('constructor initializes with correct options', () => {
const handler = new OpenRouterHandler(mockOptions)
expect(handler).toBeInstanceOf(OpenRouterHandler)
expect(OpenAI).toHaveBeenCalledWith({
baseURL: 'https://openrouter.ai/api/v1',
apiKey: mockOptions.openRouterApiKey,
defaultHeaders: {
'HTTP-Referer': 'https://github.com/RooVetGit/Roo-Cline',
'X-Title': 'Roo-Cline',
},
})
})
test("constructor initializes with correct options", () => {
const handler = new OpenRouterHandler(mockOptions)
expect(handler).toBeInstanceOf(OpenRouterHandler)
expect(OpenAI).toHaveBeenCalledWith({
baseURL: "https://openrouter.ai/api/v1",
apiKey: mockOptions.openRouterApiKey,
defaultHeaders: {
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
"X-Title": "Roo-Cline",
},
})
})
test('getModel returns correct model info when options are provided', () => {
const handler = new OpenRouterHandler(mockOptions)
const result = handler.getModel()
expect(result).toEqual({
id: mockOptions.openRouterModelId,
info: mockOptions.openRouterModelInfo
})
})
test("getModel returns correct model info when options are provided", () => {
const handler = new OpenRouterHandler(mockOptions)
const result = handler.getModel()
test('getModel returns default model info when options are not provided', () => {
const handler = new OpenRouterHandler({})
const result = handler.getModel()
expect(result.id).toBe('anthropic/claude-3.5-sonnet:beta')
expect(result.info.supportsPromptCache).toBe(true)
})
expect(result).toEqual({
id: mockOptions.openRouterModelId,
info: mockOptions.openRouterModelInfo,
})
})
test('createMessage generates correct stream chunks', async () => {
const handler = new OpenRouterHandler(mockOptions)
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
id: 'test-id',
choices: [{
delta: {
content: 'test response'
}
}]
}
}
}
test("getModel returns default model info when options are not provided", () => {
const handler = new OpenRouterHandler({})
const result = handler.getModel()
// Mock OpenAI chat.completions.create
const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
expect(result.id).toBe("anthropic/claude-3.5-sonnet:beta")
expect(result.info.supportsPromptCache).toBe(true)
})
// Mock axios.get for generation details
;(axios.get as jest.Mock).mockResolvedValue({
data: {
data: {
native_tokens_prompt: 10,
native_tokens_completion: 20,
total_cost: 0.001
}
}
})
test("createMessage generates correct stream chunks", async () => {
const handler = new OpenRouterHandler(mockOptions)
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
id: "test-id",
choices: [
{
delta: {
content: "test response",
},
},
],
}
},
}
const systemPrompt = 'test system prompt'
const messages: Anthropic.Messages.MessageParam[] = [{ role: 'user' as const, content: 'test message' }]
// Mock OpenAI chat.completions.create
const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate },
} as any
const generator = handler.createMessage(systemPrompt, messages)
const chunks = []
for await (const chunk of generator) {
chunks.push(chunk)
}
// Mock axios.get for generation details
;(axios.get as jest.Mock).mockResolvedValue({
data: {
data: {
native_tokens_prompt: 10,
native_tokens_completion: 20,
total_cost: 0.001,
},
},
})
// Verify stream chunks
expect(chunks).toHaveLength(2) // One text chunk and one usage chunk
expect(chunks[0]).toEqual({
type: 'text',
text: 'test response'
})
expect(chunks[1]).toEqual({
type: 'usage',
inputTokens: 10,
outputTokens: 20,
totalCost: 0.001,
fullResponseText: 'test response'
})
const systemPrompt = "test system prompt"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user" as const, content: "test message" }]
// Verify OpenAI client was called with correct parameters
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
model: mockOptions.openRouterModelId,
temperature: 0,
messages: expect.arrayContaining([
{ role: 'system', content: systemPrompt },
{ role: 'user', content: 'test message' }
]),
stream: true
}))
})
const generator = handler.createMessage(systemPrompt, messages)
const chunks = []
test('createMessage with middle-out transform enabled', async () => {
const handler = new OpenRouterHandler({
...mockOptions,
openRouterUseMiddleOutTransform: true
})
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
id: 'test-id',
choices: [{
delta: {
content: 'test response'
}
}]
}
}
}
for await (const chunk of generator) {
chunks.push(chunk)
}
const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
// Verify stream chunks
expect(chunks).toHaveLength(2) // One text chunk and one usage chunk
expect(chunks[0]).toEqual({
type: "text",
text: "test response",
})
expect(chunks[1]).toEqual({
type: "usage",
inputTokens: 10,
outputTokens: 20,
totalCost: 0.001,
fullResponseText: "test response",
})
await handler.createMessage('test', []).next()
// Verify OpenAI client was called with correct parameters
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: mockOptions.openRouterModelId,
temperature: 0,
messages: expect.arrayContaining([
{ role: "system", content: systemPrompt },
{ role: "user", content: "test message" },
]),
stream: true,
}),
)
})
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
transforms: ['middle-out']
}))
})
test("createMessage with middle-out transform enabled", async () => {
const handler = new OpenRouterHandler({
...mockOptions,
openRouterUseMiddleOutTransform: true,
})
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
id: "test-id",
choices: [
{
delta: {
content: "test response",
},
},
],
}
},
}
test('createMessage with Claude model adds cache control', async () => {
const handler = new OpenRouterHandler({
...mockOptions,
openRouterModelId: 'anthropic/claude-3.5-sonnet'
})
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
id: 'test-id',
choices: [{
delta: {
content: 'test response'
}
}]
}
}
}
const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate },
} as any
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
await handler.createMessage("test", []).next()
const messages: Anthropic.Messages.MessageParam[] = [
{ role: 'user', content: 'message 1' },
{ role: 'assistant', content: 'response 1' },
{ role: 'user', content: 'message 2' }
]
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
transforms: ["middle-out"],
}),
)
})
await handler.createMessage('test system', messages).next()
test("createMessage with Claude model adds cache control", async () => {
const handler = new OpenRouterHandler({
...mockOptions,
openRouterModelId: "anthropic/claude-3.5-sonnet",
})
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
id: "test-id",
choices: [
{
delta: {
content: "test response",
},
},
],
}
},
}
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
messages: expect.arrayContaining([
expect.objectContaining({
role: 'system',
content: expect.arrayContaining([
expect.objectContaining({
cache_control: { type: 'ephemeral' }
})
])
})
])
}))
})
const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate },
} as any
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
test('createMessage handles API errors', async () => {
const handler = new OpenRouterHandler(mockOptions)
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
error: {
message: 'API Error',
code: 500
}
}
}
}
const messages: Anthropic.Messages.MessageParam[] = [
{ role: "user", content: "message 1" },
{ role: "assistant", content: "response 1" },
{ role: "user", content: "message 2" },
]
const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
await handler.createMessage("test system", messages).next()
const generator = handler.createMessage('test', [])
await expect(generator.next()).rejects.toThrow('OpenRouter API Error 500: API Error')
})
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.arrayContaining([
expect.objectContaining({
role: "system",
content: expect.arrayContaining([
expect.objectContaining({
cache_control: { type: "ephemeral" },
}),
]),
}),
]),
}),
)
})
test('completePrompt returns correct response', async () => {
const handler = new OpenRouterHandler(mockOptions)
const mockResponse = {
choices: [{
message: {
content: 'test completion'
}
}]
}
test("createMessage handles API errors", async () => {
const handler = new OpenRouterHandler(mockOptions)
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
error: {
message: "API Error",
code: 500,
},
}
},
}
const mockCreate = jest.fn().mockResolvedValue(mockResponse)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate },
} as any
const result = await handler.completePrompt('test prompt')
const generator = handler.createMessage("test", [])
await expect(generator.next()).rejects.toThrow("OpenRouter API Error 500: API Error")
})
expect(result).toBe('test completion')
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.openRouterModelId,
messages: [{ role: 'user', content: 'test prompt' }],
temperature: 0,
stream: false
})
})
test("completePrompt returns correct response", async () => {
const handler = new OpenRouterHandler(mockOptions)
const mockResponse = {
choices: [
{
message: {
content: "test completion",
},
},
],
}
test('completePrompt handles API errors', async () => {
const handler = new OpenRouterHandler(mockOptions)
const mockError = {
error: {
message: 'API Error',
code: 500
}
}
const mockCreate = jest.fn().mockResolvedValue(mockResponse)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate },
} as any
const mockCreate = jest.fn().mockResolvedValue(mockError)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
const result = await handler.completePrompt("test prompt")
await expect(handler.completePrompt('test prompt'))
.rejects.toThrow('OpenRouter API Error 500: API Error')
})
expect(result).toBe("test completion")
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.openRouterModelId,
messages: [{ role: "user", content: "test prompt" }],
temperature: 0,
stream: false,
})
})
test('completePrompt handles unexpected errors', async () => {
const handler = new OpenRouterHandler(mockOptions)
const mockCreate = jest.fn().mockRejectedValue(new Error('Unexpected error'))
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
test("completePrompt handles API errors", async () => {
const handler = new OpenRouterHandler(mockOptions)
const mockError = {
error: {
message: "API Error",
code: 500,
},
}
await expect(handler.completePrompt('test prompt'))
.rejects.toThrow('OpenRouter completion error: Unexpected error')
})
const mockCreate = jest.fn().mockResolvedValue(mockError)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate },
} as any
await expect(handler.completePrompt("test prompt")).rejects.toThrow("OpenRouter API Error 500: API Error")
})
test("completePrompt handles unexpected errors", async () => {
const handler = new OpenRouterHandler(mockOptions)
const mockCreate = jest.fn().mockRejectedValue(new Error("Unexpected error"))
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate },
} as any
await expect(handler.completePrompt("test prompt")).rejects.toThrow(
"OpenRouter completion error: Unexpected error",
)
})
})

View File

@@ -1,296 +1,295 @@
import { VertexHandler } from '../vertex';
import { Anthropic } from '@anthropic-ai/sdk';
import { AnthropicVertex } from '@anthropic-ai/vertex-sdk';
import { VertexHandler } from "../vertex"
import { Anthropic } from "@anthropic-ai/sdk"
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
// Mock Vertex SDK
jest.mock('@anthropic-ai/vertex-sdk', () => ({
AnthropicVertex: jest.fn().mockImplementation(() => ({
messages: {
create: jest.fn().mockImplementation(async (options) => {
if (!options.stream) {
return {
id: 'test-completion',
content: [
{ type: 'text', text: 'Test response' }
],
role: 'assistant',
model: options.model,
usage: {
input_tokens: 10,
output_tokens: 5
}
}
}
return {
async *[Symbol.asyncIterator]() {
yield {
type: 'message_start',
message: {
usage: {
input_tokens: 10,
output_tokens: 5
}
}
}
yield {
type: 'content_block_start',
content_block: {
type: 'text',
text: 'Test response'
}
}
}
}
})
}
}))
}));
jest.mock("@anthropic-ai/vertex-sdk", () => ({
AnthropicVertex: jest.fn().mockImplementation(() => ({
messages: {
create: jest.fn().mockImplementation(async (options) => {
if (!options.stream) {
return {
id: "test-completion",
content: [{ type: "text", text: "Test response" }],
role: "assistant",
model: options.model,
usage: {
input_tokens: 10,
output_tokens: 5,
},
}
}
return {
async *[Symbol.asyncIterator]() {
yield {
type: "message_start",
message: {
usage: {
input_tokens: 10,
output_tokens: 5,
},
},
}
yield {
type: "content_block_start",
content_block: {
type: "text",
text: "Test response",
},
}
},
}
}),
},
})),
}))
describe('VertexHandler', () => {
let handler: VertexHandler;
describe("VertexHandler", () => {
let handler: VertexHandler
beforeEach(() => {
handler = new VertexHandler({
apiModelId: 'claude-3-5-sonnet-v2@20241022',
vertexProjectId: 'test-project',
vertexRegion: 'us-central1'
});
});
beforeEach(() => {
handler = new VertexHandler({
apiModelId: "claude-3-5-sonnet-v2@20241022",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
})
})
describe('constructor', () => {
it('should initialize with provided config', () => {
expect(AnthropicVertex).toHaveBeenCalledWith({
projectId: 'test-project',
region: 'us-central1'
});
});
});
describe("constructor", () => {
it("should initialize with provided config", () => {
expect(AnthropicVertex).toHaveBeenCalledWith({
projectId: "test-project",
region: "us-central1",
})
})
})
describe('createMessage', () => {
const mockMessages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: 'Hello'
},
{
role: 'assistant',
content: 'Hi there!'
}
];
describe("createMessage", () => {
const mockMessages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello",
},
{
role: "assistant",
content: "Hi there!",
},
]
const systemPrompt = 'You are a helpful assistant';
const systemPrompt = "You are a helpful assistant"
it('should handle streaming responses correctly', async () => {
const mockStream = [
{
type: 'message_start',
message: {
usage: {
input_tokens: 10,
output_tokens: 0
}
}
},
{
type: 'content_block_start',
index: 0,
content_block: {
type: 'text',
text: 'Hello'
}
},
{
type: 'content_block_delta',
delta: {
type: 'text_delta',
text: ' world!'
}
},
{
type: 'message_delta',
usage: {
output_tokens: 5
}
}
];
it("should handle streaming responses correctly", async () => {
const mockStream = [
{
type: "message_start",
message: {
usage: {
input_tokens: 10,
output_tokens: 0,
},
},
},
{
type: "content_block_start",
index: 0,
content_block: {
type: "text",
text: "Hello",
},
},
{
type: "content_block_delta",
delta: {
type: "text_delta",
text: " world!",
},
},
{
type: "message_delta",
usage: {
output_tokens: 5,
},
},
]
// Setup async iterator for mock stream
const asyncIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of mockStream) {
yield chunk;
}
}
};
// Setup async iterator for mock stream
const asyncIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of mockStream) {
yield chunk
}
},
}
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
(handler['client'].messages as any).create = mockCreate;
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
;(handler["client"].messages as any).create = mockCreate
const stream = handler.createMessage(systemPrompt, mockMessages);
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []
expect(chunks.length).toBe(4);
expect(chunks[0]).toEqual({
type: 'usage',
inputTokens: 10,
outputTokens: 0
});
expect(chunks[1]).toEqual({
type: 'text',
text: 'Hello'
});
expect(chunks[2]).toEqual({
type: 'text',
text: ' world!'
});
expect(chunks[3]).toEqual({
type: 'usage',
inputTokens: 0,
outputTokens: 5
});
for await (const chunk of stream) {
chunks.push(chunk)
}
expect(mockCreate).toHaveBeenCalledWith({
model: 'claude-3-5-sonnet-v2@20241022',
max_tokens: 8192,
temperature: 0,
system: systemPrompt,
messages: mockMessages,
stream: true
});
});
expect(chunks.length).toBe(4)
expect(chunks[0]).toEqual({
type: "usage",
inputTokens: 10,
outputTokens: 0,
})
expect(chunks[1]).toEqual({
type: "text",
text: "Hello",
})
expect(chunks[2]).toEqual({
type: "text",
text: " world!",
})
expect(chunks[3]).toEqual({
type: "usage",
inputTokens: 0,
outputTokens: 5,
})
it('should handle multiple content blocks with line breaks', async () => {
const mockStream = [
{
type: 'content_block_start',
index: 0,
content_block: {
type: 'text',
text: 'First line'
}
},
{
type: 'content_block_start',
index: 1,
content_block: {
type: 'text',
text: 'Second line'
}
}
];
expect(mockCreate).toHaveBeenCalledWith({
model: "claude-3-5-sonnet-v2@20241022",
max_tokens: 8192,
temperature: 0,
system: systemPrompt,
messages: mockMessages,
stream: true,
})
})
const asyncIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of mockStream) {
yield chunk;
}
}
};
it("should handle multiple content blocks with line breaks", async () => {
const mockStream = [
{
type: "content_block_start",
index: 0,
content_block: {
type: "text",
text: "First line",
},
},
{
type: "content_block_start",
index: 1,
content_block: {
type: "text",
text: "Second line",
},
},
]
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
(handler['client'].messages as any).create = mockCreate;
const asyncIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of mockStream) {
yield chunk
}
},
}
const stream = handler.createMessage(systemPrompt, mockMessages);
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
;(handler["client"].messages as any).create = mockCreate
expect(chunks.length).toBe(3);
expect(chunks[0]).toEqual({
type: 'text',
text: 'First line'
});
expect(chunks[1]).toEqual({
type: 'text',
text: '\n'
});
expect(chunks[2]).toEqual({
type: 'text',
text: 'Second line'
});
});
const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []
it('should handle API errors', async () => {
const mockError = new Error('Vertex API error');
const mockCreate = jest.fn().mockRejectedValue(mockError);
(handler['client'].messages as any).create = mockCreate;
for await (const chunk of stream) {
chunks.push(chunk)
}
const stream = handler.createMessage(systemPrompt, mockMessages);
expect(chunks.length).toBe(3)
expect(chunks[0]).toEqual({
type: "text",
text: "First line",
})
expect(chunks[1]).toEqual({
type: "text",
text: "\n",
})
expect(chunks[2]).toEqual({
type: "text",
text: "Second line",
})
})
await expect(async () => {
for await (const chunk of stream) {
// Should throw before yielding any chunks
}
}).rejects.toThrow('Vertex API error');
});
});
it("should handle API errors", async () => {
const mockError = new Error("Vertex API error")
const mockCreate = jest.fn().mockRejectedValue(mockError)
;(handler["client"].messages as any).create = mockCreate
describe('completePrompt', () => {
it('should complete prompt successfully', async () => {
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('Test response');
expect(handler['client'].messages.create).toHaveBeenCalledWith({
model: 'claude-3-5-sonnet-v2@20241022',
max_tokens: 8192,
temperature: 0,
messages: [{ role: 'user', content: 'Test prompt' }],
stream: false
});
});
const stream = handler.createMessage(systemPrompt, mockMessages)
it('should handle API errors', async () => {
const mockError = new Error('Vertex API error');
const mockCreate = jest.fn().mockRejectedValue(mockError);
(handler['client'].messages as any).create = mockCreate;
await expect(async () => {
for await (const chunk of stream) {
// Should throw before yielding any chunks
}
}).rejects.toThrow("Vertex API error")
})
})
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Vertex completion error: Vertex API error');
});
describe("completePrompt", () => {
it("should complete prompt successfully", async () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(handler["client"].messages.create).toHaveBeenCalledWith({
model: "claude-3-5-sonnet-v2@20241022",
max_tokens: 8192,
temperature: 0,
messages: [{ role: "user", content: "Test prompt" }],
stream: false,
})
})
it('should handle non-text content', async () => {
const mockCreate = jest.fn().mockResolvedValue({
content: [{ type: 'image' }]
});
(handler['client'].messages as any).create = mockCreate;
it("should handle API errors", async () => {
const mockError = new Error("Vertex API error")
const mockCreate = jest.fn().mockRejectedValue(mockError)
;(handler["client"].messages as any).create = mockCreate
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
"Vertex completion error: Vertex API error",
)
})
it('should handle empty response', async () => {
const mockCreate = jest.fn().mockResolvedValue({
content: [{ type: 'text', text: '' }]
});
(handler['client'].messages as any).create = mockCreate;
it("should handle non-text content", async () => {
const mockCreate = jest.fn().mockResolvedValue({
content: [{ type: "image" }],
})
;(handler["client"].messages as any).create = mockCreate
const result = await handler.completePrompt('Test prompt');
expect(result).toBe('');
});
});
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})
describe('getModel', () => {
it('should return correct model info', () => {
const modelInfo = handler.getModel();
expect(modelInfo.id).toBe('claude-3-5-sonnet-v2@20241022');
expect(modelInfo.info).toBeDefined();
expect(modelInfo.info.maxTokens).toBe(8192);
expect(modelInfo.info.contextWindow).toBe(200_000);
});
it("should handle empty response", async () => {
const mockCreate = jest.fn().mockResolvedValue({
content: [{ type: "text", text: "" }],
})
;(handler["client"].messages as any).create = mockCreate
it('should return default model if invalid model specified', () => {
const invalidHandler = new VertexHandler({
apiModelId: 'invalid-model',
vertexProjectId: 'test-project',
vertexRegion: 'us-central1'
});
const modelInfo = invalidHandler.getModel();
expect(modelInfo.id).toBe('claude-3-5-sonnet-v2@20241022'); // Default model
});
});
});
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})
})
describe("getModel", () => {
it("should return correct model info", () => {
const modelInfo = handler.getModel()
expect(modelInfo.id).toBe("claude-3-5-sonnet-v2@20241022")
expect(modelInfo.info).toBeDefined()
expect(modelInfo.info.maxTokens).toBe(8192)
expect(modelInfo.info.contextWindow).toBe(200_000)
})
it("should return default model if invalid model specified", () => {
const invalidHandler = new VertexHandler({
apiModelId: "invalid-model",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
})
const modelInfo = invalidHandler.getModel()
expect(modelInfo.id).toBe("claude-3-5-sonnet-v2@20241022") // Default model
})
})
})

View File

@@ -1,289 +1,295 @@
import * as vscode from 'vscode';
import { VsCodeLmHandler } from '../vscode-lm';
import { ApiHandlerOptions } from '../../../shared/api';
import { Anthropic } from '@anthropic-ai/sdk';
import * as vscode from "vscode"
import { VsCodeLmHandler } from "../vscode-lm"
import { ApiHandlerOptions } from "../../../shared/api"
import { Anthropic } from "@anthropic-ai/sdk"
// Mock vscode namespace
jest.mock('vscode', () => {
jest.mock("vscode", () => {
class MockLanguageModelTextPart {
type = 'text';
type = "text"
constructor(public value: string) {}
}
class MockLanguageModelToolCallPart {
type = 'tool_call';
type = "tool_call"
constructor(
public callId: string,
public name: string,
public input: any
public input: any,
) {}
}
return {
workspace: {
onDidChangeConfiguration: jest.fn((callback) => ({
dispose: jest.fn()
}))
dispose: jest.fn(),
})),
},
CancellationTokenSource: jest.fn(() => ({
token: {
isCancellationRequested: false,
onCancellationRequested: jest.fn()
onCancellationRequested: jest.fn(),
},
cancel: jest.fn(),
dispose: jest.fn()
dispose: jest.fn(),
})),
CancellationError: class CancellationError extends Error {
constructor() {
super('Operation cancelled');
this.name = 'CancellationError';
super("Operation cancelled")
this.name = "CancellationError"
}
},
LanguageModelChatMessage: {
Assistant: jest.fn((content) => ({
role: 'assistant',
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
role: "assistant",
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
})),
User: jest.fn((content) => ({
role: 'user',
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
}))
role: "user",
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
})),
},
LanguageModelTextPart: MockLanguageModelTextPart,
LanguageModelToolCallPart: MockLanguageModelToolCallPart,
lm: {
selectChatModels: jest.fn()
}
};
});
selectChatModels: jest.fn(),
},
}
})
const mockLanguageModelChat = {
id: 'test-model',
name: 'Test Model',
vendor: 'test-vendor',
family: 'test-family',
version: '1.0',
id: "test-model",
name: "Test Model",
vendor: "test-vendor",
family: "test-family",
version: "1.0",
maxInputTokens: 4096,
sendRequest: jest.fn(),
countTokens: jest.fn()
};
countTokens: jest.fn(),
}
describe('VsCodeLmHandler', () => {
let handler: VsCodeLmHandler;
describe("VsCodeLmHandler", () => {
let handler: VsCodeLmHandler
const defaultOptions: ApiHandlerOptions = {
vsCodeLmModelSelector: {
vendor: 'test-vendor',
family: 'test-family'
}
};
vendor: "test-vendor",
family: "test-family",
},
}
beforeEach(() => {
jest.clearAllMocks();
handler = new VsCodeLmHandler(defaultOptions);
});
jest.clearAllMocks()
handler = new VsCodeLmHandler(defaultOptions)
})
afterEach(() => {
handler.dispose();
});
handler.dispose()
})
describe('constructor', () => {
it('should initialize with provided options', () => {
expect(handler).toBeDefined();
expect(vscode.workspace.onDidChangeConfiguration).toHaveBeenCalled();
});
describe("constructor", () => {
it("should initialize with provided options", () => {
expect(handler).toBeDefined()
expect(vscode.workspace.onDidChangeConfiguration).toHaveBeenCalled()
})
it('should handle configuration changes', () => {
const callback = (vscode.workspace.onDidChangeConfiguration as jest.Mock).mock.calls[0][0];
callback({ affectsConfiguration: () => true });
it("should handle configuration changes", () => {
const callback = (vscode.workspace.onDidChangeConfiguration as jest.Mock).mock.calls[0][0]
callback({ affectsConfiguration: () => true })
// Should reset client when config changes
expect(handler['client']).toBeNull();
});
});
expect(handler["client"]).toBeNull()
})
})
describe('createClient', () => {
it('should create client with selector', async () => {
const mockModel = { ...mockLanguageModelChat };
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
describe("createClient", () => {
it("should create client with selector", async () => {
const mockModel = { ...mockLanguageModelChat }
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
const client = await handler['createClient']({
vendor: 'test-vendor',
family: 'test-family'
});
const client = await handler["createClient"]({
vendor: "test-vendor",
family: "test-family",
})
expect(client).toBeDefined();
expect(client.id).toBe('test-model');
expect(client).toBeDefined()
expect(client.id).toBe("test-model")
expect(vscode.lm.selectChatModels).toHaveBeenCalledWith({
vendor: 'test-vendor',
family: 'test-family'
});
});
vendor: "test-vendor",
family: "test-family",
})
})
it('should return default client when no models available', async () => {
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([]);
it("should return default client when no models available", async () => {
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([])
const client = await handler['createClient']({});
expect(client).toBeDefined();
expect(client.id).toBe('default-lm');
expect(client.vendor).toBe('vscode');
});
});
const client = await handler["createClient"]({})
describe('createMessage', () => {
expect(client).toBeDefined()
expect(client.id).toBe("default-lm")
expect(client.vendor).toBe("vscode")
})
})
describe("createMessage", () => {
beforeEach(() => {
const mockModel = { ...mockLanguageModelChat };
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
mockLanguageModelChat.countTokens.mockResolvedValue(10);
});
const mockModel = { ...mockLanguageModelChat }
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
mockLanguageModelChat.countTokens.mockResolvedValue(10)
})
it('should stream text responses', async () => {
const systemPrompt = 'You are a helpful assistant';
const messages: Anthropic.Messages.MessageParam[] = [{
role: 'user' as const,
content: 'Hello'
}];
it("should stream text responses", async () => {
const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user" as const,
content: "Hello",
},
]
const responseText = 'Hello! How can I help you?';
const responseText = "Hello! How can I help you?"
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
stream: (async function* () {
yield new vscode.LanguageModelTextPart(responseText);
return;
yield new vscode.LanguageModelTextPart(responseText)
return
})(),
text: (async function* () {
yield responseText;
return;
})()
});
yield responseText
return
})(),
})
const stream = handler.createMessage(systemPrompt, messages);
const chunks = [];
const stream = handler.createMessage(systemPrompt, messages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk);
chunks.push(chunk)
}
expect(chunks).toHaveLength(2); // Text chunk + usage chunk
expect(chunks).toHaveLength(2) // Text chunk + usage chunk
expect(chunks[0]).toEqual({
type: 'text',
text: responseText
});
type: "text",
text: responseText,
})
expect(chunks[1]).toMatchObject({
type: 'usage',
type: "usage",
inputTokens: expect.any(Number),
outputTokens: expect.any(Number)
});
});
outputTokens: expect.any(Number),
})
})
it('should handle tool calls', async () => {
const systemPrompt = 'You are a helpful assistant';
const messages: Anthropic.Messages.MessageParam[] = [{
role: 'user' as const,
content: 'Calculate 2+2'
}];
it("should handle tool calls", async () => {
const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user" as const,
content: "Calculate 2+2",
},
]
const toolCallData = {
name: 'calculator',
arguments: { operation: 'add', numbers: [2, 2] },
callId: 'call-1'
};
name: "calculator",
arguments: { operation: "add", numbers: [2, 2] },
callId: "call-1",
}
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
stream: (async function* () {
yield new vscode.LanguageModelToolCallPart(
toolCallData.callId,
toolCallData.name,
toolCallData.arguments
);
return;
toolCallData.arguments,
)
return
})(),
text: (async function* () {
yield JSON.stringify({ type: 'tool_call', ...toolCallData });
return;
})()
});
yield JSON.stringify({ type: "tool_call", ...toolCallData })
return
})(),
})
const stream = handler.createMessage(systemPrompt, messages);
const chunks = [];
const stream = handler.createMessage(systemPrompt, messages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk);
chunks.push(chunk)
}
expect(chunks).toHaveLength(2); // Tool call chunk + usage chunk
expect(chunks).toHaveLength(2) // Tool call chunk + usage chunk
expect(chunks[0]).toEqual({
type: 'text',
text: JSON.stringify({ type: 'tool_call', ...toolCallData })
});
});
type: "text",
text: JSON.stringify({ type: "tool_call", ...toolCallData }),
})
})
it('should handle errors', async () => {
const systemPrompt = 'You are a helpful assistant';
const messages: Anthropic.Messages.MessageParam[] = [{
role: 'user' as const,
content: 'Hello'
}];
it("should handle errors", async () => {
const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user" as const,
content: "Hello",
},
]
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error('API Error'));
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("API Error"))
await expect(async () => {
const stream = handler.createMessage(systemPrompt, messages);
const stream = handler.createMessage(systemPrompt, messages)
for await (const _ of stream) {
// consume stream
}
}).rejects.toThrow('API Error');
});
});
}).rejects.toThrow("API Error")
})
})
describe("getModel", () => {
it("should return model info when client exists", async () => {
const mockModel = { ...mockLanguageModelChat }
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
describe('getModel', () => {
it('should return model info when client exists', async () => {
const mockModel = { ...mockLanguageModelChat };
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
// Initialize client
await handler['getClient']();
const model = handler.getModel();
expect(model.id).toBe('test-model');
expect(model.info).toBeDefined();
expect(model.info.contextWindow).toBe(4096);
});
await handler["getClient"]()
it('should return fallback model info when no client exists', () => {
const model = handler.getModel();
expect(model.id).toBe('test-vendor/test-family');
expect(model.info).toBeDefined();
});
});
const model = handler.getModel()
expect(model.id).toBe("test-model")
expect(model.info).toBeDefined()
expect(model.info.contextWindow).toBe(4096)
})
describe('completePrompt', () => {
it('should complete single prompt', async () => {
const mockModel = { ...mockLanguageModelChat };
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
it("should return fallback model info when no client exists", () => {
const model = handler.getModel()
expect(model.id).toBe("test-vendor/test-family")
expect(model.info).toBeDefined()
})
})
const responseText = 'Completed text';
describe("completePrompt", () => {
it("should complete single prompt", async () => {
const mockModel = { ...mockLanguageModelChat }
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
const responseText = "Completed text"
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
stream: (async function* () {
yield new vscode.LanguageModelTextPart(responseText);
return;
yield new vscode.LanguageModelTextPart(responseText)
return
})(),
text: (async function* () {
yield responseText;
return;
})()
});
yield responseText
return
})(),
})
const result = await handler.completePrompt('Test prompt');
expect(result).toBe(responseText);
expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled();
});
const result = await handler.completePrompt("Test prompt")
expect(result).toBe(responseText)
expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled()
})
it('should handle errors during completion', async () => {
const mockModel = { ...mockLanguageModelChat };
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
it("should handle errors during completion", async () => {
const mockModel = { ...mockLanguageModelChat }
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error('Completion failed'));
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("Completion failed"))
await expect(handler.completePrompt('Test prompt'))
.rejects
.toThrow('VSCode LM completion error: Completion failed');
});
});
});
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
"VSCode LM completion error: Completion failed",
)
})
})
})

View File

@@ -181,14 +181,14 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
max_tokens: this.getModel().info.maxTokens || 8192,
temperature: 0,
messages: [{ role: "user", content: prompt }],
stream: false
stream: false,
})
const content = response.content[0]
if (content.type === 'text') {
if (content.type === "text") {
return content.text
}
return ''
return ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`Anthropic completion error: ${error.message}`)

View File

@@ -1,4 +1,9 @@
import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime"
import {
BedrockRuntimeClient,
ConverseStreamCommand,
ConverseCommand,
BedrockRuntimeClientConfig,
} from "@aws-sdk/client-bedrock-runtime"
import { Anthropic } from "@anthropic-ai/sdk"
import { ApiHandler, SingleCompletionHandler } from "../"
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
@@ -7,275 +12,276 @@ import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../
// Define types for stream events based on AWS SDK
export interface StreamEvent {
messageStart?: {
role?: string;
};
messageStop?: {
stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence";
additionalModelResponseFields?: Record<string, unknown>;
};
contentBlockStart?: {
start?: {
text?: string;
};
contentBlockIndex?: number;
};
contentBlockDelta?: {
delta?: {
text?: string;
};
contentBlockIndex?: number;
};
metadata?: {
usage?: {
inputTokens: number;
outputTokens: number;
totalTokens?: number; // Made optional since we don't use it
};
metrics?: {
latencyMs: number;
};
};
messageStart?: {
role?: string
}
messageStop?: {
stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence"
additionalModelResponseFields?: Record<string, unknown>
}
contentBlockStart?: {
start?: {
text?: string
}
contentBlockIndex?: number
}
contentBlockDelta?: {
delta?: {
text?: string
}
contentBlockIndex?: number
}
metadata?: {
usage?: {
inputTokens: number
outputTokens: number
totalTokens?: number // Made optional since we don't use it
}
metrics?: {
latencyMs: number
}
}
}
export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
private client: BedrockRuntimeClient
private options: ApiHandlerOptions
private client: BedrockRuntimeClient
constructor(options: ApiHandlerOptions) {
this.options = options
// Only include credentials if they actually exist
const clientConfig: BedrockRuntimeClientConfig = {
region: this.options.awsRegion || "us-east-1"
}
constructor(options: ApiHandlerOptions) {
this.options = options
if (this.options.awsAccessKey && this.options.awsSecretKey) {
// Create credentials object with all properties at once
clientConfig.credentials = {
accessKeyId: this.options.awsAccessKey,
secretAccessKey: this.options.awsSecretKey,
...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {})
}
}
// Only include credentials if they actually exist
const clientConfig: BedrockRuntimeClientConfig = {
region: this.options.awsRegion || "us-east-1",
}
this.client = new BedrockRuntimeClient(clientConfig)
}
if (this.options.awsAccessKey && this.options.awsSecretKey) {
// Create credentials object with all properties at once
clientConfig.credentials = {
accessKeyId: this.options.awsAccessKey,
secretAccessKey: this.options.awsSecretKey,
...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {}),
}
}
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
const modelConfig = this.getModel()
// Handle cross-region inference
let modelId: string
if (this.options.awsUseCrossRegionInference) {
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
switch (regionPrefix) {
case "us-":
modelId = `us.${modelConfig.id}`
break
case "eu-":
modelId = `eu.${modelConfig.id}`
break
default:
modelId = modelConfig.id
break
}
} else {
modelId = modelConfig.id
}
this.client = new BedrockRuntimeClient(clientConfig)
}
// Convert messages to Bedrock format
const formattedMessages = convertToBedrockConverseMessages(messages)
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
const modelConfig = this.getModel()
// Construct the payload
const payload = {
modelId,
messages: formattedMessages,
system: [{ text: systemPrompt }],
inferenceConfig: {
maxTokens: modelConfig.info.maxTokens || 5000,
temperature: 0.3,
topP: 0.1,
...(this.options.awsUsePromptCache ? {
promptCache: {
promptCacheId: this.options.awspromptCacheId || ""
}
} : {})
}
}
// Handle cross-region inference
let modelId: string
if (this.options.awsUseCrossRegionInference) {
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
switch (regionPrefix) {
case "us-":
modelId = `us.${modelConfig.id}`
break
case "eu-":
modelId = `eu.${modelConfig.id}`
break
default:
modelId = modelConfig.id
break
}
} else {
modelId = modelConfig.id
}
try {
const command = new ConverseStreamCommand(payload)
const response = await this.client.send(command)
// Convert messages to Bedrock format
const formattedMessages = convertToBedrockConverseMessages(messages)
if (!response.stream) {
throw new Error('No stream available in the response')
}
// Construct the payload
const payload = {
modelId,
messages: formattedMessages,
system: [{ text: systemPrompt }],
inferenceConfig: {
maxTokens: modelConfig.info.maxTokens || 5000,
temperature: 0.3,
topP: 0.1,
...(this.options.awsUsePromptCache
? {
promptCache: {
promptCacheId: this.options.awspromptCacheId || "",
},
}
: {}),
},
}
for await (const chunk of response.stream) {
// Parse the chunk as JSON if it's a string (for tests)
let streamEvent: StreamEvent
try {
streamEvent = typeof chunk === 'string' ?
JSON.parse(chunk) :
chunk as unknown as StreamEvent
} catch (e) {
console.error('Failed to parse stream event:', e)
continue
}
try {
const command = new ConverseStreamCommand(payload)
const response = await this.client.send(command)
// Handle metadata events first
if (streamEvent.metadata?.usage) {
yield {
type: "usage",
inputTokens: streamEvent.metadata.usage.inputTokens || 0,
outputTokens: streamEvent.metadata.usage.outputTokens || 0
}
continue
}
if (!response.stream) {
throw new Error("No stream available in the response")
}
// Handle message start
if (streamEvent.messageStart) {
continue
}
for await (const chunk of response.stream) {
// Parse the chunk as JSON if it's a string (for tests)
let streamEvent: StreamEvent
try {
streamEvent = typeof chunk === "string" ? JSON.parse(chunk) : (chunk as unknown as StreamEvent)
} catch (e) {
console.error("Failed to parse stream event:", e)
continue
}
// Handle content blocks
if (streamEvent.contentBlockStart?.start?.text) {
yield {
type: "text",
text: streamEvent.contentBlockStart.start.text
}
continue
}
// Handle metadata events first
if (streamEvent.metadata?.usage) {
yield {
type: "usage",
inputTokens: streamEvent.metadata.usage.inputTokens || 0,
outputTokens: streamEvent.metadata.usage.outputTokens || 0,
}
continue
}
// Handle content deltas
if (streamEvent.contentBlockDelta?.delta?.text) {
yield {
type: "text",
text: streamEvent.contentBlockDelta.delta.text
}
continue
}
// Handle message start
if (streamEvent.messageStart) {
continue
}
// Handle message stop
if (streamEvent.messageStop) {
continue
}
}
// Handle content blocks
if (streamEvent.contentBlockStart?.start?.text) {
yield {
type: "text",
text: streamEvent.contentBlockStart.start.text,
}
continue
}
} catch (error: unknown) {
console.error('Bedrock Runtime API Error:', error)
// Only access stack if error is an Error object
if (error instanceof Error) {
console.error('Error stack:', error.stack)
yield {
type: "text",
text: `Error: ${error.message}`
}
yield {
type: "usage",
inputTokens: 0,
outputTokens: 0
}
throw error
} else {
const unknownError = new Error("An unknown error occurred")
yield {
type: "text",
text: unknownError.message
}
yield {
type: "usage",
inputTokens: 0,
outputTokens: 0
}
throw unknownError
}
}
}
// Handle content deltas
if (streamEvent.contentBlockDelta?.delta?.text) {
yield {
type: "text",
text: streamEvent.contentBlockDelta.delta.text,
}
continue
}
getModel(): { id: BedrockModelId | string; info: ModelInfo } {
const modelId = this.options.apiModelId
if (modelId) {
// For tests, allow any model ID
if (process.env.NODE_ENV === 'test') {
return {
id: modelId,
info: {
maxTokens: 5000,
contextWindow: 128_000,
supportsPromptCache: false
}
}
}
// For production, validate against known models
if (modelId in bedrockModels) {
const id = modelId as BedrockModelId
return { id, info: bedrockModels[id] }
}
}
return {
id: bedrockDefaultModelId,
info: bedrockModels[bedrockDefaultModelId]
}
}
// Handle message stop
if (streamEvent.messageStop) {
continue
}
}
} catch (error: unknown) {
console.error("Bedrock Runtime API Error:", error)
// Only access stack if error is an Error object
if (error instanceof Error) {
console.error("Error stack:", error.stack)
yield {
type: "text",
text: `Error: ${error.message}`,
}
yield {
type: "usage",
inputTokens: 0,
outputTokens: 0,
}
throw error
} else {
const unknownError = new Error("An unknown error occurred")
yield {
type: "text",
text: unknownError.message,
}
yield {
type: "usage",
inputTokens: 0,
outputTokens: 0,
}
throw unknownError
}
}
}
async completePrompt(prompt: string): Promise<string> {
try {
const modelConfig = this.getModel()
// Handle cross-region inference
let modelId: string
if (this.options.awsUseCrossRegionInference) {
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
switch (regionPrefix) {
case "us-":
modelId = `us.${modelConfig.id}`
break
case "eu-":
modelId = `eu.${modelConfig.id}`
break
default:
modelId = modelConfig.id
break
}
} else {
modelId = modelConfig.id
}
getModel(): { id: BedrockModelId | string; info: ModelInfo } {
const modelId = this.options.apiModelId
if (modelId) {
// For tests, allow any model ID
if (process.env.NODE_ENV === "test") {
return {
id: modelId,
info: {
maxTokens: 5000,
contextWindow: 128_000,
supportsPromptCache: false,
},
}
}
// For production, validate against known models
if (modelId in bedrockModels) {
const id = modelId as BedrockModelId
return { id, info: bedrockModels[id] }
}
}
return {
id: bedrockDefaultModelId,
info: bedrockModels[bedrockDefaultModelId],
}
}
const payload = {
modelId,
messages: convertToBedrockConverseMessages([{
role: "user",
content: prompt
}]),
inferenceConfig: {
maxTokens: modelConfig.info.maxTokens || 5000,
temperature: 0.3,
topP: 0.1
}
}
async completePrompt(prompt: string): Promise<string> {
try {
const modelConfig = this.getModel()
const command = new ConverseCommand(payload)
const response = await this.client.send(command)
// Handle cross-region inference
let modelId: string
if (this.options.awsUseCrossRegionInference) {
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
switch (regionPrefix) {
case "us-":
modelId = `us.${modelConfig.id}`
break
case "eu-":
modelId = `eu.${modelConfig.id}`
break
default:
modelId = modelConfig.id
break
}
} else {
modelId = modelConfig.id
}
if (response.output && response.output instanceof Uint8Array) {
try {
const outputStr = new TextDecoder().decode(response.output)
const output = JSON.parse(outputStr)
if (output.content) {
return output.content
}
} catch (parseError) {
console.error('Failed to parse Bedrock response:', parseError)
}
}
return ''
} catch (error) {
if (error instanceof Error) {
throw new Error(`Bedrock completion error: ${error.message}`)
}
throw error
}
}
const payload = {
modelId,
messages: convertToBedrockConverseMessages([
{
role: "user",
content: prompt,
},
]),
inferenceConfig: {
maxTokens: modelConfig.info.maxTokens || 5000,
temperature: 0.3,
topP: 0.1,
},
}
const command = new ConverseCommand(payload)
const response = await this.client.send(command)
if (response.output && response.output instanceof Uint8Array) {
try {
const outputStr = new TextDecoder().decode(response.output)
const output = JSON.parse(outputStr)
if (output.content) {
return output.content
}
} catch (parseError) {
console.error("Failed to parse Bedrock response:", parseError)
}
}
return ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`Bedrock completion error: ${error.message}`)
}
throw error
}
}
}

View File

@@ -3,24 +3,24 @@ import { ApiHandlerOptions, ModelInfo } from "../../shared/api"
import { deepSeekModels, deepSeekDefaultModelId } from "../../shared/api"
export class DeepSeekHandler extends OpenAiHandler {
constructor(options: ApiHandlerOptions) {
if (!options.deepSeekApiKey) {
throw new Error("DeepSeek API key is required. Please provide it in the settings.")
}
super({
...options,
openAiApiKey: options.deepSeekApiKey,
openAiModelId: options.deepSeekModelId ?? deepSeekDefaultModelId,
openAiBaseUrl: options.deepSeekBaseUrl ?? "https://api.deepseek.com/v1",
includeMaxTokens: true
})
}
constructor(options: ApiHandlerOptions) {
if (!options.deepSeekApiKey) {
throw new Error("DeepSeek API key is required. Please provide it in the settings.")
}
super({
...options,
openAiApiKey: options.deepSeekApiKey,
openAiModelId: options.deepSeekModelId ?? deepSeekDefaultModelId,
openAiBaseUrl: options.deepSeekBaseUrl ?? "https://api.deepseek.com/v1",
includeMaxTokens: true,
})
}
override getModel(): { id: string; info: ModelInfo } {
const modelId = this.options.deepSeekModelId ?? deepSeekDefaultModelId
return {
id: modelId,
info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId]
}
}
override getModel(): { id: string; info: ModelInfo } {
const modelId = this.options.deepSeekModelId ?? deepSeekDefaultModelId
return {
id: modelId,
info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId],
}
}
}

View File

@@ -72,17 +72,17 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
maxTokens = 8_192
}
const { data: completion, response } = await this.client.chat.completions.create({
model: this.getModel().id,
max_tokens: maxTokens,
temperature: 0,
messages: openAiMessages,
stream: true,
}).withResponse();
const { data: completion, response } = await this.client.chat.completions
.create({
model: this.getModel().id,
max_tokens: maxTokens,
temperature: 0,
messages: openAiMessages,
stream: true,
})
.withResponse()
const completionRequestId = response.headers.get(
'x-completion-request-id',
);
const completionRequestId = response.headers.get("x-completion-request-id")
for await (const chunk of completion) {
const delta = chunk.choices[0]?.delta
@@ -96,13 +96,16 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
}
try {
const response = await axios.get(`https://glama.ai/api/gateway/v1/completion-requests/${completionRequestId}`, {
headers: {
Authorization: `Bearer ${this.options.glamaApiKey}`,
const response = await axios.get(
`https://glama.ai/api/gateway/v1/completion-requests/${completionRequestId}`,
{
headers: {
Authorization: `Bearer ${this.options.glamaApiKey}`,
},
},
})
)
const completionRequest = response.data;
const completionRequest = response.data
if (completionRequest.tokenUsage) {
yield {
@@ -113,7 +116,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
outputTokens: completionRequest.tokenUsage.completionTokens,
totalCost: parseFloat(completionRequest.totalCostUsd),
}
}
}
} catch (error) {
console.error("Error fetching Glama completion details", error)
}
@@ -126,7 +129,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
if (modelId && modelInfo) {
return { id: modelId, info: modelInfo }
}
return { id: glamaDefaultModelId, info: glamaDefaultModelInfo }
}
@@ -141,7 +144,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
if (this.getModel().id.startsWith("anthropic/")) {
requestOptions.max_tokens = 8192
}
const response = await this.client.chat.completions.create(requestOptions)
return response.choices[0]?.message.content || ""
} catch (error) {

View File

@@ -60,7 +60,7 @@ export class LmStudioHandler implements ApiHandler, SingleCompletionHandler {
model: this.getModel().id,
messages: [{ role: "user", content: prompt }],
temperature: 0,
stream: false
stream: false,
})
return response.choices[0]?.message.content || ""
} catch (error) {

View File

@@ -53,7 +53,7 @@ export class OllamaHandler implements ApiHandler, SingleCompletionHandler {
model: this.getModel().id,
messages: [{ role: "user", content: prompt }],
temperature: 0,
stream: false
stream: false,
})
return response.choices[0]?.message.content || ""
} catch (error) {

View File

@@ -32,7 +32,10 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
// o1 doesnt support streaming or non-1 temp but does support a developer prompt
const response = await this.client.chat.completions.create({
model: modelId,
messages: [{ role: modelId === "o1" ? "developer" : "user", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
messages: [
{ role: modelId === "o1" ? "developer" : "user", content: systemPrompt },
...convertToOpenAiMessages(messages),
],
})
yield {
type: "text",
@@ -98,14 +101,14 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
// o1 doesn't support non-1 temp
requestOptions = {
model: modelId,
messages: [{ role: "user", content: prompt }]
messages: [{ role: "user", content: prompt }],
}
break
default:
requestOptions = {
model: modelId,
messages: [{ role: "user", content: prompt }],
temperature: 0
temperature: 0,
}
}

View File

@@ -17,7 +17,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
constructor(options: ApiHandlerOptions) {
this.options = options
// Azure API shape slightly differs from the core API shape: https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
const urlHost = new URL(this.options.openAiBaseUrl ?? "").host;
const urlHost = new URL(this.options.openAiBaseUrl ?? "").host
if (urlHost === "azure.com" || urlHost.endsWith(".azure.com")) {
this.client = new AzureOpenAI({
baseURL: this.options.openAiBaseUrl,
@@ -39,7 +39,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
if (this.options.openAiStreamingEnabled ?? true) {
const systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = {
role: "system",
content: systemPrompt
content: systemPrompt,
}
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: modelId,
@@ -74,14 +74,14 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
// o1 for instance doesnt support streaming, non-1 temp, or system prompt
const systemMessage: OpenAI.Chat.ChatCompletionUserMessageParam = {
role: "user",
content: systemPrompt
content: systemPrompt,
}
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: modelId,
messages: [systemMessage, ...convertToOpenAiMessages(messages)],
}
const response = await this.client.chat.completions.create(requestOptions)
yield {
type: "text",
text: response.choices[0]?.message.content || "",
@@ -108,7 +108,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
messages: [{ role: "user", content: prompt }],
temperature: 0,
}
const response = await this.client.chat.completions.create(requestOptions)
return response.choices[0]?.message.content || ""
} catch (error) {

View File

@@ -9,12 +9,12 @@ import delay from "delay"
// Add custom interface for OpenRouter params
type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & {
transforms?: string[];
transforms?: string[]
}
// Add custom interface for OpenRouter usage chunk
interface OpenRouterApiStreamUsageChunk extends ApiStreamUsageChunk {
fullResponseText: string;
fullResponseText: string
}
import { SingleCompletionHandler } from ".."
@@ -35,7 +35,10 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
})
}
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): AsyncGenerator<ApiStreamChunk> {
async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
): AsyncGenerator<ApiStreamChunk> {
// Convert Anthropic messages to OpenAI format
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
@@ -108,7 +111,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
break
}
// https://openrouter.ai/docs/transforms
let fullResponseText = "";
let fullResponseText = ""
const stream = await this.client.chat.completions.create({
model: this.getModel().id,
max_tokens: maxTokens,
@@ -116,8 +119,8 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
messages: openAiMessages,
stream: true,
// This way, the transforms field will only be included in the parameters when openRouterUseMiddleOutTransform is true.
...(this.options.openRouterUseMiddleOutTransform && { transforms: ["middle-out"] })
} as OpenRouterChatCompletionParams);
...(this.options.openRouterUseMiddleOutTransform && { transforms: ["middle-out"] }),
} as OpenRouterChatCompletionParams)
let genId: string | undefined
@@ -135,11 +138,11 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
const delta = chunk.choices[0]?.delta
if (delta?.content) {
fullResponseText += delta.content;
fullResponseText += delta.content
yield {
type: "text",
text: delta.content,
} as ApiStreamChunk;
} as ApiStreamChunk
}
// if (chunk.usage) {
// yield {
@@ -170,13 +173,12 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
inputTokens: generation?.native_tokens_prompt || 0,
outputTokens: generation?.native_tokens_completion || 0,
totalCost: generation?.total_cost || 0,
fullResponseText
} as OpenRouterApiStreamUsageChunk;
fullResponseText,
} as OpenRouterApiStreamUsageChunk
} catch (error) {
// ignore if fails
console.error("Error fetching OpenRouter generation details:", error)
}
}
getModel(): { id: string; info: ModelInfo } {
const modelId = this.options.openRouterModelId
@@ -193,7 +195,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
model: this.getModel().id,
messages: [{ role: "user", content: prompt }],
temperature: 0,
stream: false
stream: false,
})
if ("error" in response) {

View File

@@ -91,14 +91,14 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
max_tokens: this.getModel().info.maxTokens || 8192,
temperature: 0,
messages: [{ role: "user", content: prompt }],
stream: false
stream: false,
})
const content = response.content[0]
if (content.type === 'text') {
if (content.type === "text") {
return content.text
}
return ''
return ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`Vertex completion error: ${error.message}`)

View File

@@ -1,31 +1,31 @@
import { Anthropic } from "@anthropic-ai/sdk";
import * as vscode from 'vscode';
import { ApiHandler, SingleCompletionHandler } from "../";
import { calculateApiCost } from "../../utils/cost";
import { ApiStream } from "../transform/stream";
import { convertToVsCodeLmMessages } from "../transform/vscode-lm-format";
import { SELECTOR_SEPARATOR, stringifyVsCodeLmModelSelector } from "../../shared/vsCodeSelectorUtils";
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api";
import { Anthropic } from "@anthropic-ai/sdk"
import * as vscode from "vscode"
import { ApiHandler, SingleCompletionHandler } from "../"
import { calculateApiCost } from "../../utils/cost"
import { ApiStream } from "../transform/stream"
import { convertToVsCodeLmMessages } from "../transform/vscode-lm-format"
import { SELECTOR_SEPARATOR, stringifyVsCodeLmModelSelector } from "../../shared/vsCodeSelectorUtils"
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
/**
* Handles interaction with VS Code's Language Model API for chat-based operations.
* This handler implements the ApiHandler interface to provide VS Code LM specific functionality.
*
*
* @implements {ApiHandler}
*
*
* @remarks
* The handler manages a VS Code language model chat client and provides methods to:
* - Create and manage chat client instances
* - Stream messages using VS Code's Language Model API
* - Retrieve model information
*
*
* @example
* ```typescript
* const options = {
* vsCodeLmModelSelector: { vendor: "copilot", family: "gpt-4" }
* };
* const handler = new VsCodeLmHandler(options);
*
*
* // Stream a conversation
* const systemPrompt = "You are a helpful assistant";
* const messages = [{ role: "user", content: "Hello!" }];
@@ -35,39 +35,36 @@ import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../..
* ```
*/
export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions;
private client: vscode.LanguageModelChat | null;
private disposable: vscode.Disposable | null;
private currentRequestCancellation: vscode.CancellationTokenSource | null;
private options: ApiHandlerOptions
private client: vscode.LanguageModelChat | null
private disposable: vscode.Disposable | null
private currentRequestCancellation: vscode.CancellationTokenSource | null
constructor(options: ApiHandlerOptions) {
this.options = options;
this.client = null;
this.disposable = null;
this.currentRequestCancellation = null;
this.options = options
this.client = null
this.disposable = null
this.currentRequestCancellation = null
try {
// Listen for model changes and reset client
this.disposable = vscode.workspace.onDidChangeConfiguration(event => {
if (event.affectsConfiguration('lm')) {
this.disposable = vscode.workspace.onDidChangeConfiguration((event) => {
if (event.affectsConfiguration("lm")) {
try {
this.client = null;
this.ensureCleanState();
}
catch (error) {
console.error('Error during configuration change cleanup:', error);
this.client = null
this.ensureCleanState()
} catch (error) {
console.error("Error during configuration change cleanup:", error)
}
}
});
}
catch (error) {
})
} catch (error) {
// Ensure cleanup if constructor fails
this.dispose();
this.dispose()
throw new Error(
`Cline <Language Model API>: Failed to initialize handler: ${error instanceof Error ? error.message : 'Unknown error'}`
);
`Cline <Language Model API>: Failed to initialize handler: ${error instanceof Error ? error.message : "Unknown error"}`,
)
}
}
@@ -77,46 +74,46 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
* @param selector - Selector criteria to filter language model chat instances
* @returns Promise resolving to the first matching language model chat instance
* @throws Error when no matching models are found with the given selector
*
*
* @example
* const selector = { vendor: "copilot", family: "gpt-4o" };
* const chatClient = await createClient(selector);
*/
async createClient(selector: vscode.LanguageModelChatSelector): Promise<vscode.LanguageModelChat> {
try {
const models = await vscode.lm.selectChatModels(selector);
const models = await vscode.lm.selectChatModels(selector)
// Use first available model or create a minimal model object
if (models && Array.isArray(models) && models.length > 0) {
return models[0];
return models[0]
}
// Create a minimal model if no models are available
return {
id: 'default-lm',
name: 'Default Language Model',
vendor: 'vscode',
family: 'lm',
version: '1.0',
id: "default-lm",
name: "Default Language Model",
vendor: "vscode",
family: "lm",
version: "1.0",
maxInputTokens: 8192,
sendRequest: async (messages, options, token) => {
// Provide a minimal implementation
return {
stream: (async function* () {
yield new vscode.LanguageModelTextPart(
"Language model functionality is limited. Please check VS Code configuration."
);
"Language model functionality is limited. Please check VS Code configuration.",
)
})(),
text: (async function* () {
yield "Language model functionality is limited. Please check VS Code configuration.";
})()
};
yield "Language model functionality is limited. Please check VS Code configuration."
})(),
}
},
countTokens: async () => 0
};
countTokens: async () => 0,
}
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
throw new Error(`Cline <Language Model API>: Failed to select model: ${errorMessage}`);
const errorMessage = error instanceof Error ? error.message : "Unknown error"
throw new Error(`Cline <Language Model API>: Failed to select model: ${errorMessage}`)
}
}
@@ -125,242 +122,234 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
*
* @param systemPrompt - The system prompt to initialize the conversation context
* @param messages - An array of message parameters following the Anthropic message format
*
*
* @yields {ApiStream} An async generator that yields either text chunks or tool calls from the model response
*
*
* @throws {Error} When vsCodeLmModelSelector option is not provided
* @throws {Error} When the response stream encounters an error
*
*
* @remarks
* This method handles the initialization of the VS Code LM client if not already created,
* converts the messages to VS Code LM format, and streams the response chunks.
* Tool calls handling is currently a work in progress.
*/
dispose(): void {
if (this.disposable) {
this.disposable.dispose();
this.disposable.dispose()
}
if (this.currentRequestCancellation) {
this.currentRequestCancellation.cancel();
this.currentRequestCancellation.dispose();
this.currentRequestCancellation.cancel()
this.currentRequestCancellation.dispose()
}
}
private async countTokens(text: string | vscode.LanguageModelChatMessage): Promise<number> {
// Check for required dependencies
if (!this.client) {
console.warn('Cline <Language Model API>: No client available for token counting');
return 0;
console.warn("Cline <Language Model API>: No client available for token counting")
return 0
}
if (!this.currentRequestCancellation) {
console.warn('Cline <Language Model API>: No cancellation token available for token counting');
return 0;
console.warn("Cline <Language Model API>: No cancellation token available for token counting")
return 0
}
// Validate input
if (!text) {
console.debug('Cline <Language Model API>: Empty text provided for token counting');
return 0;
console.debug("Cline <Language Model API>: Empty text provided for token counting")
return 0
}
try {
// Handle different input types
let tokenCount: number;
let tokenCount: number
if (typeof text === 'string') {
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token);
if (typeof text === "string") {
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
} else if (text instanceof vscode.LanguageModelChatMessage) {
// For chat messages, ensure we have content
if (!text.content || (Array.isArray(text.content) && text.content.length === 0)) {
console.debug('Cline <Language Model API>: Empty chat message content');
return 0;
console.debug("Cline <Language Model API>: Empty chat message content")
return 0
}
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token);
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
} else {
console.warn('Cline <Language Model API>: Invalid input type for token counting');
return 0;
console.warn("Cline <Language Model API>: Invalid input type for token counting")
return 0
}
// Validate the result
if (typeof tokenCount !== 'number') {
console.warn('Cline <Language Model API>: Non-numeric token count received:', tokenCount);
return 0;
if (typeof tokenCount !== "number") {
console.warn("Cline <Language Model API>: Non-numeric token count received:", tokenCount)
return 0
}
if (tokenCount < 0) {
console.warn('Cline <Language Model API>: Negative token count received:', tokenCount);
return 0;
console.warn("Cline <Language Model API>: Negative token count received:", tokenCount)
return 0
}
return tokenCount;
}
catch (error) {
return tokenCount
} catch (error) {
// Handle specific error types
if (error instanceof vscode.CancellationError) {
console.debug('Cline <Language Model API>: Token counting cancelled by user');
return 0;
console.debug("Cline <Language Model API>: Token counting cancelled by user")
return 0
}
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
console.warn('Cline <Language Model API>: Token counting failed:', errorMessage);
const errorMessage = error instanceof Error ? error.message : "Unknown error"
console.warn("Cline <Language Model API>: Token counting failed:", errorMessage)
// Log additional error details if available
if (error instanceof Error && error.stack) {
console.debug('Token counting error stack:', error.stack);
console.debug("Token counting error stack:", error.stack)
}
return 0; // Fallback to prevent stream interruption
return 0 // Fallback to prevent stream interruption
}
}
private async calculateTotalInputTokens(systemPrompt: string, vsCodeLmMessages: vscode.LanguageModelChatMessage[]): Promise<number> {
private async calculateTotalInputTokens(
systemPrompt: string,
vsCodeLmMessages: vscode.LanguageModelChatMessage[],
): Promise<number> {
const systemTokens: number = await this.countTokens(systemPrompt)
const systemTokens: number = await this.countTokens(systemPrompt);
const messageTokens: number[] = await Promise.all(vsCodeLmMessages.map((msg) => this.countTokens(msg)))
const messageTokens: number[] = await Promise.all(
vsCodeLmMessages.map(msg => this.countTokens(msg))
);
return systemTokens + messageTokens.reduce(
(sum: number, tokens: number): number => sum + tokens, 0
);
return systemTokens + messageTokens.reduce((sum: number, tokens: number): number => sum + tokens, 0)
}
private ensureCleanState(): void {
if (this.currentRequestCancellation) {
this.currentRequestCancellation.cancel();
this.currentRequestCancellation.dispose();
this.currentRequestCancellation = null;
this.currentRequestCancellation.cancel()
this.currentRequestCancellation.dispose()
this.currentRequestCancellation = null
}
}
private async getClient(): Promise<vscode.LanguageModelChat> {
if (!this.client) {
console.debug('Cline <Language Model API>: Getting client with options:', {
console.debug("Cline <Language Model API>: Getting client with options:", {
vsCodeLmModelSelector: this.options.vsCodeLmModelSelector,
hasOptions: !!this.options,
selectorKeys: this.options.vsCodeLmModelSelector ? Object.keys(this.options.vsCodeLmModelSelector) : []
});
selectorKeys: this.options.vsCodeLmModelSelector ? Object.keys(this.options.vsCodeLmModelSelector) : [],
})
try {
// Use default empty selector if none provided to get all available models
const selector = this.options?.vsCodeLmModelSelector || {};
console.debug('Cline <Language Model API>: Creating client with selector:', selector);
this.client = await this.createClient(selector);
const selector = this.options?.vsCodeLmModelSelector || {}
console.debug("Cline <Language Model API>: Creating client with selector:", selector)
this.client = await this.createClient(selector)
} catch (error) {
const message = error instanceof Error ? error.message : 'Unknown error';
console.error('Cline <Language Model API>: Client creation failed:', message);
throw new Error(`Cline <Language Model API>: Failed to create client: ${message}`);
const message = error instanceof Error ? error.message : "Unknown error"
console.error("Cline <Language Model API>: Client creation failed:", message)
throw new Error(`Cline <Language Model API>: Failed to create client: ${message}`)
}
}
return this.client;
return this.client
}
private cleanTerminalOutput(text: string): string {
if (!text) {
return '';
return ""
}
return text
// Нормализуем переносы строк
.replace(/\r\n/g, '\n')
.replace(/\r/g, '\n')
return (
text
// Нормализуем переносы строк
.replace(/\r\n/g, "\n")
.replace(/\r/g, "\n")
// Удаляем ANSI escape sequences
.replace(/\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])/g, '') // Полный набор ANSI sequences
.replace(/\x9B[0-?]*[ -/]*[@-~]/g, '') // CSI sequences
// Удаляем ANSI escape sequences
.replace(/\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])/g, "") // Полный набор ANSI sequences
.replace(/\x9B[0-?]*[ -/]*[@-~]/g, "") // CSI sequences
// Удаляем последовательности установки заголовка терминала и прочие OSC sequences
.replace(/\x1B\][0-9;]*(?:\x07|\x1B\\)/g, '')
// Удаляем последовательности установки заголовка терминала и прочие OSC sequences
.replace(/\x1B\][0-9;]*(?:\x07|\x1B\\)/g, "")
// Удаляем управляющие символы
.replace(/[\x00-\x09\x0B-\x0C\x0E-\x1F\x7F]/g, '')
// Удаляем управляющие символы
.replace(/[\x00-\x09\x0B-\x0C\x0E-\x1F\x7F]/g, "")
// Удаляем escape-последовательности VS Code
.replace(/\x1B[PD].*?\x1B\\/g, '') // DCS sequences
.replace(/\x1B_.*?\x1B\\/g, '') // APC sequences
.replace(/\x1B\^.*?\x1B\\/g, '') // PM sequences
.replace(/\x1B\[[\d;]*[HfABCDEFGJKST]/g, '') // Cursor movement and clear screen
// Удаляем escape-последовательности VS Code
.replace(/\x1B[PD].*?\x1B\\/g, "") // DCS sequences
.replace(/\x1B_.*?\x1B\\/g, "") // APC sequences
.replace(/\x1B\^.*?\x1B\\/g, "") // PM sequences
.replace(/\x1B\[[\d;]*[HfABCDEFGJKST]/g, "") // Cursor movement and clear screen
// Удаляем пути Windows и служебную информацию
.replace(/^(?:PS )?[A-Z]:\\[^\n]*$/mg, '')
.replace(/^;?Cwd=.*$/mg, '')
// Удаляем пути Windows и служебную информацию
.replace(/^(?:PS )?[A-Z]:\\[^\n]*$/gm, "")
.replace(/^;?Cwd=.*$/gm, "")
// Очищаем экранированные последовательности
.replace(/\\x[0-9a-fA-F]{2}/g, '')
.replace(/\\u[0-9a-fA-F]{4}/g, '')
// Очищаем экранированные последовательности
.replace(/\\x[0-9a-fA-F]{2}/g, "")
.replace(/\\u[0-9a-fA-F]{4}/g, "")
// Финальная очистка
.replace(/\n{3,}/g, '\n\n') // Убираем множественные пустые строки
.trim();
// Финальная очистка
.replace(/\n{3,}/g, "\n\n") // Убираем множественные пустые строки
.trim()
)
}
private cleanMessageContent(content: any): any {
if (!content) {
return content;
return content
}
if (typeof content === 'string') {
return this.cleanTerminalOutput(content);
if (typeof content === "string") {
return this.cleanTerminalOutput(content)
}
if (Array.isArray(content)) {
return content.map(item => this.cleanMessageContent(item));
return content.map((item) => this.cleanMessageContent(item))
}
if (typeof content === 'object') {
const cleaned: any = {};
if (typeof content === "object") {
const cleaned: any = {}
for (const [key, value] of Object.entries(content)) {
cleaned[key] = this.cleanMessageContent(value);
cleaned[key] = this.cleanMessageContent(value)
}
return cleaned;
return cleaned
}
return content;
return content
}
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
// Ensure clean state before starting a new request
this.ensureCleanState();
const client: vscode.LanguageModelChat = await this.getClient();
this.ensureCleanState()
const client: vscode.LanguageModelChat = await this.getClient()
// Clean system prompt and messages
const cleanedSystemPrompt = this.cleanTerminalOutput(systemPrompt);
const cleanedMessages = messages.map(msg => ({
const cleanedSystemPrompt = this.cleanTerminalOutput(systemPrompt)
const cleanedMessages = messages.map((msg) => ({
...msg,
content: this.cleanMessageContent(msg.content)
}));
content: this.cleanMessageContent(msg.content),
}))
// Convert Anthropic messages to VS Code LM messages
const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = [
vscode.LanguageModelChatMessage.Assistant(cleanedSystemPrompt),
...convertToVsCodeLmMessages(cleanedMessages),
];
]
// Initialize cancellation token for the request
this.currentRequestCancellation = new vscode.CancellationTokenSource();
this.currentRequestCancellation = new vscode.CancellationTokenSource()
// Calculate input tokens before starting the stream
const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages);
const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages)
// Accumulate the text and count at the end of the stream to reduce token counting overhead.
let accumulatedText: string = '';
let accumulatedText: string = ""
try {
// Create the response stream with minimal required options
const requestOptions: vscode.LanguageModelChatRequestOptions = {
justification: `Cline would like to use '${client.name}' from '${client.vendor}', Click 'Allow' to proceed.`
};
justification: `Cline would like to use '${client.name}' from '${client.vendor}', Click 'Allow' to proceed.`,
}
// Note: Tool support is currently provided by the VSCode Language Model API directly
// Extensions can register tools using vscode.lm.registerTool()
@@ -368,40 +357,40 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
const response: vscode.LanguageModelChatResponse = await client.sendRequest(
vsCodeLmMessages,
requestOptions,
this.currentRequestCancellation.token
);
this.currentRequestCancellation.token,
)
// Consume the stream and handle both text and tool call chunks
for await (const chunk of response.stream) {
if (chunk instanceof vscode.LanguageModelTextPart) {
// Validate text part value
if (typeof chunk.value !== 'string') {
console.warn('Cline <Language Model API>: Invalid text part value received:', chunk.value);
continue;
if (typeof chunk.value !== "string") {
console.warn("Cline <Language Model API>: Invalid text part value received:", chunk.value)
continue
}
accumulatedText += chunk.value;
accumulatedText += chunk.value
yield {
type: "text",
text: chunk.value,
};
}
} else if (chunk instanceof vscode.LanguageModelToolCallPart) {
try {
// Validate tool call parameters
if (!chunk.name || typeof chunk.name !== 'string') {
console.warn('Cline <Language Model API>: Invalid tool name received:', chunk.name);
continue;
if (!chunk.name || typeof chunk.name !== "string") {
console.warn("Cline <Language Model API>: Invalid tool name received:", chunk.name)
continue
}
if (!chunk.callId || typeof chunk.callId !== 'string') {
console.warn('Cline <Language Model API>: Invalid tool callId received:', chunk.callId);
continue;
if (!chunk.callId || typeof chunk.callId !== "string") {
console.warn("Cline <Language Model API>: Invalid tool callId received:", chunk.callId)
continue
}
// Ensure input is a valid object
if (!chunk.input || typeof chunk.input !== 'object') {
console.warn('Cline <Language Model API>: Invalid tool input received:', chunk.input);
continue;
if (!chunk.input || typeof chunk.input !== "object") {
console.warn("Cline <Language Model API>: Invalid tool input received:", chunk.input)
continue
}
// Convert tool calls to text format with proper error handling
@@ -409,82 +398,75 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
type: "tool_call",
name: chunk.name,
arguments: chunk.input,
callId: chunk.callId
};
callId: chunk.callId,
}
const toolCallText = JSON.stringify(toolCall);
accumulatedText += toolCallText;
const toolCallText = JSON.stringify(toolCall)
accumulatedText += toolCallText
// Log tool call for debugging
console.debug('Cline <Language Model API>: Processing tool call:', {
console.debug("Cline <Language Model API>: Processing tool call:", {
name: chunk.name,
callId: chunk.callId,
inputSize: JSON.stringify(chunk.input).length
});
inputSize: JSON.stringify(chunk.input).length,
})
yield {
type: "text",
text: toolCallText,
};
}
} catch (error) {
console.error('Cline <Language Model API>: Failed to process tool call:', error);
console.error("Cline <Language Model API>: Failed to process tool call:", error)
// Continue processing other chunks even if one fails
continue;
continue
}
} else {
console.warn('Cline <Language Model API>: Unknown chunk type received:', chunk);
console.warn("Cline <Language Model API>: Unknown chunk type received:", chunk)
}
}
// Count tokens in the accumulated text after stream completion
const totalOutputTokens: number = await this.countTokens(accumulatedText);
const totalOutputTokens: number = await this.countTokens(accumulatedText)
// Report final usage after stream completion
yield {
type: "usage",
inputTokens: totalInputTokens,
outputTokens: totalOutputTokens,
totalCost: calculateApiCost(
this.getModel().info,
totalInputTokens,
totalOutputTokens
)
};
}
catch (error: unknown) {
this.ensureCleanState();
totalCost: calculateApiCost(this.getModel().info, totalInputTokens, totalOutputTokens),
}
} catch (error: unknown) {
this.ensureCleanState()
if (error instanceof vscode.CancellationError) {
throw new Error("Cline <Language Model API>: Request cancelled by user");
throw new Error("Cline <Language Model API>: Request cancelled by user")
}
if (error instanceof Error) {
console.error('Cline <Language Model API>: Stream error details:', {
console.error("Cline <Language Model API>: Stream error details:", {
message: error.message,
stack: error.stack,
name: error.name
});
name: error.name,
})
// Return original error if it's already an Error instance
throw error;
} else if (typeof error === 'object' && error !== null) {
throw error
} else if (typeof error === "object" && error !== null) {
// Handle error-like objects
const errorDetails = JSON.stringify(error, null, 2);
console.error('Cline <Language Model API>: Stream error object:', errorDetails);
throw new Error(`Cline <Language Model API>: Response stream error: ${errorDetails}`);
const errorDetails = JSON.stringify(error, null, 2)
console.error("Cline <Language Model API>: Stream error object:", errorDetails)
throw new Error(`Cline <Language Model API>: Response stream error: ${errorDetails}`)
} else {
// Fallback for unknown error types
const errorMessage = String(error);
console.error('Cline <Language Model API>: Unknown stream error:', errorMessage);
throw new Error(`Cline <Language Model API>: Response stream error: ${errorMessage}`);
const errorMessage = String(error)
console.error("Cline <Language Model API>: Unknown stream error:", errorMessage)
throw new Error(`Cline <Language Model API>: Response stream error: ${errorMessage}`)
}
}
}
// Return model information based on the current client state
getModel(): { id: string; info: ModelInfo; } {
getModel(): { id: string; info: ModelInfo } {
if (this.client) {
// Validate client properties
const requiredProps = {
@@ -492,68 +474,69 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
vendor: this.client.vendor,
family: this.client.family,
version: this.client.version,
maxInputTokens: this.client.maxInputTokens
};
maxInputTokens: this.client.maxInputTokens,
}
// Log any missing properties for debugging
for (const [prop, value] of Object.entries(requiredProps)) {
if (!value && value !== 0) {
console.warn(`Cline <Language Model API>: Client missing ${prop} property`);
console.warn(`Cline <Language Model API>: Client missing ${prop} property`)
}
}
// Construct model ID using available information
const modelParts = [
this.client.vendor,
this.client.family,
this.client.version
].filter(Boolean);
const modelParts = [this.client.vendor, this.client.family, this.client.version].filter(Boolean)
const modelId = this.client.id || modelParts.join(SELECTOR_SEPARATOR);
const modelId = this.client.id || modelParts.join(SELECTOR_SEPARATOR)
// Build model info with conservative defaults for missing values
const modelInfo: ModelInfo = {
maxTokens: -1, // Unlimited tokens by default
contextWindow: typeof this.client.maxInputTokens === 'number'
? Math.max(0, this.client.maxInputTokens)
: openAiModelInfoSaneDefaults.contextWindow,
contextWindow:
typeof this.client.maxInputTokens === "number"
? Math.max(0, this.client.maxInputTokens)
: openAiModelInfoSaneDefaults.contextWindow,
supportsImages: false, // VSCode Language Model API currently doesn't support image inputs
supportsPromptCache: true,
inputPrice: 0,
outputPrice: 0,
description: `VSCode Language Model: ${modelId}`
};
description: `VSCode Language Model: ${modelId}`,
}
return { id: modelId, info: modelInfo };
return { id: modelId, info: modelInfo }
}
// Fallback when no client is available
const fallbackId = this.options.vsCodeLmModelSelector
? stringifyVsCodeLmModelSelector(this.options.vsCodeLmModelSelector)
: "vscode-lm";
: "vscode-lm"
console.debug('Cline <Language Model API>: No client available, using fallback model info');
console.debug("Cline <Language Model API>: No client available, using fallback model info")
return {
id: fallbackId,
info: {
...openAiModelInfoSaneDefaults,
description: `VSCode Language Model (Fallback): ${fallbackId}`
}
};
description: `VSCode Language Model (Fallback): ${fallbackId}`,
},
}
}
async completePrompt(prompt: string): Promise<string> {
try {
const client = await this.getClient();
const response = await client.sendRequest([vscode.LanguageModelChatMessage.User(prompt)], {}, new vscode.CancellationTokenSource().token);
let result = "";
const client = await this.getClient()
const response = await client.sendRequest(
[vscode.LanguageModelChatMessage.User(prompt)],
{},
new vscode.CancellationTokenSource().token,
)
let result = ""
for await (const chunk of response.stream) {
if (chunk instanceof vscode.LanguageModelTextPart) {
result += chunk.value;
result += chunk.value
}
}
return result;
return result
} catch (error) {
if (error instanceof Error) {
throw new Error(`VSCode LM completion error: ${error.message}`)

View File

@@ -1,252 +1,250 @@
import { convertToBedrockConverseMessages, convertToAnthropicMessage } from '../bedrock-converse-format'
import { Anthropic } from '@anthropic-ai/sdk'
import { ContentBlock, ToolResultContentBlock } from '@aws-sdk/client-bedrock-runtime'
import { StreamEvent } from '../../providers/bedrock'
import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../bedrock-converse-format"
import { Anthropic } from "@anthropic-ai/sdk"
import { ContentBlock, ToolResultContentBlock } from "@aws-sdk/client-bedrock-runtime"
import { StreamEvent } from "../../providers/bedrock"
describe('bedrock-converse-format', () => {
describe('convertToBedrockConverseMessages', () => {
test('converts simple text messages correctly', () => {
const messages: Anthropic.Messages.MessageParam[] = [
{ role: 'user', content: 'Hello' },
{ role: 'assistant', content: 'Hi there' }
]
describe("bedrock-converse-format", () => {
describe("convertToBedrockConverseMessages", () => {
test("converts simple text messages correctly", () => {
const messages: Anthropic.Messages.MessageParam[] = [
{ role: "user", content: "Hello" },
{ role: "assistant", content: "Hi there" },
]
const result = convertToBedrockConverseMessages(messages)
const result = convertToBedrockConverseMessages(messages)
expect(result).toEqual([
{
role: 'user',
content: [{ text: 'Hello' }]
},
{
role: 'assistant',
content: [{ text: 'Hi there' }]
}
])
})
expect(result).toEqual([
{
role: "user",
content: [{ text: "Hello" }],
},
{
role: "assistant",
content: [{ text: "Hi there" }],
},
])
})
test('converts messages with images correctly', () => {
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: [
{
type: 'text',
text: 'Look at this image:'
},
{
type: 'image',
source: {
type: 'base64',
data: 'SGVsbG8=', // "Hello" in base64
media_type: 'image/jpeg' as const
}
}
]
}
]
test("converts messages with images correctly", () => {
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [
{
type: "text",
text: "Look at this image:",
},
{
type: "image",
source: {
type: "base64",
data: "SGVsbG8=", // "Hello" in base64
media_type: "image/jpeg" as const,
},
},
],
},
]
const result = convertToBedrockConverseMessages(messages)
const result = convertToBedrockConverseMessages(messages)
if (!result[0] || !result[0].content) {
fail('Expected result to have content')
return
}
if (!result[0] || !result[0].content) {
fail("Expected result to have content")
return
}
expect(result[0].role).toBe('user')
expect(result[0].content).toHaveLength(2)
expect(result[0].content[0]).toEqual({ text: 'Look at this image:' })
const imageBlock = result[0].content[1] as ContentBlock
if ('image' in imageBlock && imageBlock.image && imageBlock.image.source) {
expect(imageBlock.image.format).toBe('jpeg')
expect(imageBlock.image.source).toBeDefined()
expect(imageBlock.image.source.bytes).toBeDefined()
} else {
fail('Expected image block not found')
}
})
expect(result[0].role).toBe("user")
expect(result[0].content).toHaveLength(2)
expect(result[0].content[0]).toEqual({ text: "Look at this image:" })
test('converts tool use messages correctly', () => {
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'assistant',
content: [
{
type: 'tool_use',
id: 'test-id',
name: 'read_file',
input: {
path: 'test.txt'
}
}
]
}
]
const imageBlock = result[0].content[1] as ContentBlock
if ("image" in imageBlock && imageBlock.image && imageBlock.image.source) {
expect(imageBlock.image.format).toBe("jpeg")
expect(imageBlock.image.source).toBeDefined()
expect(imageBlock.image.source.bytes).toBeDefined()
} else {
fail("Expected image block not found")
}
})
const result = convertToBedrockConverseMessages(messages)
test("converts tool use messages correctly", () => {
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "assistant",
content: [
{
type: "tool_use",
id: "test-id",
name: "read_file",
input: {
path: "test.txt",
},
},
],
},
]
if (!result[0] || !result[0].content) {
fail('Expected result to have content')
return
}
const result = convertToBedrockConverseMessages(messages)
expect(result[0].role).toBe('assistant')
const toolBlock = result[0].content[0] as ContentBlock
if ('toolUse' in toolBlock && toolBlock.toolUse) {
expect(toolBlock.toolUse).toEqual({
toolUseId: 'test-id',
name: 'read_file',
input: '<read_file>\n<path>\ntest.txt\n</path>\n</read_file>'
})
} else {
fail('Expected tool use block not found')
}
})
if (!result[0] || !result[0].content) {
fail("Expected result to have content")
return
}
test('converts tool result messages correctly', () => {
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'assistant',
content: [
{
type: 'tool_result',
tool_use_id: 'test-id',
content: [{ type: 'text', text: 'File contents here' }]
}
]
}
]
expect(result[0].role).toBe("assistant")
const toolBlock = result[0].content[0] as ContentBlock
if ("toolUse" in toolBlock && toolBlock.toolUse) {
expect(toolBlock.toolUse).toEqual({
toolUseId: "test-id",
name: "read_file",
input: "<read_file>\n<path>\ntest.txt\n</path>\n</read_file>",
})
} else {
fail("Expected tool use block not found")
}
})
const result = convertToBedrockConverseMessages(messages)
test("converts tool result messages correctly", () => {
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "assistant",
content: [
{
type: "tool_result",
tool_use_id: "test-id",
content: [{ type: "text", text: "File contents here" }],
},
],
},
]
if (!result[0] || !result[0].content) {
fail('Expected result to have content')
return
}
const result = convertToBedrockConverseMessages(messages)
expect(result[0].role).toBe('assistant')
const resultBlock = result[0].content[0] as ContentBlock
if ('toolResult' in resultBlock && resultBlock.toolResult) {
const expectedContent: ToolResultContentBlock[] = [
{ text: 'File contents here' }
]
expect(resultBlock.toolResult).toEqual({
toolUseId: 'test-id',
content: expectedContent,
status: 'success'
})
} else {
fail('Expected tool result block not found')
}
})
if (!result[0] || !result[0].content) {
fail("Expected result to have content")
return
}
test('handles text content correctly', () => {
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: [
{
type: 'text',
text: 'Hello world'
}
]
}
]
expect(result[0].role).toBe("assistant")
const resultBlock = result[0].content[0] as ContentBlock
if ("toolResult" in resultBlock && resultBlock.toolResult) {
const expectedContent: ToolResultContentBlock[] = [{ text: "File contents here" }]
expect(resultBlock.toolResult).toEqual({
toolUseId: "test-id",
content: expectedContent,
status: "success",
})
} else {
fail("Expected tool result block not found")
}
})
const result = convertToBedrockConverseMessages(messages)
test("handles text content correctly", () => {
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [
{
type: "text",
text: "Hello world",
},
],
},
]
if (!result[0] || !result[0].content) {
fail('Expected result to have content')
return
}
const result = convertToBedrockConverseMessages(messages)
expect(result[0].role).toBe('user')
expect(result[0].content).toHaveLength(1)
const textBlock = result[0].content[0] as ContentBlock
expect(textBlock).toEqual({ text: 'Hello world' })
})
})
if (!result[0] || !result[0].content) {
fail("Expected result to have content")
return
}
describe('convertToAnthropicMessage', () => {
test('converts metadata events correctly', () => {
const event: StreamEvent = {
metadata: {
usage: {
inputTokens: 10,
outputTokens: 20
}
}
}
expect(result[0].role).toBe("user")
expect(result[0].content).toHaveLength(1)
const textBlock = result[0].content[0] as ContentBlock
expect(textBlock).toEqual({ text: "Hello world" })
})
})
const result = convertToAnthropicMessage(event, 'test-model')
describe("convertToAnthropicMessage", () => {
test("converts metadata events correctly", () => {
const event: StreamEvent = {
metadata: {
usage: {
inputTokens: 10,
outputTokens: 20,
},
},
}
expect(result).toEqual({
id: '',
type: 'message',
role: 'assistant',
model: 'test-model',
usage: {
input_tokens: 10,
output_tokens: 20
}
})
})
const result = convertToAnthropicMessage(event, "test-model")
test('converts content block start events correctly', () => {
const event: StreamEvent = {
contentBlockStart: {
start: {
text: 'Hello'
}
}
}
expect(result).toEqual({
id: "",
type: "message",
role: "assistant",
model: "test-model",
usage: {
input_tokens: 10,
output_tokens: 20,
},
})
})
const result = convertToAnthropicMessage(event, 'test-model')
test("converts content block start events correctly", () => {
const event: StreamEvent = {
contentBlockStart: {
start: {
text: "Hello",
},
},
}
expect(result).toEqual({
type: 'message',
role: 'assistant',
content: [{ type: 'text', text: 'Hello' }],
model: 'test-model'
})
})
const result = convertToAnthropicMessage(event, "test-model")
test('converts content block delta events correctly', () => {
const event: StreamEvent = {
contentBlockDelta: {
delta: {
text: ' world'
}
}
}
expect(result).toEqual({
type: "message",
role: "assistant",
content: [{ type: "text", text: "Hello" }],
model: "test-model",
})
})
const result = convertToAnthropicMessage(event, 'test-model')
test("converts content block delta events correctly", () => {
const event: StreamEvent = {
contentBlockDelta: {
delta: {
text: " world",
},
},
}
expect(result).toEqual({
type: 'message',
role: 'assistant',
content: [{ type: 'text', text: ' world' }],
model: 'test-model'
})
})
const result = convertToAnthropicMessage(event, "test-model")
test('converts message stop events correctly', () => {
const event: StreamEvent = {
messageStop: {
stopReason: 'end_turn' as const
}
}
expect(result).toEqual({
type: "message",
role: "assistant",
content: [{ type: "text", text: " world" }],
model: "test-model",
})
})
const result = convertToAnthropicMessage(event, 'test-model')
test("converts message stop events correctly", () => {
const event: StreamEvent = {
messageStop: {
stopReason: "end_turn" as const,
},
}
expect(result).toEqual({
type: 'message',
role: 'assistant',
stop_reason: 'end_turn',
stop_sequence: null,
model: 'test-model'
})
})
})
const result = convertToAnthropicMessage(event, "test-model")
expect(result).toEqual({
type: "message",
role: "assistant",
stop_reason: "end_turn",
stop_sequence: null,
model: "test-model",
})
})
})
})

View File

@@ -1,257 +1,275 @@
import { convertToOpenAiMessages, convertToAnthropicMessage } from '../openai-format';
import { Anthropic } from '@anthropic-ai/sdk';
import OpenAI from 'openai';
import { convertToOpenAiMessages, convertToAnthropicMessage } from "../openai-format"
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
type PartialChatCompletion = Omit<OpenAI.Chat.Completions.ChatCompletion, 'choices'> & {
choices: Array<Partial<OpenAI.Chat.Completions.ChatCompletion.Choice> & {
message: OpenAI.Chat.Completions.ChatCompletion.Choice['message'];
finish_reason: string;
index: number;
}>;
};
type PartialChatCompletion = Omit<OpenAI.Chat.Completions.ChatCompletion, "choices"> & {
choices: Array<
Partial<OpenAI.Chat.Completions.ChatCompletion.Choice> & {
message: OpenAI.Chat.Completions.ChatCompletion.Choice["message"]
finish_reason: string
index: number
}
>
}
describe('OpenAI Format Transformations', () => {
describe('convertToOpenAiMessages', () => {
it('should convert simple text messages', () => {
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: 'Hello'
},
{
role: 'assistant',
content: 'Hi there!'
}
];
describe("OpenAI Format Transformations", () => {
describe("convertToOpenAiMessages", () => {
it("should convert simple text messages", () => {
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello",
},
{
role: "assistant",
content: "Hi there!",
},
]
const openAiMessages = convertToOpenAiMessages(anthropicMessages);
expect(openAiMessages).toHaveLength(2);
expect(openAiMessages[0]).toEqual({
role: 'user',
content: 'Hello'
});
expect(openAiMessages[1]).toEqual({
role: 'assistant',
content: 'Hi there!'
});
});
const openAiMessages = convertToOpenAiMessages(anthropicMessages)
expect(openAiMessages).toHaveLength(2)
expect(openAiMessages[0]).toEqual({
role: "user",
content: "Hello",
})
expect(openAiMessages[1]).toEqual({
role: "assistant",
content: "Hi there!",
})
})
it('should handle messages with image content', () => {
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: [
{
type: 'text',
text: 'What is in this image?'
},
{
type: 'image',
source: {
type: 'base64',
media_type: 'image/jpeg',
data: 'base64data'
}
}
]
}
];
it("should handle messages with image content", () => {
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [
{
type: "text",
text: "What is in this image?",
},
{
type: "image",
source: {
type: "base64",
media_type: "image/jpeg",
data: "base64data",
},
},
],
},
]
const openAiMessages = convertToOpenAiMessages(anthropicMessages);
expect(openAiMessages).toHaveLength(1);
expect(openAiMessages[0].role).toBe('user');
const content = openAiMessages[0].content as Array<{
type: string;
text?: string;
image_url?: { url: string };
}>;
expect(Array.isArray(content)).toBe(true);
expect(content).toHaveLength(2);
expect(content[0]).toEqual({ type: 'text', text: 'What is in this image?' });
expect(content[1]).toEqual({
type: 'image_url',
image_url: { url: 'data:image/jpeg;base64,base64data' }
});
});
const openAiMessages = convertToOpenAiMessages(anthropicMessages)
expect(openAiMessages).toHaveLength(1)
expect(openAiMessages[0].role).toBe("user")
it('should handle assistant messages with tool use', () => {
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
{
role: 'assistant',
content: [
{
type: 'text',
text: 'Let me check the weather.'
},
{
type: 'tool_use',
id: 'weather-123',
name: 'get_weather',
input: { city: 'London' }
}
]
}
];
const content = openAiMessages[0].content as Array<{
type: string
text?: string
image_url?: { url: string }
}>
const openAiMessages = convertToOpenAiMessages(anthropicMessages);
expect(openAiMessages).toHaveLength(1);
const assistantMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionAssistantMessageParam;
expect(assistantMessage.role).toBe('assistant');
expect(assistantMessage.content).toBe('Let me check the weather.');
expect(assistantMessage.tool_calls).toHaveLength(1);
expect(assistantMessage.tool_calls![0]).toEqual({
id: 'weather-123',
type: 'function',
function: {
name: 'get_weather',
arguments: JSON.stringify({ city: 'London' })
}
});
});
expect(Array.isArray(content)).toBe(true)
expect(content).toHaveLength(2)
expect(content[0]).toEqual({ type: "text", text: "What is in this image?" })
expect(content[1]).toEqual({
type: "image_url",
image_url: { url: "data:image/jpeg;base64,base64data" },
})
})
it('should handle user messages with tool results', () => {
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: 'weather-123',
content: 'Current temperature in London: 20°C'
}
]
}
];
it("should handle assistant messages with tool use", () => {
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
{
role: "assistant",
content: [
{
type: "text",
text: "Let me check the weather.",
},
{
type: "tool_use",
id: "weather-123",
name: "get_weather",
input: { city: "London" },
},
],
},
]
const openAiMessages = convertToOpenAiMessages(anthropicMessages);
expect(openAiMessages).toHaveLength(1);
const toolMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionToolMessageParam;
expect(toolMessage.role).toBe('tool');
expect(toolMessage.tool_call_id).toBe('weather-123');
expect(toolMessage.content).toBe('Current temperature in London: 20°C');
});
});
const openAiMessages = convertToOpenAiMessages(anthropicMessages)
expect(openAiMessages).toHaveLength(1)
describe('convertToAnthropicMessage', () => {
it('should convert simple completion', () => {
const openAiCompletion: PartialChatCompletion = {
id: 'completion-123',
model: 'gpt-4',
choices: [{
message: {
role: 'assistant',
content: 'Hello there!',
refusal: null
},
finish_reason: 'stop',
index: 0
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
},
created: 123456789,
object: 'chat.completion'
};
const assistantMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionAssistantMessageParam
expect(assistantMessage.role).toBe("assistant")
expect(assistantMessage.content).toBe("Let me check the weather.")
expect(assistantMessage.tool_calls).toHaveLength(1)
expect(assistantMessage.tool_calls![0]).toEqual({
id: "weather-123",
type: "function",
function: {
name: "get_weather",
arguments: JSON.stringify({ city: "London" }),
},
})
})
const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion);
expect(anthropicMessage.id).toBe('completion-123');
expect(anthropicMessage.role).toBe('assistant');
expect(anthropicMessage.content).toHaveLength(1);
expect(anthropicMessage.content[0]).toEqual({
type: 'text',
text: 'Hello there!'
});
expect(anthropicMessage.stop_reason).toBe('end_turn');
expect(anthropicMessage.usage).toEqual({
input_tokens: 10,
output_tokens: 5
});
});
it("should handle user messages with tool results", () => {
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [
{
type: "tool_result",
tool_use_id: "weather-123",
content: "Current temperature in London: 20°C",
},
],
},
]
it('should handle tool calls in completion', () => {
const openAiCompletion: PartialChatCompletion = {
id: 'completion-123',
model: 'gpt-4',
choices: [{
message: {
role: 'assistant',
content: 'Let me check the weather.',
tool_calls: [{
id: 'weather-123',
type: 'function',
function: {
name: 'get_weather',
arguments: '{"city":"London"}'
}
}],
refusal: null
},
finish_reason: 'tool_calls',
index: 0
}],
usage: {
prompt_tokens: 15,
completion_tokens: 8,
total_tokens: 23
},
created: 123456789,
object: 'chat.completion'
};
const openAiMessages = convertToOpenAiMessages(anthropicMessages)
expect(openAiMessages).toHaveLength(1)
const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion);
expect(anthropicMessage.content).toHaveLength(2);
expect(anthropicMessage.content[0]).toEqual({
type: 'text',
text: 'Let me check the weather.'
});
expect(anthropicMessage.content[1]).toEqual({
type: 'tool_use',
id: 'weather-123',
name: 'get_weather',
input: { city: 'London' }
});
expect(anthropicMessage.stop_reason).toBe('tool_use');
});
const toolMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionToolMessageParam
expect(toolMessage.role).toBe("tool")
expect(toolMessage.tool_call_id).toBe("weather-123")
expect(toolMessage.content).toBe("Current temperature in London: 20°C")
})
})
it('should handle invalid tool call arguments', () => {
const openAiCompletion: PartialChatCompletion = {
id: 'completion-123',
model: 'gpt-4',
choices: [{
message: {
role: 'assistant',
content: 'Testing invalid arguments',
tool_calls: [{
id: 'test-123',
type: 'function',
function: {
name: 'test_function',
arguments: 'invalid json'
}
}],
refusal: null
},
finish_reason: 'tool_calls',
index: 0
}],
created: 123456789,
object: 'chat.completion'
};
describe("convertToAnthropicMessage", () => {
it("should convert simple completion", () => {
const openAiCompletion: PartialChatCompletion = {
id: "completion-123",
model: "gpt-4",
choices: [
{
message: {
role: "assistant",
content: "Hello there!",
refusal: null,
},
finish_reason: "stop",
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
created: 123456789,
object: "chat.completion",
}
const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion);
expect(anthropicMessage.content).toHaveLength(2);
expect(anthropicMessage.content[1]).toEqual({
type: 'tool_use',
id: 'test-123',
name: 'test_function',
input: {} // Should default to empty object for invalid JSON
});
});
});
});
const anthropicMessage = convertToAnthropicMessage(
openAiCompletion as OpenAI.Chat.Completions.ChatCompletion,
)
expect(anthropicMessage.id).toBe("completion-123")
expect(anthropicMessage.role).toBe("assistant")
expect(anthropicMessage.content).toHaveLength(1)
expect(anthropicMessage.content[0]).toEqual({
type: "text",
text: "Hello there!",
})
expect(anthropicMessage.stop_reason).toBe("end_turn")
expect(anthropicMessage.usage).toEqual({
input_tokens: 10,
output_tokens: 5,
})
})
it("should handle tool calls in completion", () => {
const openAiCompletion: PartialChatCompletion = {
id: "completion-123",
model: "gpt-4",
choices: [
{
message: {
role: "assistant",
content: "Let me check the weather.",
tool_calls: [
{
id: "weather-123",
type: "function",
function: {
name: "get_weather",
arguments: '{"city":"London"}',
},
},
],
refusal: null,
},
finish_reason: "tool_calls",
index: 0,
},
],
usage: {
prompt_tokens: 15,
completion_tokens: 8,
total_tokens: 23,
},
created: 123456789,
object: "chat.completion",
}
const anthropicMessage = convertToAnthropicMessage(
openAiCompletion as OpenAI.Chat.Completions.ChatCompletion,
)
expect(anthropicMessage.content).toHaveLength(2)
expect(anthropicMessage.content[0]).toEqual({
type: "text",
text: "Let me check the weather.",
})
expect(anthropicMessage.content[1]).toEqual({
type: "tool_use",
id: "weather-123",
name: "get_weather",
input: { city: "London" },
})
expect(anthropicMessage.stop_reason).toBe("tool_use")
})
it("should handle invalid tool call arguments", () => {
const openAiCompletion: PartialChatCompletion = {
id: "completion-123",
model: "gpt-4",
choices: [
{
message: {
role: "assistant",
content: "Testing invalid arguments",
tool_calls: [
{
id: "test-123",
type: "function",
function: {
name: "test_function",
arguments: "invalid json",
},
},
],
refusal: null,
},
finish_reason: "tool_calls",
index: 0,
},
],
created: 123456789,
object: "chat.completion",
}
const anthropicMessage = convertToAnthropicMessage(
openAiCompletion as OpenAI.Chat.Completions.ChatCompletion,
)
expect(anthropicMessage.content).toHaveLength(2)
expect(anthropicMessage.content[1]).toEqual({
type: "tool_use",
id: "test-123",
name: "test_function",
input: {}, // Should default to empty object for invalid JSON
})
})
})
})

View File

@@ -1,114 +1,114 @@
import { ApiStreamChunk } from '../stream';
import { ApiStreamChunk } from "../stream"
describe('API Stream Types', () => {
describe('ApiStreamChunk', () => {
it('should correctly handle text chunks', () => {
const textChunk: ApiStreamChunk = {
type: 'text',
text: 'Hello world'
};
describe("API Stream Types", () => {
describe("ApiStreamChunk", () => {
it("should correctly handle text chunks", () => {
const textChunk: ApiStreamChunk = {
type: "text",
text: "Hello world",
}
expect(textChunk.type).toBe('text');
expect(textChunk.text).toBe('Hello world');
});
expect(textChunk.type).toBe("text")
expect(textChunk.text).toBe("Hello world")
})
it('should correctly handle usage chunks with cache information', () => {
const usageChunk: ApiStreamChunk = {
type: 'usage',
inputTokens: 100,
outputTokens: 50,
cacheWriteTokens: 20,
cacheReadTokens: 10
};
it("should correctly handle usage chunks with cache information", () => {
const usageChunk: ApiStreamChunk = {
type: "usage",
inputTokens: 100,
outputTokens: 50,
cacheWriteTokens: 20,
cacheReadTokens: 10,
}
expect(usageChunk.type).toBe('usage');
expect(usageChunk.inputTokens).toBe(100);
expect(usageChunk.outputTokens).toBe(50);
expect(usageChunk.cacheWriteTokens).toBe(20);
expect(usageChunk.cacheReadTokens).toBe(10);
});
expect(usageChunk.type).toBe("usage")
expect(usageChunk.inputTokens).toBe(100)
expect(usageChunk.outputTokens).toBe(50)
expect(usageChunk.cacheWriteTokens).toBe(20)
expect(usageChunk.cacheReadTokens).toBe(10)
})
it('should handle usage chunks without cache tokens', () => {
const usageChunk: ApiStreamChunk = {
type: 'usage',
inputTokens: 100,
outputTokens: 50
};
it("should handle usage chunks without cache tokens", () => {
const usageChunk: ApiStreamChunk = {
type: "usage",
inputTokens: 100,
outputTokens: 50,
}
expect(usageChunk.type).toBe('usage');
expect(usageChunk.inputTokens).toBe(100);
expect(usageChunk.outputTokens).toBe(50);
expect(usageChunk.cacheWriteTokens).toBeUndefined();
expect(usageChunk.cacheReadTokens).toBeUndefined();
});
expect(usageChunk.type).toBe("usage")
expect(usageChunk.inputTokens).toBe(100)
expect(usageChunk.outputTokens).toBe(50)
expect(usageChunk.cacheWriteTokens).toBeUndefined()
expect(usageChunk.cacheReadTokens).toBeUndefined()
})
it('should handle text chunks with empty strings', () => {
const emptyTextChunk: ApiStreamChunk = {
type: 'text',
text: ''
};
it("should handle text chunks with empty strings", () => {
const emptyTextChunk: ApiStreamChunk = {
type: "text",
text: "",
}
expect(emptyTextChunk.type).toBe('text');
expect(emptyTextChunk.text).toBe('');
});
expect(emptyTextChunk.type).toBe("text")
expect(emptyTextChunk.text).toBe("")
})
it('should handle usage chunks with zero tokens', () => {
const zeroUsageChunk: ApiStreamChunk = {
type: 'usage',
inputTokens: 0,
outputTokens: 0
};
it("should handle usage chunks with zero tokens", () => {
const zeroUsageChunk: ApiStreamChunk = {
type: "usage",
inputTokens: 0,
outputTokens: 0,
}
expect(zeroUsageChunk.type).toBe('usage');
expect(zeroUsageChunk.inputTokens).toBe(0);
expect(zeroUsageChunk.outputTokens).toBe(0);
});
expect(zeroUsageChunk.type).toBe("usage")
expect(zeroUsageChunk.inputTokens).toBe(0)
expect(zeroUsageChunk.outputTokens).toBe(0)
})
it('should handle usage chunks with large token counts', () => {
const largeUsageChunk: ApiStreamChunk = {
type: 'usage',
inputTokens: 1000000,
outputTokens: 500000,
cacheWriteTokens: 200000,
cacheReadTokens: 100000
};
it("should handle usage chunks with large token counts", () => {
const largeUsageChunk: ApiStreamChunk = {
type: "usage",
inputTokens: 1000000,
outputTokens: 500000,
cacheWriteTokens: 200000,
cacheReadTokens: 100000,
}
expect(largeUsageChunk.type).toBe('usage');
expect(largeUsageChunk.inputTokens).toBe(1000000);
expect(largeUsageChunk.outputTokens).toBe(500000);
expect(largeUsageChunk.cacheWriteTokens).toBe(200000);
expect(largeUsageChunk.cacheReadTokens).toBe(100000);
});
expect(largeUsageChunk.type).toBe("usage")
expect(largeUsageChunk.inputTokens).toBe(1000000)
expect(largeUsageChunk.outputTokens).toBe(500000)
expect(largeUsageChunk.cacheWriteTokens).toBe(200000)
expect(largeUsageChunk.cacheReadTokens).toBe(100000)
})
it('should handle text chunks with special characters', () => {
const specialCharsChunk: ApiStreamChunk = {
type: 'text',
text: '!@#$%^&*()_+-=[]{}|;:,.<>?`~'
};
it("should handle text chunks with special characters", () => {
const specialCharsChunk: ApiStreamChunk = {
type: "text",
text: "!@#$%^&*()_+-=[]{}|;:,.<>?`~",
}
expect(specialCharsChunk.type).toBe('text');
expect(specialCharsChunk.text).toBe('!@#$%^&*()_+-=[]{}|;:,.<>?`~');
});
expect(specialCharsChunk.type).toBe("text")
expect(specialCharsChunk.text).toBe("!@#$%^&*()_+-=[]{}|;:,.<>?`~")
})
it('should handle text chunks with unicode characters', () => {
const unicodeChunk: ApiStreamChunk = {
type: 'text',
text: '你好世界👋🌍'
};
it("should handle text chunks with unicode characters", () => {
const unicodeChunk: ApiStreamChunk = {
type: "text",
text: "你好世界👋🌍",
}
expect(unicodeChunk.type).toBe('text');
expect(unicodeChunk.text).toBe('你好世界👋🌍');
});
expect(unicodeChunk.type).toBe("text")
expect(unicodeChunk.text).toBe("你好世界👋🌍")
})
it('should handle text chunks with multiline content', () => {
const multilineChunk: ApiStreamChunk = {
type: 'text',
text: 'Line 1\nLine 2\nLine 3'
};
it("should handle text chunks with multiline content", () => {
const multilineChunk: ApiStreamChunk = {
type: "text",
text: "Line 1\nLine 2\nLine 3",
}
expect(multilineChunk.type).toBe('text');
expect(multilineChunk.text).toBe('Line 1\nLine 2\nLine 3');
expect(multilineChunk.text.split('\n')).toHaveLength(3);
});
});
});
expect(multilineChunk.type).toBe("text")
expect(multilineChunk.text).toBe("Line 1\nLine 2\nLine 3")
expect(multilineChunk.text.split("\n")).toHaveLength(3)
})
})
})

View File

@@ -1,66 +1,66 @@
import { Anthropic } from "@anthropic-ai/sdk";
import * as vscode from 'vscode';
import { convertToVsCodeLmMessages, convertToAnthropicRole, convertToAnthropicMessage } from '../vscode-lm-format';
import { Anthropic } from "@anthropic-ai/sdk"
import * as vscode from "vscode"
import { convertToVsCodeLmMessages, convertToAnthropicRole, convertToAnthropicMessage } from "../vscode-lm-format"
// Mock crypto
const mockCrypto = {
randomUUID: () => 'test-uuid'
};
global.crypto = mockCrypto as any;
randomUUID: () => "test-uuid",
}
global.crypto = mockCrypto as any
// Define types for our mocked classes
interface MockLanguageModelTextPart {
type: 'text';
value: string;
type: "text"
value: string
}
interface MockLanguageModelToolCallPart {
type: 'tool_call';
callId: string;
name: string;
input: any;
type: "tool_call"
callId: string
name: string
input: any
}
interface MockLanguageModelToolResultPart {
type: 'tool_result';
toolUseId: string;
parts: MockLanguageModelTextPart[];
type: "tool_result"
toolUseId: string
parts: MockLanguageModelTextPart[]
}
type MockMessageContent = MockLanguageModelTextPart | MockLanguageModelToolCallPart | MockLanguageModelToolResultPart;
type MockMessageContent = MockLanguageModelTextPart | MockLanguageModelToolCallPart | MockLanguageModelToolResultPart
interface MockLanguageModelChatMessage {
role: string;
name?: string;
content: MockMessageContent[];
role: string
name?: string
content: MockMessageContent[]
}
// Mock vscode namespace
jest.mock('vscode', () => {
jest.mock("vscode", () => {
const LanguageModelChatMessageRole = {
Assistant: 'assistant',
User: 'user'
};
Assistant: "assistant",
User: "user",
}
class MockLanguageModelTextPart {
type = 'text';
type = "text"
constructor(public value: string) {}
}
class MockLanguageModelToolCallPart {
type = 'tool_call';
type = "tool_call"
constructor(
public callId: string,
public name: string,
public input: any
public input: any,
) {}
}
class MockLanguageModelToolResultPart {
type = 'tool_result';
type = "tool_result"
constructor(
public toolUseId: string,
public parts: MockLanguageModelTextPart[]
public parts: MockLanguageModelTextPart[],
) {}
}
@@ -68,179 +68,189 @@ jest.mock('vscode', () => {
LanguageModelChatMessage: {
Assistant: jest.fn((content) => ({
role: LanguageModelChatMessageRole.Assistant,
name: 'assistant',
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
name: "assistant",
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
})),
User: jest.fn((content) => ({
role: LanguageModelChatMessageRole.User,
name: 'user',
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
}))
name: "user",
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
})),
},
LanguageModelChatMessageRole,
LanguageModelTextPart: MockLanguageModelTextPart,
LanguageModelToolCallPart: MockLanguageModelToolCallPart,
LanguageModelToolResultPart: MockLanguageModelToolResultPart
};
});
LanguageModelToolResultPart: MockLanguageModelToolResultPart,
}
})
describe('vscode-lm-format', () => {
describe('convertToVsCodeLmMessages', () => {
it('should convert simple string messages', () => {
describe("vscode-lm-format", () => {
describe("convertToVsCodeLmMessages", () => {
it("should convert simple string messages", () => {
const messages: Anthropic.Messages.MessageParam[] = [
{ role: 'user', content: 'Hello' },
{ role: 'assistant', content: 'Hi there' }
];
{ role: "user", content: "Hello" },
{ role: "assistant", content: "Hi there" },
]
const result = convertToVsCodeLmMessages(messages);
expect(result).toHaveLength(2);
expect(result[0].role).toBe('user');
expect((result[0].content[0] as MockLanguageModelTextPart).value).toBe('Hello');
expect(result[1].role).toBe('assistant');
expect((result[1].content[0] as MockLanguageModelTextPart).value).toBe('Hi there');
});
const result = convertToVsCodeLmMessages(messages)
it('should handle complex user messages with tool results', () => {
const messages: Anthropic.Messages.MessageParam[] = [{
role: 'user',
content: [
{ type: 'text', text: 'Here is the result:' },
{
type: 'tool_result',
tool_use_id: 'tool-1',
content: 'Tool output'
}
]
}];
expect(result).toHaveLength(2)
expect(result[0].role).toBe("user")
expect((result[0].content[0] as MockLanguageModelTextPart).value).toBe("Hello")
expect(result[1].role).toBe("assistant")
expect((result[1].content[0] as MockLanguageModelTextPart).value).toBe("Hi there")
})
const result = convertToVsCodeLmMessages(messages);
expect(result).toHaveLength(1);
expect(result[0].role).toBe('user');
expect(result[0].content).toHaveLength(2);
const [toolResult, textContent] = result[0].content as [MockLanguageModelToolResultPart, MockLanguageModelTextPart];
expect(toolResult.type).toBe('tool_result');
expect(textContent.type).toBe('text');
});
it("should handle complex user messages with tool results", () => {
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [
{ type: "text", text: "Here is the result:" },
{
type: "tool_result",
tool_use_id: "tool-1",
content: "Tool output",
},
],
},
]
it('should handle complex assistant messages with tool calls', () => {
const messages: Anthropic.Messages.MessageParam[] = [{
role: 'assistant',
content: [
{ type: 'text', text: 'Let me help you with that.' },
{
type: 'tool_use',
id: 'tool-1',
name: 'calculator',
input: { operation: 'add', numbers: [2, 2] }
}
]
}];
const result = convertToVsCodeLmMessages(messages)
const result = convertToVsCodeLmMessages(messages);
expect(result).toHaveLength(1);
expect(result[0].role).toBe('assistant');
expect(result[0].content).toHaveLength(2);
const [toolCall, textContent] = result[0].content as [MockLanguageModelToolCallPart, MockLanguageModelTextPart];
expect(toolCall.type).toBe('tool_call');
expect(textContent.type).toBe('text');
});
expect(result).toHaveLength(1)
expect(result[0].role).toBe("user")
expect(result[0].content).toHaveLength(2)
const [toolResult, textContent] = result[0].content as [
MockLanguageModelToolResultPart,
MockLanguageModelTextPart,
]
expect(toolResult.type).toBe("tool_result")
expect(textContent.type).toBe("text")
})
it('should handle image blocks with appropriate placeholders', () => {
const messages: Anthropic.Messages.MessageParam[] = [{
role: 'user',
content: [
{ type: 'text', text: 'Look at this:' },
{
type: 'image',
source: {
type: 'base64',
media_type: 'image/png',
data: 'base64data'
}
}
]
}];
it("should handle complex assistant messages with tool calls", () => {
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "assistant",
content: [
{ type: "text", text: "Let me help you with that." },
{
type: "tool_use",
id: "tool-1",
name: "calculator",
input: { operation: "add", numbers: [2, 2] },
},
],
},
]
const result = convertToVsCodeLmMessages(messages);
expect(result).toHaveLength(1);
const imagePlaceholder = result[0].content[1] as MockLanguageModelTextPart;
expect(imagePlaceholder.value).toContain('[Image (base64): image/png not supported by VSCode LM API]');
});
});
const result = convertToVsCodeLmMessages(messages)
describe('convertToAnthropicRole', () => {
it('should convert assistant role correctly', () => {
const result = convertToAnthropicRole('assistant' as any);
expect(result).toBe('assistant');
});
expect(result).toHaveLength(1)
expect(result[0].role).toBe("assistant")
expect(result[0].content).toHaveLength(2)
const [toolCall, textContent] = result[0].content as [
MockLanguageModelToolCallPart,
MockLanguageModelTextPart,
]
expect(toolCall.type).toBe("tool_call")
expect(textContent.type).toBe("text")
})
it('should convert user role correctly', () => {
const result = convertToAnthropicRole('user' as any);
expect(result).toBe('user');
});
it("should handle image blocks with appropriate placeholders", () => {
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [
{ type: "text", text: "Look at this:" },
{
type: "image",
source: {
type: "base64",
media_type: "image/png",
data: "base64data",
},
},
],
},
]
it('should return null for unknown roles', () => {
const result = convertToAnthropicRole('unknown' as any);
expect(result).toBeNull();
});
});
const result = convertToVsCodeLmMessages(messages)
describe('convertToAnthropicMessage', () => {
it('should convert assistant message with text content', async () => {
expect(result).toHaveLength(1)
const imagePlaceholder = result[0].content[1] as MockLanguageModelTextPart
expect(imagePlaceholder.value).toContain("[Image (base64): image/png not supported by VSCode LM API]")
})
})
describe("convertToAnthropicRole", () => {
it("should convert assistant role correctly", () => {
const result = convertToAnthropicRole("assistant" as any)
expect(result).toBe("assistant")
})
it("should convert user role correctly", () => {
const result = convertToAnthropicRole("user" as any)
expect(result).toBe("user")
})
it("should return null for unknown roles", () => {
const result = convertToAnthropicRole("unknown" as any)
expect(result).toBeNull()
})
})
describe("convertToAnthropicMessage", () => {
it("should convert assistant message with text content", async () => {
const vsCodeMessage = {
role: 'assistant',
name: 'assistant',
content: [new vscode.LanguageModelTextPart('Hello')]
};
role: "assistant",
name: "assistant",
content: [new vscode.LanguageModelTextPart("Hello")],
}
const result = await convertToAnthropicMessage(vsCodeMessage as any);
expect(result.role).toBe('assistant');
expect(result.content).toHaveLength(1);
const result = await convertToAnthropicMessage(vsCodeMessage as any)
expect(result.role).toBe("assistant")
expect(result.content).toHaveLength(1)
expect(result.content[0]).toEqual({
type: 'text',
text: 'Hello'
});
expect(result.id).toBe('test-uuid');
});
type: "text",
text: "Hello",
})
expect(result.id).toBe("test-uuid")
})
it('should convert assistant message with tool calls', async () => {
it("should convert assistant message with tool calls", async () => {
const vsCodeMessage = {
role: 'assistant',
name: 'assistant',
content: [new vscode.LanguageModelToolCallPart(
'call-1',
'calculator',
{ operation: 'add', numbers: [2, 2] }
)]
};
role: "assistant",
name: "assistant",
content: [
new vscode.LanguageModelToolCallPart("call-1", "calculator", { operation: "add", numbers: [2, 2] }),
],
}
const result = await convertToAnthropicMessage(vsCodeMessage as any);
expect(result.content).toHaveLength(1);
const result = await convertToAnthropicMessage(vsCodeMessage as any)
expect(result.content).toHaveLength(1)
expect(result.content[0]).toEqual({
type: 'tool_use',
id: 'call-1',
name: 'calculator',
input: { operation: 'add', numbers: [2, 2] }
});
expect(result.id).toBe('test-uuid');
});
type: "tool_use",
id: "call-1",
name: "calculator",
input: { operation: "add", numbers: [2, 2] },
})
expect(result.id).toBe("test-uuid")
})
it('should throw error for non-assistant messages', async () => {
it("should throw error for non-assistant messages", async () => {
const vsCodeMessage = {
role: 'user',
name: 'user',
content: [new vscode.LanguageModelTextPart('Hello')]
};
role: "user",
name: "user",
content: [new vscode.LanguageModelTextPart("Hello")],
}
await expect(convertToAnthropicMessage(vsCodeMessage as any))
.rejects
.toThrow('Cline <Language Model API>: Only assistant messages are supported.');
});
});
});
await expect(convertToAnthropicMessage(vsCodeMessage as any)).rejects.toThrow(
"Cline <Language Model API>: Only assistant messages are supported.",
)
})
})
})

View File

@@ -8,210 +8,216 @@ import { StreamEvent } from "../providers/bedrock"
/**
* Convert Anthropic messages to Bedrock Converse format
*/
export function convertToBedrockConverseMessages(
anthropicMessages: Anthropic.Messages.MessageParam[]
): Message[] {
return anthropicMessages.map(anthropicMessage => {
// Map Anthropic roles to Bedrock roles
const role: ConversationRole = anthropicMessage.role === "assistant" ? "assistant" : "user"
export function convertToBedrockConverseMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): Message[] {
return anthropicMessages.map((anthropicMessage) => {
// Map Anthropic roles to Bedrock roles
const role: ConversationRole = anthropicMessage.role === "assistant" ? "assistant" : "user"
if (typeof anthropicMessage.content === "string") {
return {
role,
content: [{
text: anthropicMessage.content
}] as ContentBlock[]
}
}
if (typeof anthropicMessage.content === "string") {
return {
role,
content: [
{
text: anthropicMessage.content,
},
] as ContentBlock[],
}
}
// Process complex content types
const content = anthropicMessage.content.map(block => {
const messageBlock = block as MessageContent & {
id?: string,
tool_use_id?: string,
content?: Array<{ type: string, text: string }>,
output?: string | Array<{ type: string, text: string }>
}
// Process complex content types
const content = anthropicMessage.content.map((block) => {
const messageBlock = block as MessageContent & {
id?: string
tool_use_id?: string
content?: Array<{ type: string; text: string }>
output?: string | Array<{ type: string; text: string }>
}
if (messageBlock.type === "text") {
return {
text: messageBlock.text || ''
} as ContentBlock
}
if (messageBlock.type === "image" && messageBlock.source) {
// Convert base64 string to byte array if needed
let byteArray: Uint8Array
if (typeof messageBlock.source.data === 'string') {
const binaryString = atob(messageBlock.source.data)
byteArray = new Uint8Array(binaryString.length)
for (let i = 0; i < binaryString.length; i++) {
byteArray[i] = binaryString.charCodeAt(i)
}
} else {
byteArray = messageBlock.source.data
}
if (messageBlock.type === "text") {
return {
text: messageBlock.text || "",
} as ContentBlock
}
// Extract format from media_type (e.g., "image/jpeg" -> "jpeg")
const format = messageBlock.source.media_type.split('/')[1]
if (!['png', 'jpeg', 'gif', 'webp'].includes(format)) {
throw new Error(`Unsupported image format: ${format}`)
}
if (messageBlock.type === "image" && messageBlock.source) {
// Convert base64 string to byte array if needed
let byteArray: Uint8Array
if (typeof messageBlock.source.data === "string") {
const binaryString = atob(messageBlock.source.data)
byteArray = new Uint8Array(binaryString.length)
for (let i = 0; i < binaryString.length; i++) {
byteArray[i] = binaryString.charCodeAt(i)
}
} else {
byteArray = messageBlock.source.data
}
return {
image: {
format: format as "png" | "jpeg" | "gif" | "webp",
source: {
bytes: byteArray
}
}
} as ContentBlock
}
// Extract format from media_type (e.g., "image/jpeg" -> "jpeg")
const format = messageBlock.source.media_type.split("/")[1]
if (!["png", "jpeg", "gif", "webp"].includes(format)) {
throw new Error(`Unsupported image format: ${format}`)
}
if (messageBlock.type === "tool_use") {
// Convert tool use to XML format
const toolParams = Object.entries(messageBlock.input || {})
.map(([key, value]) => `<${key}>\n${value}\n</${key}>`)
.join('\n')
return {
image: {
format: format as "png" | "jpeg" | "gif" | "webp",
source: {
bytes: byteArray,
},
},
} as ContentBlock
}
return {
toolUse: {
toolUseId: messageBlock.id || '',
name: messageBlock.name || '',
input: `<${messageBlock.name}>\n${toolParams}\n</${messageBlock.name}>`
}
} as ContentBlock
}
if (messageBlock.type === "tool_use") {
// Convert tool use to XML format
const toolParams = Object.entries(messageBlock.input || {})
.map(([key, value]) => `<${key}>\n${value}\n</${key}>`)
.join("\n")
if (messageBlock.type === "tool_result") {
// First try to use content if available
if (messageBlock.content && Array.isArray(messageBlock.content)) {
return {
toolResult: {
toolUseId: messageBlock.tool_use_id || '',
content: messageBlock.content.map(item => ({
text: item.text
})),
status: "success"
}
} as ContentBlock
}
return {
toolUse: {
toolUseId: messageBlock.id || "",
name: messageBlock.name || "",
input: `<${messageBlock.name}>\n${toolParams}\n</${messageBlock.name}>`,
},
} as ContentBlock
}
// Fall back to output handling if content is not available
if (messageBlock.output && typeof messageBlock.output === "string") {
return {
toolResult: {
toolUseId: messageBlock.tool_use_id || '',
content: [{
text: messageBlock.output
}],
status: "success"
}
} as ContentBlock
}
// Handle array of content blocks if output is an array
if (Array.isArray(messageBlock.output)) {
return {
toolResult: {
toolUseId: messageBlock.tool_use_id || '',
content: messageBlock.output.map(part => {
if (typeof part === "object" && "text" in part) {
return { text: part.text }
}
// Skip images in tool results as they're handled separately
if (typeof part === "object" && "type" in part && part.type === "image") {
return { text: "(see following message for image)" }
}
return { text: String(part) }
}),
status: "success"
}
} as ContentBlock
}
if (messageBlock.type === "tool_result") {
// First try to use content if available
if (messageBlock.content && Array.isArray(messageBlock.content)) {
return {
toolResult: {
toolUseId: messageBlock.tool_use_id || "",
content: messageBlock.content.map((item) => ({
text: item.text,
})),
status: "success",
},
} as ContentBlock
}
// Default case
return {
toolResult: {
toolUseId: messageBlock.tool_use_id || '',
content: [{
text: String(messageBlock.output || '')
}],
status: "success"
}
} as ContentBlock
}
// Fall back to output handling if content is not available
if (messageBlock.output && typeof messageBlock.output === "string") {
return {
toolResult: {
toolUseId: messageBlock.tool_use_id || "",
content: [
{
text: messageBlock.output,
},
],
status: "success",
},
} as ContentBlock
}
// Handle array of content blocks if output is an array
if (Array.isArray(messageBlock.output)) {
return {
toolResult: {
toolUseId: messageBlock.tool_use_id || "",
content: messageBlock.output.map((part) => {
if (typeof part === "object" && "text" in part) {
return { text: part.text }
}
// Skip images in tool results as they're handled separately
if (typeof part === "object" && "type" in part && part.type === "image") {
return { text: "(see following message for image)" }
}
return { text: String(part) }
}),
status: "success",
},
} as ContentBlock
}
if (messageBlock.type === "video") {
const videoContent = messageBlock.s3Location ? {
s3Location: {
uri: messageBlock.s3Location.uri,
bucketOwner: messageBlock.s3Location.bucketOwner
}
} : messageBlock.source
// Default case
return {
toolResult: {
toolUseId: messageBlock.tool_use_id || "",
content: [
{
text: String(messageBlock.output || ""),
},
],
status: "success",
},
} as ContentBlock
}
return {
video: {
format: "mp4", // Default to mp4, adjust based on actual format if needed
source: videoContent
}
} as ContentBlock
}
if (messageBlock.type === "video") {
const videoContent = messageBlock.s3Location
? {
s3Location: {
uri: messageBlock.s3Location.uri,
bucketOwner: messageBlock.s3Location.bucketOwner,
},
}
: messageBlock.source
// Default case for unknown block types
return {
text: '[Unknown Block Type]'
} as ContentBlock
})
return {
video: {
format: "mp4", // Default to mp4, adjust based on actual format if needed
source: videoContent,
},
} as ContentBlock
}
return {
role,
content
}
})
// Default case for unknown block types
return {
text: "[Unknown Block Type]",
} as ContentBlock
})
return {
role,
content,
}
})
}
/**
* Convert Bedrock Converse stream events to Anthropic message format
*/
export function convertToAnthropicMessage(
streamEvent: StreamEvent,
modelId: string
streamEvent: StreamEvent,
modelId: string,
): Partial<Anthropic.Messages.Message> {
// Handle metadata events
if (streamEvent.metadata?.usage) {
return {
id: '', // Bedrock doesn't provide message IDs
type: "message",
role: "assistant",
model: modelId,
usage: {
input_tokens: streamEvent.metadata.usage.inputTokens || 0,
output_tokens: streamEvent.metadata.usage.outputTokens || 0
}
}
}
// Handle metadata events
if (streamEvent.metadata?.usage) {
return {
id: "", // Bedrock doesn't provide message IDs
type: "message",
role: "assistant",
model: modelId,
usage: {
input_tokens: streamEvent.metadata.usage.inputTokens || 0,
output_tokens: streamEvent.metadata.usage.outputTokens || 0,
},
}
}
// Handle content blocks
const text = streamEvent.contentBlockStart?.start?.text || streamEvent.contentBlockDelta?.delta?.text
if (text !== undefined) {
return {
type: "message",
role: "assistant",
content: [{ type: "text", text: text }],
model: modelId
}
}
// Handle content blocks
const text = streamEvent.contentBlockStart?.start?.text || streamEvent.contentBlockDelta?.delta?.text
if (text !== undefined) {
return {
type: "message",
role: "assistant",
content: [{ type: "text", text: text }],
model: modelId,
}
}
// Handle message stop
if (streamEvent.messageStop) {
return {
type: "message",
role: "assistant",
stop_reason: streamEvent.messageStop.stopReason || null,
stop_sequence: null,
model: modelId
}
}
// Handle message stop
if (streamEvent.messageStop) {
return {
type: "message",
role: "assistant",
stop_reason: streamEvent.messageStop.stopReason || null,
stop_sequence: null,
model: modelId,
}
}
return {}
return {}
}

View File

@@ -1,5 +1,5 @@
import { Anthropic } from "@anthropic-ai/sdk";
import * as vscode from 'vscode';
import { Anthropic } from "@anthropic-ai/sdk"
import * as vscode from "vscode"
/**
* Safely converts a value into a plain object.
@@ -7,30 +7,31 @@ import * as vscode from 'vscode';
function asObjectSafe(value: any): object {
// Handle null/undefined
if (!value) {
return {};
return {}
}
try {
// Handle strings that might be JSON
if (typeof value === 'string') {
return JSON.parse(value);
if (typeof value === "string") {
return JSON.parse(value)
}
// Handle pre-existing objects
if (typeof value === 'object') {
return Object.assign({}, value);
if (typeof value === "object") {
return Object.assign({}, value)
}
return {};
}
catch (error) {
console.warn('Cline <Language Model API>: Failed to parse object:', error);
return {};
return {}
} catch (error) {
console.warn("Cline <Language Model API>: Failed to parse object:", error)
return {}
}
}
export function convertToVsCodeLmMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): vscode.LanguageModelChatMessage[] {
const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = [];
export function convertToVsCodeLmMessages(
anthropicMessages: Anthropic.Messages.MessageParam[],
): vscode.LanguageModelChatMessage[] {
const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = []
for (const anthropicMessage of anthropicMessages) {
// Handle simple string messages
@@ -38,135 +39,129 @@ export function convertToVsCodeLmMessages(anthropicMessages: Anthropic.Messages.
vsCodeLmMessages.push(
anthropicMessage.role === "assistant"
? vscode.LanguageModelChatMessage.Assistant(anthropicMessage.content)
: vscode.LanguageModelChatMessage.User(anthropicMessage.content)
);
continue;
: vscode.LanguageModelChatMessage.User(anthropicMessage.content),
)
continue
}
// Handle complex message structures
switch (anthropicMessage.role) {
case "user": {
const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[];
toolMessages: Anthropic.ToolResultBlockParam[];
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
toolMessages: Anthropic.ToolResultBlockParam[]
}>(
(acc, part) => {
if (part.type === "tool_result") {
acc.toolMessages.push(part);
acc.toolMessages.push(part)
} else if (part.type === "text" || part.type === "image") {
acc.nonToolMessages.push(part)
}
else if (part.type === "text" || part.type === "image") {
acc.nonToolMessages.push(part);
}
return acc;
return acc
},
{ nonToolMessages: [], toolMessages: [] },
);
)
// Process tool messages first then non-tool messages
const contentParts = [
// Convert tool messages to ToolResultParts
...toolMessages.map((toolMessage) => {
// Process tool result content into TextParts
const toolContentParts: vscode.LanguageModelTextPart[] = (
const toolContentParts: vscode.LanguageModelTextPart[] =
typeof toolMessage.content === "string"
? [new vscode.LanguageModelTextPart(toolMessage.content)]
: (
toolMessage.content?.map((part) => {
: (toolMessage.content?.map((part) => {
if (part.type === "image") {
return new vscode.LanguageModelTextPart(
`[Image (${part.source?.type || 'Unknown source-type'}): ${part.source?.media_type || 'unknown media-type'} not supported by VSCode LM API]`
);
`[Image (${part.source?.type || "Unknown source-type"}): ${part.source?.media_type || "unknown media-type"} not supported by VSCode LM API]`,
)
}
return new vscode.LanguageModelTextPart(part.text);
})
?? [new vscode.LanguageModelTextPart("")]
)
);
return new vscode.LanguageModelTextPart(part.text)
}) ?? [new vscode.LanguageModelTextPart("")])
return new vscode.LanguageModelToolResultPart(
toolMessage.tool_use_id,
toolContentParts
);
return new vscode.LanguageModelToolResultPart(toolMessage.tool_use_id, toolContentParts)
}),
// Convert non-tool messages to TextParts after tool messages
...nonToolMessages.map((part) => {
if (part.type === "image") {
return new vscode.LanguageModelTextPart(
`[Image (${part.source?.type || 'Unknown source-type'}): ${part.source?.media_type || 'unknown media-type'} not supported by VSCode LM API]`
);
`[Image (${part.source?.type || "Unknown source-type"}): ${part.source?.media_type || "unknown media-type"} not supported by VSCode LM API]`,
)
}
return new vscode.LanguageModelTextPart(part.text);
})
];
return new vscode.LanguageModelTextPart(part.text)
}),
]
// Add single user message with all content parts
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.User(contentParts));
break;
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.User(contentParts))
break
}
case "assistant": {
const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[];
toolMessages: Anthropic.ToolUseBlockParam[];
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
toolMessages: Anthropic.ToolUseBlockParam[]
}>(
(acc, part) => {
if (part.type === "tool_use") {
acc.toolMessages.push(part);
acc.toolMessages.push(part)
} else if (part.type === "text" || part.type === "image") {
acc.nonToolMessages.push(part)
}
else if (part.type === "text" || part.type === "image") {
acc.nonToolMessages.push(part);
}
return acc;
return acc
},
{ nonToolMessages: [], toolMessages: [] },
);
)
// Process tool messages first then non-tool messages
// Process tool messages first then non-tool messages
const contentParts = [
// Convert tool messages to ToolCallParts first
...toolMessages.map((toolMessage) =>
new vscode.LanguageModelToolCallPart(
toolMessage.id,
toolMessage.name,
asObjectSafe(toolMessage.input)
)
...toolMessages.map(
(toolMessage) =>
new vscode.LanguageModelToolCallPart(
toolMessage.id,
toolMessage.name,
asObjectSafe(toolMessage.input),
),
),
// Convert non-tool messages to TextParts after tool messages
...nonToolMessages.map((part) => {
if (part.type === "image") {
return new vscode.LanguageModelTextPart("[Image generation not supported by VSCode LM API]");
return new vscode.LanguageModelTextPart("[Image generation not supported by VSCode LM API]")
}
return new vscode.LanguageModelTextPart(part.text);
})
];
return new vscode.LanguageModelTextPart(part.text)
}),
]
// Add the assistant message to the list of messages
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.Assistant(contentParts));
break;
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.Assistant(contentParts))
break
}
}
}
return vsCodeLmMessages;
return vsCodeLmMessages
}
export function convertToAnthropicRole(vsCodeLmMessageRole: vscode.LanguageModelChatMessageRole): string | null {
switch (vsCodeLmMessageRole) {
case vscode.LanguageModelChatMessageRole.Assistant:
return "assistant";
return "assistant"
case vscode.LanguageModelChatMessageRole.User:
return "user";
return "user"
default:
return null;
return null
}
}
export async function convertToAnthropicMessage(vsCodeLmMessage: vscode.LanguageModelChatMessage): Promise<Anthropic.Messages.Message> {
const anthropicRole: string | null = convertToAnthropicRole(vsCodeLmMessage.role);
export async function convertToAnthropicMessage(
vsCodeLmMessage: vscode.LanguageModelChatMessage,
): Promise<Anthropic.Messages.Message> {
const anthropicRole: string | null = convertToAnthropicRole(vsCodeLmMessage.role)
if (anthropicRole !== "assistant") {
throw new Error("Cline <Language Model API>: Only assistant messages are supported.");
throw new Error("Cline <Language Model API>: Only assistant messages are supported.")
}
return {
@@ -174,36 +169,32 @@ export async function convertToAnthropicMessage(vsCodeLmMessage: vscode.Language
type: "message",
model: "vscode-lm",
role: anthropicRole,
content: (
vsCodeLmMessage.content
.map((part): Anthropic.ContentBlock | null => {
if (part instanceof vscode.LanguageModelTextPart) {
return {
type: "text",
text: part.value
};
content: vsCodeLmMessage.content
.map((part): Anthropic.ContentBlock | null => {
if (part instanceof vscode.LanguageModelTextPart) {
return {
type: "text",
text: part.value,
}
}
if (part instanceof vscode.LanguageModelToolCallPart) {
return {
type: "tool_use",
id: part.callId || crypto.randomUUID(),
name: part.name,
input: asObjectSafe(part.input)
};
if (part instanceof vscode.LanguageModelToolCallPart) {
return {
type: "tool_use",
id: part.callId || crypto.randomUUID(),
name: part.name,
input: asObjectSafe(part.input),
}
}
return null;
})
.filter(
(part): part is Anthropic.ContentBlock => part !== null
)
),
return null
})
.filter((part): part is Anthropic.ContentBlock => part !== null),
stop_reason: null,
stop_sequence: null,
usage: {
input_tokens: 0,
output_tokens: 0,
}
};
},
}
}