mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-20 04:11:10 -05:00
Prettier backfill
This commit is contained in:
@@ -1,239 +1,238 @@
|
||||
import { AnthropicHandler } from '../anthropic';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import { ApiStream } from '../../transform/stream';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { AnthropicHandler } from "../anthropic"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import { ApiStream } from "../../transform/stream"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock Anthropic client
|
||||
const mockBetaCreate = jest.fn();
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('@anthropic-ai/sdk', () => {
|
||||
return {
|
||||
Anthropic: jest.fn().mockImplementation(() => ({
|
||||
beta: {
|
||||
promptCaching: {
|
||||
messages: {
|
||||
create: mockBetaCreate.mockImplementation(async () => ({
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: 'message_start',
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 100,
|
||||
output_tokens: 50,
|
||||
cache_creation_input_tokens: 20,
|
||||
cache_read_input_tokens: 10
|
||||
}
|
||||
}
|
||||
};
|
||||
yield {
|
||||
type: 'content_block_start',
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'Hello'
|
||||
}
|
||||
};
|
||||
yield {
|
||||
type: 'content_block_delta',
|
||||
delta: {
|
||||
type: 'text_delta',
|
||||
text: ' world'
|
||||
}
|
||||
};
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
||||
},
|
||||
messages: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
content: [
|
||||
{ type: 'text', text: 'Test response' }
|
||||
],
|
||||
role: 'assistant',
|
||||
model: options.model,
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
return {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: 'message_start',
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
yield {
|
||||
type: 'content_block_start',
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'Test response'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
const mockBetaCreate = jest.fn()
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("@anthropic-ai/sdk", () => {
|
||||
return {
|
||||
Anthropic: jest.fn().mockImplementation(() => ({
|
||||
beta: {
|
||||
promptCaching: {
|
||||
messages: {
|
||||
create: mockBetaCreate.mockImplementation(async () => ({
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: "message_start",
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 100,
|
||||
output_tokens: 50,
|
||||
cache_creation_input_tokens: 20,
|
||||
cache_read_input_tokens: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
yield {
|
||||
type: "content_block_start",
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: "text",
|
||||
text: "Hello",
|
||||
},
|
||||
}
|
||||
yield {
|
||||
type: "content_block_delta",
|
||||
delta: {
|
||||
type: "text_delta",
|
||||
text: " world",
|
||||
},
|
||||
}
|
||||
},
|
||||
})),
|
||||
},
|
||||
},
|
||||
},
|
||||
messages: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: "test-completion",
|
||||
content: [{ type: "text", text: "Test response" }],
|
||||
role: "assistant",
|
||||
model: options.model,
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
},
|
||||
}
|
||||
}
|
||||
return {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: "message_start",
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
yield {
|
||||
type: "content_block_start",
|
||||
content_block: {
|
||||
type: "text",
|
||||
text: "Test response",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
})),
|
||||
}
|
||||
})
|
||||
|
||||
describe('AnthropicHandler', () => {
|
||||
let handler: AnthropicHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
describe("AnthropicHandler", () => {
|
||||
let handler: AnthropicHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiKey: 'test-api-key',
|
||||
apiModelId: 'claude-3-5-sonnet-20241022'
|
||||
};
|
||||
handler = new AnthropicHandler(mockOptions);
|
||||
mockBetaCreate.mockClear();
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiKey: "test-api-key",
|
||||
apiModelId: "claude-3-5-sonnet-20241022",
|
||||
}
|
||||
handler = new AnthropicHandler(mockOptions)
|
||||
mockBetaCreate.mockClear()
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(AnthropicHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(AnthropicHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
|
||||
})
|
||||
|
||||
it('should initialize with undefined API key', () => {
|
||||
// The SDK will handle API key validation, so we just verify it initializes
|
||||
const handlerWithoutKey = new AnthropicHandler({
|
||||
...mockOptions,
|
||||
apiKey: undefined
|
||||
});
|
||||
expect(handlerWithoutKey).toBeInstanceOf(AnthropicHandler);
|
||||
});
|
||||
it("should initialize with undefined API key", () => {
|
||||
// The SDK will handle API key validation, so we just verify it initializes
|
||||
const handlerWithoutKey = new AnthropicHandler({
|
||||
...mockOptions,
|
||||
apiKey: undefined,
|
||||
})
|
||||
expect(handlerWithoutKey).toBeInstanceOf(AnthropicHandler)
|
||||
})
|
||||
|
||||
it('should use custom base URL if provided', () => {
|
||||
const customBaseUrl = 'https://custom.anthropic.com';
|
||||
const handlerWithCustomUrl = new AnthropicHandler({
|
||||
...mockOptions,
|
||||
anthropicBaseUrl: customBaseUrl
|
||||
});
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler);
|
||||
});
|
||||
});
|
||||
it("should use custom base URL if provided", () => {
|
||||
const customBaseUrl = "https://custom.anthropic.com"
|
||||
const handlerWithCustomUrl = new AnthropicHandler({
|
||||
...mockOptions,
|
||||
anthropicBaseUrl: customBaseUrl,
|
||||
})
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler)
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: [{
|
||||
type: 'text' as const,
|
||||
text: 'Hello!'
|
||||
}]
|
||||
}
|
||||
];
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: "Hello!",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
it('should handle prompt caching for supported models', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, [
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text' as const, text: 'First message' }]
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text' as const, text: 'Response' }]
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text' as const, text: 'Second message' }]
|
||||
}
|
||||
]);
|
||||
it("should handle prompt caching for supported models", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, [
|
||||
{
|
||||
role: "user",
|
||||
content: [{ type: "text" as const, text: "First message" }],
|
||||
},
|
||||
{
|
||||
role: "assistant",
|
||||
content: [{ type: "text" as const, text: "Response" }],
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: [{ type: "text" as const, text: "Second message" }],
|
||||
},
|
||||
])
|
||||
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
// Verify usage information
|
||||
const usageChunk = chunks.find(chunk => chunk.type === 'usage');
|
||||
expect(usageChunk).toBeDefined();
|
||||
expect(usageChunk?.inputTokens).toBe(100);
|
||||
expect(usageChunk?.outputTokens).toBe(50);
|
||||
expect(usageChunk?.cacheWriteTokens).toBe(20);
|
||||
expect(usageChunk?.cacheReadTokens).toBe(10);
|
||||
// Verify usage information
|
||||
const usageChunk = chunks.find((chunk) => chunk.type === "usage")
|
||||
expect(usageChunk).toBeDefined()
|
||||
expect(usageChunk?.inputTokens).toBe(100)
|
||||
expect(usageChunk?.outputTokens).toBe(50)
|
||||
expect(usageChunk?.cacheWriteTokens).toBe(20)
|
||||
expect(usageChunk?.cacheReadTokens).toBe(10)
|
||||
|
||||
// Verify text content
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(2);
|
||||
expect(textChunks[0].text).toBe('Hello');
|
||||
expect(textChunks[1].text).toBe(' world');
|
||||
// Verify text content
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(2)
|
||||
expect(textChunks[0].text).toBe("Hello")
|
||||
expect(textChunks[1].text).toBe(" world")
|
||||
|
||||
// Verify beta API was used
|
||||
expect(mockBetaCreate).toHaveBeenCalled();
|
||||
expect(mockCreate).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
// Verify beta API was used
|
||||
expect(mockBetaCreate).toHaveBeenCalled()
|
||||
expect(mockCreate).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.apiModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
max_tokens: 8192,
|
||||
temperature: 0,
|
||||
stream: false
|
||||
});
|
||||
});
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.apiModelId,
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
max_tokens: 8192,
|
||||
temperature: 0,
|
||||
stream: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Anthropic completion error: API Error');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Anthropic completion error: API Error")
|
||||
})
|
||||
|
||||
it('should handle non-text content', async () => {
|
||||
mockCreate.mockImplementationOnce(async () => ({
|
||||
content: [{ type: 'image' }]
|
||||
}));
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
it("should handle non-text content", async () => {
|
||||
mockCreate.mockImplementationOnce(async () => ({
|
||||
content: [{ type: "image" }],
|
||||
}))
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockImplementationOnce(async () => ({
|
||||
content: [{ type: 'text', text: '' }]
|
||||
}));
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockImplementationOnce(async () => ({
|
||||
content: [{ type: "text", text: "" }],
|
||||
}))
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return default model if no model ID is provided', () => {
|
||||
const handlerWithoutModel = new AnthropicHandler({
|
||||
...mockOptions,
|
||||
apiModelId: undefined
|
||||
});
|
||||
const model = handlerWithoutModel.getModel();
|
||||
expect(model.id).toBeDefined();
|
||||
expect(model.info).toBeDefined();
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return default model if no model ID is provided", () => {
|
||||
const handlerWithoutModel = new AnthropicHandler({
|
||||
...mockOptions,
|
||||
apiModelId: undefined,
|
||||
})
|
||||
const model = handlerWithoutModel.getModel()
|
||||
expect(model.id).toBeDefined()
|
||||
expect(model.info).toBeDefined()
|
||||
})
|
||||
|
||||
it('should return specified model if valid model ID is provided', () => {
|
||||
const model = handler.getModel();
|
||||
expect(model.id).toBe(mockOptions.apiModelId);
|
||||
expect(model.info).toBeDefined();
|
||||
expect(model.info.maxTokens).toBe(8192);
|
||||
expect(model.info.contextWindow).toBe(200_000);
|
||||
expect(model.info.supportsImages).toBe(true);
|
||||
expect(model.info.supportsPromptCache).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
it("should return specified model if valid model ID is provided", () => {
|
||||
const model = handler.getModel()
|
||||
expect(model.id).toBe(mockOptions.apiModelId)
|
||||
expect(model.info).toBeDefined()
|
||||
expect(model.info.maxTokens).toBe(8192)
|
||||
expect(model.info.contextWindow).toBe(200_000)
|
||||
expect(model.info.supportsImages).toBe(true)
|
||||
expect(model.info.supportsPromptCache).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,246 +1,259 @@
|
||||
import { AwsBedrockHandler } from '../bedrock';
|
||||
import { MessageContent } from '../../../shared/api';
|
||||
import { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { AwsBedrockHandler } from "../bedrock"
|
||||
import { MessageContent } from "../../../shared/api"
|
||||
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
describe('AwsBedrockHandler', () => {
|
||||
let handler: AwsBedrockHandler;
|
||||
describe("AwsBedrockHandler", () => {
|
||||
let handler: AwsBedrockHandler
|
||||
|
||||
beforeEach(() => {
|
||||
handler = new AwsBedrockHandler({
|
||||
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||
awsAccessKey: 'test-access-key',
|
||||
awsSecretKey: 'test-secret-key',
|
||||
awsRegion: 'us-east-1'
|
||||
});
|
||||
});
|
||||
beforeEach(() => {
|
||||
handler = new AwsBedrockHandler({
|
||||
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
awsAccessKey: "test-access-key",
|
||||
awsSecretKey: "test-secret-key",
|
||||
awsRegion: "us-east-1",
|
||||
})
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided config', () => {
|
||||
expect(handler['options'].awsAccessKey).toBe('test-access-key');
|
||||
expect(handler['options'].awsSecretKey).toBe('test-secret-key');
|
||||
expect(handler['options'].awsRegion).toBe('us-east-1');
|
||||
expect(handler['options'].apiModelId).toBe('anthropic.claude-3-5-sonnet-20241022-v2:0');
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided config", () => {
|
||||
expect(handler["options"].awsAccessKey).toBe("test-access-key")
|
||||
expect(handler["options"].awsSecretKey).toBe("test-secret-key")
|
||||
expect(handler["options"].awsRegion).toBe("us-east-1")
|
||||
expect(handler["options"].apiModelId).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
})
|
||||
|
||||
it('should initialize with missing AWS credentials', () => {
|
||||
const handlerWithoutCreds = new AwsBedrockHandler({
|
||||
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||
awsRegion: 'us-east-1'
|
||||
});
|
||||
expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler);
|
||||
});
|
||||
});
|
||||
it("should initialize with missing AWS credentials", () => {
|
||||
const handlerWithoutCreds = new AwsBedrockHandler({
|
||||
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
awsRegion: "us-east-1",
|
||||
})
|
||||
expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler)
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello'
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi there!'
|
||||
}
|
||||
];
|
||||
describe("createMessage", () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello",
|
||||
},
|
||||
{
|
||||
role: "assistant",
|
||||
content: "Hi there!",
|
||||
},
|
||||
]
|
||||
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
|
||||
it('should handle text messages correctly', async () => {
|
||||
const mockResponse = {
|
||||
messages: [{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Hello! How can I help you?' }]
|
||||
}],
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
};
|
||||
it("should handle text messages correctly", async () => {
|
||||
const mockResponse = {
|
||||
messages: [
|
||||
{
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "Hello! How can I help you?" }],
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
},
|
||||
}
|
||||
|
||||
// Mock AWS SDK invoke
|
||||
const mockStream = {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
metadata: {
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
outputTokens: 5
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
// Mock AWS SDK invoke
|
||||
const mockStream = {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
metadata: {
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
outputTokens: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const mockInvoke = jest.fn().mockResolvedValue({
|
||||
stream: mockStream
|
||||
});
|
||||
const mockInvoke = jest.fn().mockResolvedValue({
|
||||
stream: mockStream,
|
||||
})
|
||||
|
||||
handler['client'] = {
|
||||
send: mockInvoke
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
handler["client"] = {
|
||||
send: mockInvoke,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const chunks = [];
|
||||
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
const chunks = []
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'usage',
|
||||
inputTokens: 10,
|
||||
outputTokens: 5
|
||||
});
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(mockInvoke).toHaveBeenCalledWith(expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0'
|
||||
})
|
||||
}));
|
||||
});
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
expect(chunks[0]).toEqual({
|
||||
type: "usage",
|
||||
inputTokens: 10,
|
||||
outputTokens: 5,
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
// Mock AWS SDK invoke with error
|
||||
const mockInvoke = jest.fn().mockRejectedValue(new Error('AWS Bedrock error'));
|
||||
expect(mockInvoke).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
}),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
handler['client'] = {
|
||||
send: mockInvoke
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
it("should handle API errors", async () => {
|
||||
// Mock AWS SDK invoke with error
|
||||
const mockInvoke = jest.fn().mockRejectedValue(new Error("AWS Bedrock error"))
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
handler["client"] = {
|
||||
send: mockInvoke,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
}
|
||||
}).rejects.toThrow('AWS Bedrock error');
|
||||
});
|
||||
});
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(JSON.stringify({
|
||||
content: 'Test response'
|
||||
}))
|
||||
};
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
}
|
||||
}).rejects.toThrow("AWS Bedrock error")
|
||||
})
|
||||
})
|
||||
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(
|
||||
JSON.stringify({
|
||||
content: "Test response",
|
||||
}),
|
||||
),
|
||||
}
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||
messages: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
role: 'user',
|
||||
content: [{ text: 'Test prompt' }]
|
||||
})
|
||||
]),
|
||||
inferenceConfig: expect.objectContaining({
|
||||
maxTokens: 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1
|
||||
})
|
||||
})
|
||||
}));
|
||||
});
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse)
|
||||
handler["client"] = {
|
||||
send: mockSend,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('AWS Bedrock error');
|
||||
const mockSend = jest.fn().mockRejectedValue(mockError);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockSend).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
messages: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
role: "user",
|
||||
content: [{ text: "Test prompt" }],
|
||||
}),
|
||||
]),
|
||||
inferenceConfig: expect.objectContaining({
|
||||
maxTokens: 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1,
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Bedrock completion error: AWS Bedrock error');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
const mockError = new Error("AWS Bedrock error")
|
||||
const mockSend = jest.fn().mockRejectedValue(mockError)
|
||||
handler["client"] = {
|
||||
send: mockSend,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
it('should handle invalid response format', async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode('invalid json')
|
||||
};
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"Bedrock completion error: AWS Bedrock error",
|
||||
)
|
||||
})
|
||||
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
it("should handle invalid response format", async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode("invalid json"),
|
||||
}
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse)
|
||||
handler["client"] = {
|
||||
send: mockSend,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(JSON.stringify({}))
|
||||
};
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
it("should handle empty response", async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(JSON.stringify({})),
|
||||
}
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse)
|
||||
handler["client"] = {
|
||||
send: mockSend,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
it('should handle cross-region inference', async () => {
|
||||
handler = new AwsBedrockHandler({
|
||||
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||
awsAccessKey: 'test-access-key',
|
||||
awsSecretKey: 'test-secret-key',
|
||||
awsRegion: 'us-east-1',
|
||||
awsUseCrossRegionInference: true
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(JSON.stringify({
|
||||
content: 'Test response'
|
||||
}))
|
||||
};
|
||||
it("should handle cross-region inference", async () => {
|
||||
handler = new AwsBedrockHandler({
|
||||
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
awsAccessKey: "test-access-key",
|
||||
awsSecretKey: "test-secret-key",
|
||||
awsRegion: "us-east-1",
|
||||
awsUseCrossRegionInference: true,
|
||||
})
|
||||
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(
|
||||
JSON.stringify({
|
||||
content: "Test response",
|
||||
}),
|
||||
),
|
||||
}
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: 'us.anthropic.claude-3-5-sonnet-20241022-v2:0'
|
||||
})
|
||||
}));
|
||||
});
|
||||
});
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse)
|
||||
handler["client"] = {
|
||||
send: mockSend,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return correct model info in test environment', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe('anthropic.claude-3-5-sonnet-20241022-v2:0');
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(5000); // Test environment value
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000); // Test environment value
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockSend).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
}),
|
||||
}),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('should return test model info for invalid model in test environment', () => {
|
||||
const invalidHandler = new AwsBedrockHandler({
|
||||
apiModelId: 'invalid-model',
|
||||
awsAccessKey: 'test-access-key',
|
||||
awsSecretKey: 'test-secret-key',
|
||||
awsRegion: 'us-east-1'
|
||||
});
|
||||
const modelInfo = invalidHandler.getModel();
|
||||
expect(modelInfo.id).toBe('invalid-model'); // In test env, returns whatever is passed
|
||||
expect(modelInfo.info.maxTokens).toBe(5000);
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||
});
|
||||
});
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return correct model info in test environment", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(5000) // Test environment value
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000) // Test environment value
|
||||
})
|
||||
|
||||
it("should return test model info for invalid model in test environment", () => {
|
||||
const invalidHandler = new AwsBedrockHandler({
|
||||
apiModelId: "invalid-model",
|
||||
awsAccessKey: "test-access-key",
|
||||
awsSecretKey: "test-secret-key",
|
||||
awsRegion: "us-east-1",
|
||||
})
|
||||
const modelInfo = invalidHandler.getModel()
|
||||
expect(modelInfo.id).toBe("invalid-model") // In test env, returns whatever is passed
|
||||
expect(modelInfo.info.maxTokens).toBe(5000)
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,203 +1,217 @@
|
||||
import { DeepSeekHandler } from '../deepseek';
|
||||
import { ApiHandlerOptions, deepSeekDefaultModelId } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { DeepSeekHandler } from "../deepseek"
|
||||
import { ApiHandlerOptions, deepSeekDefaultModelId } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response', refusal: null },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Return async iterator for streaming
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response", refusal: null },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
describe('DeepSeekHandler', () => {
|
||||
let handler: DeepSeekHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
// Return async iterator for streaming
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
deepSeekApiKey: 'test-api-key',
|
||||
deepSeekModelId: 'deepseek-chat',
|
||||
deepSeekBaseUrl: 'https://api.deepseek.com/v1'
|
||||
};
|
||||
handler = new DeepSeekHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
describe("DeepSeekHandler", () => {
|
||||
let handler: DeepSeekHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(DeepSeekHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.deepSeekModelId);
|
||||
});
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
deepSeekApiKey: "test-api-key",
|
||||
deepSeekModelId: "deepseek-chat",
|
||||
deepSeekBaseUrl: "https://api.deepseek.com/v1",
|
||||
}
|
||||
handler = new DeepSeekHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
it('should throw error if API key is missing', () => {
|
||||
expect(() => {
|
||||
new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekApiKey: undefined
|
||||
});
|
||||
}).toThrow('DeepSeek API key is required');
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(DeepSeekHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.deepSeekModelId)
|
||||
})
|
||||
|
||||
it('should use default model ID if not provided', () => {
|
||||
const handlerWithoutModel = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekModelId: undefined
|
||||
});
|
||||
expect(handlerWithoutModel.getModel().id).toBe(deepSeekDefaultModelId);
|
||||
});
|
||||
it("should throw error if API key is missing", () => {
|
||||
expect(() => {
|
||||
new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekApiKey: undefined,
|
||||
})
|
||||
}).toThrow("DeepSeek API key is required")
|
||||
})
|
||||
|
||||
it('should use default base URL if not provided', () => {
|
||||
const handlerWithoutBaseUrl = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekBaseUrl: undefined
|
||||
});
|
||||
expect(handlerWithoutBaseUrl).toBeInstanceOf(DeepSeekHandler);
|
||||
// The base URL is passed to OpenAI client internally
|
||||
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({
|
||||
baseURL: 'https://api.deepseek.com/v1'
|
||||
}));
|
||||
});
|
||||
it("should use default model ID if not provided", () => {
|
||||
const handlerWithoutModel = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekModelId: undefined,
|
||||
})
|
||||
expect(handlerWithoutModel.getModel().id).toBe(deepSeekDefaultModelId)
|
||||
})
|
||||
|
||||
it('should use custom base URL if provided', () => {
|
||||
const customBaseUrl = 'https://custom.deepseek.com/v1';
|
||||
const handlerWithCustomUrl = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekBaseUrl: customBaseUrl
|
||||
});
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(DeepSeekHandler);
|
||||
// The custom base URL is passed to OpenAI client
|
||||
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({
|
||||
baseURL: customBaseUrl
|
||||
}));
|
||||
});
|
||||
it("should use default base URL if not provided", () => {
|
||||
const handlerWithoutBaseUrl = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekBaseUrl: undefined,
|
||||
})
|
||||
expect(handlerWithoutBaseUrl).toBeInstanceOf(DeepSeekHandler)
|
||||
// The base URL is passed to OpenAI client internally
|
||||
expect(OpenAI).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
baseURL: "https://api.deepseek.com/v1",
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
it('should set includeMaxTokens to true', () => {
|
||||
// Create a new handler and verify OpenAI client was called with includeMaxTokens
|
||||
new DeepSeekHandler(mockOptions);
|
||||
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({
|
||||
apiKey: mockOptions.deepSeekApiKey
|
||||
}));
|
||||
});
|
||||
});
|
||||
it("should use custom base URL if provided", () => {
|
||||
const customBaseUrl = "https://custom.deepseek.com/v1"
|
||||
const handlerWithCustomUrl = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekBaseUrl: customBaseUrl,
|
||||
})
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(DeepSeekHandler)
|
||||
// The custom base URL is passed to OpenAI client
|
||||
expect(OpenAI).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
baseURL: customBaseUrl,
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info for valid model ID', () => {
|
||||
const model = handler.getModel();
|
||||
expect(model.id).toBe(mockOptions.deepSeekModelId);
|
||||
expect(model.info).toBeDefined();
|
||||
expect(model.info.maxTokens).toBe(8192);
|
||||
expect(model.info.contextWindow).toBe(64_000);
|
||||
expect(model.info.supportsImages).toBe(false);
|
||||
expect(model.info.supportsPromptCache).toBe(false);
|
||||
});
|
||||
it("should set includeMaxTokens to true", () => {
|
||||
// Create a new handler and verify OpenAI client was called with includeMaxTokens
|
||||
new DeepSeekHandler(mockOptions)
|
||||
expect(OpenAI).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
apiKey: mockOptions.deepSeekApiKey,
|
||||
}),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('should return provided model ID with default model info if model does not exist', () => {
|
||||
const handlerWithInvalidModel = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekModelId: 'invalid-model'
|
||||
});
|
||||
const model = handlerWithInvalidModel.getModel();
|
||||
expect(model.id).toBe('invalid-model'); // Returns provided ID
|
||||
expect(model.info).toBeDefined();
|
||||
expect(model.info).toBe(handler.getModel().info); // But uses default model info
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return model info for valid model ID", () => {
|
||||
const model = handler.getModel()
|
||||
expect(model.id).toBe(mockOptions.deepSeekModelId)
|
||||
expect(model.info).toBeDefined()
|
||||
expect(model.info.maxTokens).toBe(8192)
|
||||
expect(model.info.contextWindow).toBe(64_000)
|
||||
expect(model.info.supportsImages).toBe(false)
|
||||
expect(model.info.supportsPromptCache).toBe(false)
|
||||
})
|
||||
|
||||
it('should return default model if no model ID is provided', () => {
|
||||
const handlerWithoutModel = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekModelId: undefined
|
||||
});
|
||||
const model = handlerWithoutModel.getModel();
|
||||
expect(model.id).toBe(deepSeekDefaultModelId);
|
||||
expect(model.info).toBeDefined();
|
||||
});
|
||||
});
|
||||
it("should return provided model ID with default model info if model does not exist", () => {
|
||||
const handlerWithInvalidModel = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekModelId: "invalid-model",
|
||||
})
|
||||
const model = handlerWithInvalidModel.getModel()
|
||||
expect(model.id).toBe("invalid-model") // Returns provided ID
|
||||
expect(model.info).toBeDefined()
|
||||
expect(model.info).toBe(handler.getModel().info) // But uses default model info
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: [{
|
||||
type: 'text' as const,
|
||||
text: 'Hello!'
|
||||
}]
|
||||
}
|
||||
];
|
||||
it("should return default model if no model ID is provided", () => {
|
||||
const handlerWithoutModel = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekModelId: undefined,
|
||||
})
|
||||
const model = handlerWithoutModel.getModel()
|
||||
expect(model.id).toBe(deepSeekDefaultModelId)
|
||||
expect(model.info).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: "Hello!",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
it("should handle streaming responses", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
it('should include usage information', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(1)
|
||||
expect(textChunks[0].text).toBe("Test response")
|
||||
})
|
||||
|
||||
const usageChunks = chunks.filter(chunk => chunk.type === 'usage');
|
||||
expect(usageChunks.length).toBeGreaterThan(0);
|
||||
expect(usageChunks[0].inputTokens).toBe(10);
|
||||
expect(usageChunks[0].outputTokens).toBe(5);
|
||||
});
|
||||
});
|
||||
});
|
||||
it("should include usage information", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
|
||||
expect(usageChunks.length).toBeGreaterThan(0)
|
||||
expect(usageChunks[0].inputTokens).toBe(10)
|
||||
expect(usageChunks[0].outputTokens).toBe(5)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,212 +1,210 @@
|
||||
import { GeminiHandler } from '../gemini';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { GoogleGenerativeAI } from '@google/generative-ai';
|
||||
import { GeminiHandler } from "../gemini"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { GoogleGenerativeAI } from "@google/generative-ai"
|
||||
|
||||
// Mock the Google Generative AI SDK
|
||||
jest.mock('@google/generative-ai', () => ({
|
||||
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
|
||||
getGenerativeModel: jest.fn().mockReturnValue({
|
||||
generateContentStream: jest.fn(),
|
||||
generateContent: jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => 'Test response'
|
||||
}
|
||||
})
|
||||
})
|
||||
}))
|
||||
}));
|
||||
jest.mock("@google/generative-ai", () => ({
|
||||
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
|
||||
getGenerativeModel: jest.fn().mockReturnValue({
|
||||
generateContentStream: jest.fn(),
|
||||
generateContent: jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => "Test response",
|
||||
},
|
||||
}),
|
||||
}),
|
||||
})),
|
||||
}))
|
||||
|
||||
describe('GeminiHandler', () => {
|
||||
let handler: GeminiHandler;
|
||||
describe("GeminiHandler", () => {
|
||||
let handler: GeminiHandler
|
||||
|
||||
beforeEach(() => {
|
||||
handler = new GeminiHandler({
|
||||
apiKey: 'test-key',
|
||||
apiModelId: 'gemini-2.0-flash-thinking-exp-1219',
|
||||
geminiApiKey: 'test-key'
|
||||
});
|
||||
});
|
||||
beforeEach(() => {
|
||||
handler = new GeminiHandler({
|
||||
apiKey: "test-key",
|
||||
apiModelId: "gemini-2.0-flash-thinking-exp-1219",
|
||||
geminiApiKey: "test-key",
|
||||
})
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided config', () => {
|
||||
expect(handler['options'].geminiApiKey).toBe('test-key');
|
||||
expect(handler['options'].apiModelId).toBe('gemini-2.0-flash-thinking-exp-1219');
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided config", () => {
|
||||
expect(handler["options"].geminiApiKey).toBe("test-key")
|
||||
expect(handler["options"].apiModelId).toBe("gemini-2.0-flash-thinking-exp-1219")
|
||||
})
|
||||
|
||||
it('should throw if API key is missing', () => {
|
||||
expect(() => {
|
||||
new GeminiHandler({
|
||||
apiModelId: 'gemini-2.0-flash-thinking-exp-1219',
|
||||
geminiApiKey: ''
|
||||
});
|
||||
}).toThrow('API key is required for Google Gemini');
|
||||
});
|
||||
});
|
||||
it("should throw if API key is missing", () => {
|
||||
expect(() => {
|
||||
new GeminiHandler({
|
||||
apiModelId: "gemini-2.0-flash-thinking-exp-1219",
|
||||
geminiApiKey: "",
|
||||
})
|
||||
}).toThrow("API key is required for Google Gemini")
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello'
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi there!'
|
||||
}
|
||||
];
|
||||
describe("createMessage", () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello",
|
||||
},
|
||||
{
|
||||
role: "assistant",
|
||||
content: "Hi there!",
|
||||
},
|
||||
]
|
||||
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
|
||||
it('should handle text messages correctly', async () => {
|
||||
// Mock the stream response
|
||||
const mockStream = {
|
||||
stream: [
|
||||
{ text: () => 'Hello' },
|
||||
{ text: () => ' world!' }
|
||||
],
|
||||
response: {
|
||||
usageMetadata: {
|
||||
promptTokenCount: 10,
|
||||
candidatesTokenCount: 5
|
||||
}
|
||||
}
|
||||
};
|
||||
it("should handle text messages correctly", async () => {
|
||||
// Mock the stream response
|
||||
const mockStream = {
|
||||
stream: [{ text: () => "Hello" }, { text: () => " world!" }],
|
||||
response: {
|
||||
usageMetadata: {
|
||||
promptTokenCount: 10,
|
||||
candidatesTokenCount: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Setup the mock implementation
|
||||
const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream);
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContentStream: mockGenerateContentStream
|
||||
});
|
||||
// Setup the mock implementation
|
||||
const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream)
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContentStream: mockGenerateContentStream,
|
||||
})
|
||||
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const chunks = [];
|
||||
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
const chunks = []
|
||||
|
||||
// Should have 3 chunks: 'Hello', ' world!', and usage info
|
||||
expect(chunks.length).toBe(3);
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Hello'
|
||||
});
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'text',
|
||||
text: ' world!'
|
||||
});
|
||||
expect(chunks[2]).toEqual({
|
||||
type: 'usage',
|
||||
inputTokens: 10,
|
||||
outputTokens: 5
|
||||
});
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
// Verify the model configuration
|
||||
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
|
||||
model: 'gemini-2.0-flash-thinking-exp-1219',
|
||||
systemInstruction: systemPrompt
|
||||
});
|
||||
// Should have 3 chunks: 'Hello', ' world!', and usage info
|
||||
expect(chunks.length).toBe(3)
|
||||
expect(chunks[0]).toEqual({
|
||||
type: "text",
|
||||
text: "Hello",
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: "text",
|
||||
text: " world!",
|
||||
})
|
||||
expect(chunks[2]).toEqual({
|
||||
type: "usage",
|
||||
inputTokens: 10,
|
||||
outputTokens: 5,
|
||||
})
|
||||
|
||||
// Verify generation config
|
||||
expect(mockGenerateContentStream).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
generationConfig: {
|
||||
temperature: 0
|
||||
}
|
||||
})
|
||||
);
|
||||
});
|
||||
// Verify the model configuration
|
||||
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
|
||||
model: "gemini-2.0-flash-thinking-exp-1219",
|
||||
systemInstruction: systemPrompt,
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('Gemini API error');
|
||||
const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError);
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContentStream: mockGenerateContentStream
|
||||
});
|
||||
// Verify generation config
|
||||
expect(mockGenerateContentStream).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
generationConfig: {
|
||||
temperature: 0,
|
||||
},
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
it("should handle API errors", async () => {
|
||||
const mockError = new Error("Gemini API error")
|
||||
const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError)
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContentStream: mockGenerateContentStream,
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
}
|
||||
}).rejects.toThrow('Gemini API error');
|
||||
});
|
||||
});
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const mockGenerateContent = jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => 'Test response'
|
||||
}
|
||||
});
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent
|
||||
});
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
}
|
||||
}).rejects.toThrow("Gemini API error")
|
||||
})
|
||||
})
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
|
||||
model: 'gemini-2.0-flash-thinking-exp-1219'
|
||||
});
|
||||
expect(mockGenerateContent).toHaveBeenCalledWith({
|
||||
contents: [{ role: 'user', parts: [{ text: 'Test prompt' }] }],
|
||||
generationConfig: {
|
||||
temperature: 0
|
||||
}
|
||||
});
|
||||
});
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const mockGenerateContent = jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => "Test response",
|
||||
},
|
||||
})
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
})
|
||||
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('Gemini API error');
|
||||
const mockGenerateContent = jest.fn().mockRejectedValue(mockError);
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent
|
||||
});
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
|
||||
model: "gemini-2.0-flash-thinking-exp-1219",
|
||||
})
|
||||
expect(mockGenerateContent).toHaveBeenCalledWith({
|
||||
contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
|
||||
generationConfig: {
|
||||
temperature: 0,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Gemini completion error: Gemini API error');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
const mockError = new Error("Gemini API error")
|
||||
const mockGenerateContent = jest.fn().mockRejectedValue(mockError)
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
})
|
||||
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
const mockGenerateContent = jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => ''
|
||||
}
|
||||
});
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent
|
||||
});
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"Gemini completion error: Gemini API error",
|
||||
)
|
||||
})
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
it("should handle empty response", async () => {
|
||||
const mockGenerateContent = jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => "",
|
||||
},
|
||||
})
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
})
|
||||
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return correct model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe('gemini-2.0-flash-thinking-exp-1219');
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(8192);
|
||||
expect(modelInfo.info.contextWindow).toBe(32_767);
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
it('should return default model if invalid model specified', () => {
|
||||
const invalidHandler = new GeminiHandler({
|
||||
apiModelId: 'invalid-model',
|
||||
geminiApiKey: 'test-key'
|
||||
});
|
||||
const modelInfo = invalidHandler.getModel();
|
||||
expect(modelInfo.id).toBe('gemini-2.0-flash-thinking-exp-1219'); // Default model
|
||||
});
|
||||
});
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return correct model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219")
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(8192)
|
||||
expect(modelInfo.info.contextWindow).toBe(32_767)
|
||||
})
|
||||
|
||||
it("should return default model if invalid model specified", () => {
|
||||
const invalidHandler = new GeminiHandler({
|
||||
apiModelId: "invalid-model",
|
||||
geminiApiKey: "test-key",
|
||||
})
|
||||
const modelInfo = invalidHandler.getModel()
|
||||
expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219") // Default model
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,226 +1,238 @@
|
||||
import { GlamaHandler } from '../glama';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import axios from 'axios';
|
||||
import { GlamaHandler } from "../glama"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import axios from "axios"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
const mockWithResponse = jest.fn();
|
||||
const mockCreate = jest.fn()
|
||||
const mockWithResponse = jest.fn()
|
||||
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: (...args: any[]) => {
|
||||
const stream = {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: (...args: any[]) => {
|
||||
const stream = {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const result = mockCreate(...args);
|
||||
if (args[0].stream) {
|
||||
mockWithResponse.mockReturnValue(Promise.resolve({
|
||||
data: stream,
|
||||
response: {
|
||||
headers: {
|
||||
get: (name: string) => name === 'x-completion-request-id' ? 'test-request-id' : null
|
||||
}
|
||||
}
|
||||
}));
|
||||
result.withResponse = mockWithResponse;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
const result = mockCreate(...args)
|
||||
if (args[0].stream) {
|
||||
mockWithResponse.mockReturnValue(
|
||||
Promise.resolve({
|
||||
data: stream,
|
||||
response: {
|
||||
headers: {
|
||||
get: (name: string) =>
|
||||
name === "x-completion-request-id" ? "test-request-id" : null,
|
||||
},
|
||||
},
|
||||
}),
|
||||
)
|
||||
result.withResponse = mockWithResponse
|
||||
}
|
||||
return result
|
||||
},
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
})
|
||||
|
||||
describe('GlamaHandler', () => {
|
||||
let handler: GlamaHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
describe("GlamaHandler", () => {
|
||||
let handler: GlamaHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: 'anthropic/claude-3-5-sonnet',
|
||||
glamaModelId: 'anthropic/claude-3-5-sonnet',
|
||||
glamaApiKey: 'test-api-key'
|
||||
};
|
||||
handler = new GlamaHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
mockWithResponse.mockClear();
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: "anthropic/claude-3-5-sonnet",
|
||||
glamaModelId: "anthropic/claude-3-5-sonnet",
|
||||
glamaApiKey: "test-api-key",
|
||||
}
|
||||
handler = new GlamaHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
mockWithResponse.mockClear()
|
||||
|
||||
// Default mock implementation for non-streaming responses
|
||||
mockCreate.mockResolvedValue({
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
});
|
||||
});
|
||||
// Default mock implementation for non-streaming responses
|
||||
mockCreate.mockResolvedValue({
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response" },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(GlamaHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
|
||||
});
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(GlamaHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello!'
|
||||
}
|
||||
];
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello!",
|
||||
},
|
||||
]
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
// Mock axios for token usage request
|
||||
const mockAxios = jest.spyOn(axios, 'get').mockResolvedValueOnce({
|
||||
data: {
|
||||
tokenUsage: {
|
||||
promptTokens: 10,
|
||||
completionTokens: 5,
|
||||
cacheCreationInputTokens: 0,
|
||||
cacheReadInputTokens: 0
|
||||
},
|
||||
totalCostUsd: "0.00"
|
||||
}
|
||||
});
|
||||
it("should handle streaming responses", async () => {
|
||||
// Mock axios for token usage request
|
||||
const mockAxios = jest.spyOn(axios, "get").mockResolvedValueOnce({
|
||||
data: {
|
||||
tokenUsage: {
|
||||
promptTokens: 10,
|
||||
completionTokens: 5,
|
||||
cacheCreationInputTokens: 0,
|
||||
cacheReadInputTokens: 0,
|
||||
},
|
||||
totalCostUsd: "0.00",
|
||||
},
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks.length).toBe(2); // Text chunk and usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Test response'
|
||||
});
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'usage',
|
||||
inputTokens: 10,
|
||||
outputTokens: 5,
|
||||
cacheWriteTokens: 0,
|
||||
cacheReadTokens: 0,
|
||||
totalCost: 0
|
||||
});
|
||||
expect(chunks.length).toBe(2) // Text chunk and usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: "text",
|
||||
text: "Test response",
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: "usage",
|
||||
inputTokens: 10,
|
||||
outputTokens: 5,
|
||||
cacheWriteTokens: 0,
|
||||
cacheReadTokens: 0,
|
||||
totalCost: 0,
|
||||
})
|
||||
|
||||
mockAxios.mockRestore();
|
||||
});
|
||||
mockAxios.mockRestore()
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockImplementationOnce(() => {
|
||||
throw new Error('API Error');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockImplementationOnce(() => {
|
||||
throw new Error("API Error")
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks = [];
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks = []
|
||||
|
||||
try {
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
fail('Expected error to be thrown');
|
||||
} catch (error) {
|
||||
expect(error).toBeInstanceOf(Error);
|
||||
expect(error.message).toBe('API Error');
|
||||
}
|
||||
});
|
||||
});
|
||||
try {
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
fail("Expected error to be thrown")
|
||||
} catch (error) {
|
||||
expect(error).toBeInstanceOf(Error)
|
||||
expect(error.message).toBe("API Error")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
model: mockOptions.apiModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0,
|
||||
max_tokens: 8192
|
||||
}));
|
||||
});
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: mockOptions.apiModelId,
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
max_tokens: 8192,
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Glama completion error: API Error');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Glama completion error: API Error")
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: "" } }],
|
||||
})
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
|
||||
it('should not set max_tokens for non-Anthropic models', async () => {
|
||||
// Reset mock to clear any previous calls
|
||||
mockCreate.mockClear();
|
||||
|
||||
const nonAnthropicOptions = {
|
||||
apiModelId: 'openai/gpt-4',
|
||||
glamaModelId: 'openai/gpt-4',
|
||||
glamaApiKey: 'test-key',
|
||||
glamaModelInfo: {
|
||||
maxTokens: 4096,
|
||||
contextWindow: 8192,
|
||||
supportsImages: true,
|
||||
supportsPromptCache: false
|
||||
}
|
||||
};
|
||||
const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions);
|
||||
it("should not set max_tokens for non-Anthropic models", async () => {
|
||||
// Reset mock to clear any previous calls
|
||||
mockCreate.mockClear()
|
||||
|
||||
await nonAnthropicHandler.completePrompt('Test prompt');
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
model: 'openai/gpt-4',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0
|
||||
}));
|
||||
expect(mockCreate.mock.calls[0][0]).not.toHaveProperty('max_tokens');
|
||||
});
|
||||
});
|
||||
const nonAnthropicOptions = {
|
||||
apiModelId: "openai/gpt-4",
|
||||
glamaModelId: "openai/gpt-4",
|
||||
glamaApiKey: "test-key",
|
||||
glamaModelInfo: {
|
||||
maxTokens: 4096,
|
||||
contextWindow: 8192,
|
||||
supportsImages: true,
|
||||
supportsPromptCache: false,
|
||||
},
|
||||
}
|
||||
const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions)
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe(mockOptions.apiModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(8192);
|
||||
expect(modelInfo.info.contextWindow).toBe(200_000);
|
||||
});
|
||||
});
|
||||
});
|
||||
await nonAnthropicHandler.completePrompt("Test prompt")
|
||||
expect(mockCreate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: "openai/gpt-4",
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
}),
|
||||
)
|
||||
expect(mockCreate.mock.calls[0][0]).not.toHaveProperty("max_tokens")
|
||||
})
|
||||
})
|
||||
|
||||
describe("getModel", () => {
|
||||
it("should return model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe(mockOptions.apiModelId)
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(8192)
|
||||
expect(modelInfo.info.contextWindow).toBe(200_000)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,160 +1,167 @@
|
||||
import { LmStudioHandler } from '../lmstudio';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { LmStudioHandler } from "../lmstudio"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response" },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
describe('LmStudioHandler', () => {
|
||||
let handler: LmStudioHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: 'local-model',
|
||||
lmStudioModelId: 'local-model',
|
||||
lmStudioBaseUrl: 'http://localhost:1234/v1'
|
||||
};
|
||||
handler = new LmStudioHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
describe("LmStudioHandler", () => {
|
||||
let handler: LmStudioHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(LmStudioHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId);
|
||||
});
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: "local-model",
|
||||
lmStudioModelId: "local-model",
|
||||
lmStudioBaseUrl: "http://localhost:1234/v1",
|
||||
}
|
||||
handler = new LmStudioHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
it('should use default base URL if not provided', () => {
|
||||
const handlerWithoutUrl = new LmStudioHandler({
|
||||
apiModelId: 'local-model',
|
||||
lmStudioModelId: 'local-model'
|
||||
});
|
||||
expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler);
|
||||
});
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(LmStudioHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId)
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello!'
|
||||
}
|
||||
];
|
||||
it("should use default base URL if not provided", () => {
|
||||
const handlerWithoutUrl = new LmStudioHandler({
|
||||
apiModelId: "local-model",
|
||||
lmStudioModelId: "local-model",
|
||||
})
|
||||
expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler)
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello!",
|
||||
},
|
||||
]
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
it("should handle streaming responses", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(1)
|
||||
expect(textChunks[0].text).toBe("Test response")
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
|
||||
});
|
||||
});
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.lmStudioModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
});
|
||||
});
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow("Please check the LM Studio developer logs to debug what went wrong")
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
|
||||
});
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.lmStudioModelId,
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
stream: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"Please check the LM Studio developer logs to debug what went wrong",
|
||||
)
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe(mockOptions.lmStudioModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(-1);
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||
});
|
||||
});
|
||||
});
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: "" } }],
|
||||
})
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe("getModel", () => {
|
||||
it("should return model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe(mockOptions.lmStudioModelId)
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(-1)
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,160 +1,165 @@
|
||||
import { OllamaHandler } from '../ollama';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { OllamaHandler } from "../ollama"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response" },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
describe('OllamaHandler', () => {
|
||||
let handler: OllamaHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: 'llama2',
|
||||
ollamaModelId: 'llama2',
|
||||
ollamaBaseUrl: 'http://localhost:11434/v1'
|
||||
};
|
||||
handler = new OllamaHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
describe("OllamaHandler", () => {
|
||||
let handler: OllamaHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(OllamaHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.ollamaModelId);
|
||||
});
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: "llama2",
|
||||
ollamaModelId: "llama2",
|
||||
ollamaBaseUrl: "http://localhost:11434/v1",
|
||||
}
|
||||
handler = new OllamaHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
it('should use default base URL if not provided', () => {
|
||||
const handlerWithoutUrl = new OllamaHandler({
|
||||
apiModelId: 'llama2',
|
||||
ollamaModelId: 'llama2'
|
||||
});
|
||||
expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler);
|
||||
});
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(OllamaHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.ollamaModelId)
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello!'
|
||||
}
|
||||
];
|
||||
it("should use default base URL if not provided", () => {
|
||||
const handlerWithoutUrl = new OllamaHandler({
|
||||
apiModelId: "llama2",
|
||||
ollamaModelId: "llama2",
|
||||
})
|
||||
expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler)
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello!",
|
||||
},
|
||||
]
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
it("should handle streaming responses", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(1)
|
||||
expect(textChunks[0].text).toBe("Test response")
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow('API Error');
|
||||
});
|
||||
});
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.ollamaModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
});
|
||||
});
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow("API Error")
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Ollama completion error: API Error');
|
||||
});
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.ollamaModelId,
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
stream: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Ollama completion error: API Error")
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe(mockOptions.ollamaModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(-1);
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||
});
|
||||
});
|
||||
});
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: "" } }],
|
||||
})
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe("getModel", () => {
|
||||
it("should return model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe(mockOptions.ollamaModelId)
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(-1)
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,319 +1,326 @@
|
||||
import { OpenAiNativeHandler } from '../openai-native';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { OpenAiNativeHandler } from "../openai-native"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response" },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
describe('OpenAiNativeHandler', () => {
|
||||
let handler: OpenAiNativeHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello!'
|
||||
}
|
||||
];
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: 'gpt-4o',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
};
|
||||
handler = new OpenAiNativeHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
describe("OpenAiNativeHandler", () => {
|
||||
let handler: OpenAiNativeHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello!",
|
||||
},
|
||||
]
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(OpenAiNativeHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
|
||||
});
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: "gpt-4o",
|
||||
openAiNativeApiKey: "test-api-key",
|
||||
}
|
||||
handler = new OpenAiNativeHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
it('should initialize with empty API key', () => {
|
||||
const handlerWithoutKey = new OpenAiNativeHandler({
|
||||
apiModelId: 'gpt-4o',
|
||||
openAiNativeApiKey: ''
|
||||
});
|
||||
expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler);
|
||||
});
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(OpenAiNativeHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
it("should initialize with empty API key", () => {
|
||||
const handlerWithoutKey = new OpenAiNativeHandler({
|
||||
apiModelId: "gpt-4o",
|
||||
openAiNativeApiKey: "",
|
||||
})
|
||||
expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler)
|
||||
})
|
||||
})
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
describe("createMessage", () => {
|
||||
it("should handle streaming responses", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow('API Error');
|
||||
});
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(1)
|
||||
expect(textChunks[0].text).toBe("Test response")
|
||||
})
|
||||
|
||||
it('should handle missing content in response for o1 model', async () => {
|
||||
// Use o1 model which supports developer role
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: 'o1'
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow("API Error")
|
||||
})
|
||||
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: null } }],
|
||||
usage: {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0
|
||||
}
|
||||
});
|
||||
it("should handle missing content in response for o1 model", async () => {
|
||||
// Use o1 model which supports developer role
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: "o1",
|
||||
})
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages);
|
||||
const results = [];
|
||||
for await (const result of generator) {
|
||||
results.push(result);
|
||||
}
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: null } }],
|
||||
usage: {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
},
|
||||
})
|
||||
|
||||
expect(results).toEqual([
|
||||
{ type: 'text', text: '' },
|
||||
{ type: 'usage', inputTokens: 0, outputTokens: 0 }
|
||||
]);
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const results = []
|
||||
for await (const result of generator) {
|
||||
results.push(result)
|
||||
}
|
||||
|
||||
// Verify developer role is used for system prompt with o1 model
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1',
|
||||
messages: [
|
||||
{ role: 'developer', content: systemPrompt },
|
||||
{ role: 'user', content: 'Hello!' }
|
||||
]
|
||||
});
|
||||
});
|
||||
});
|
||||
expect(results).toEqual([
|
||||
{ type: "text", text: "" },
|
||||
{ type: "usage", inputTokens: 0, outputTokens: 0 },
|
||||
])
|
||||
|
||||
describe('streaming models', () => {
|
||||
beforeEach(() => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: 'gpt-4o',
|
||||
});
|
||||
});
|
||||
// Verify developer role is used for system prompt with o1 model
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "o1",
|
||||
messages: [
|
||||
{ role: "developer", content: systemPrompt },
|
||||
{ role: "user", content: "Hello!" },
|
||||
],
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle streaming response', async () => {
|
||||
const mockStream = [
|
||||
{ choices: [{ delta: { content: 'Hello' } }], usage: null },
|
||||
{ choices: [{ delta: { content: ' there' } }], usage: null },
|
||||
{ choices: [{ delta: { content: '!' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||
];
|
||||
describe("streaming models", () => {
|
||||
beforeEach(() => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: "gpt-4o",
|
||||
})
|
||||
})
|
||||
|
||||
mockCreate.mockResolvedValueOnce(
|
||||
(async function* () {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk;
|
||||
}
|
||||
})()
|
||||
);
|
||||
it("should handle streaming response", async () => {
|
||||
const mockStream = [
|
||||
{ choices: [{ delta: { content: "Hello" } }], usage: null },
|
||||
{ choices: [{ delta: { content: " there" } }], usage: null },
|
||||
{ choices: [{ delta: { content: "!" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||
]
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages);
|
||||
const results = [];
|
||||
for await (const result of generator) {
|
||||
results.push(result);
|
||||
}
|
||||
mockCreate.mockResolvedValueOnce(
|
||||
(async function* () {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk
|
||||
}
|
||||
})(),
|
||||
)
|
||||
|
||||
expect(results).toEqual([
|
||||
{ type: 'text', text: 'Hello' },
|
||||
{ type: 'text', text: ' there' },
|
||||
{ type: 'text', text: '!' },
|
||||
{ type: 'usage', inputTokens: 10, outputTokens: 5 },
|
||||
]);
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const results = []
|
||||
for await (const result of generator) {
|
||||
results.push(result)
|
||||
}
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'gpt-4o',
|
||||
temperature: 0,
|
||||
messages: [
|
||||
{ role: 'system', content: systemPrompt },
|
||||
{ role: 'user', content: 'Hello!' },
|
||||
],
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
});
|
||||
});
|
||||
expect(results).toEqual([
|
||||
{ type: "text", text: "Hello" },
|
||||
{ type: "text", text: " there" },
|
||||
{ type: "text", text: "!" },
|
||||
{ type: "usage", inputTokens: 10, outputTokens: 5 },
|
||||
])
|
||||
|
||||
it('should handle empty delta content', async () => {
|
||||
const mockStream = [
|
||||
{ choices: [{ delta: {} }], usage: null },
|
||||
{ choices: [{ delta: { content: null } }], usage: null },
|
||||
{ choices: [{ delta: { content: 'Hello' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||
];
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "gpt-4o",
|
||||
temperature: 0,
|
||||
messages: [
|
||||
{ role: "system", content: systemPrompt },
|
||||
{ role: "user", content: "Hello!" },
|
||||
],
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
})
|
||||
})
|
||||
|
||||
mockCreate.mockResolvedValueOnce(
|
||||
(async function* () {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk;
|
||||
}
|
||||
})()
|
||||
);
|
||||
it("should handle empty delta content", async () => {
|
||||
const mockStream = [
|
||||
{ choices: [{ delta: {} }], usage: null },
|
||||
{ choices: [{ delta: { content: null } }], usage: null },
|
||||
{ choices: [{ delta: { content: "Hello" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||
]
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages);
|
||||
const results = [];
|
||||
for await (const result of generator) {
|
||||
results.push(result);
|
||||
}
|
||||
mockCreate.mockResolvedValueOnce(
|
||||
(async function* () {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk
|
||||
}
|
||||
})(),
|
||||
)
|
||||
|
||||
expect(results).toEqual([
|
||||
{ type: 'text', text: 'Hello' },
|
||||
{ type: 'usage', inputTokens: 10, outputTokens: 5 },
|
||||
]);
|
||||
});
|
||||
});
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const results = []
|
||||
for await (const result of generator) {
|
||||
results.push(result)
|
||||
}
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully with gpt-4o model', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'gpt-4o',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0
|
||||
});
|
||||
});
|
||||
expect(results).toEqual([
|
||||
{ type: "text", text: "Hello" },
|
||||
{ type: "usage", inputTokens: 10, outputTokens: 5 },
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
it('should complete prompt successfully with o1 model', async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: 'o1',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully with gpt-4o model", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "gpt-4o",
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
})
|
||||
})
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||
});
|
||||
});
|
||||
it("should complete prompt successfully with o1 model", async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: "o1",
|
||||
openAiNativeApiKey: "test-api-key",
|
||||
})
|
||||
|
||||
it('should complete prompt successfully with o1-preview model', async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: 'o1-preview',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "o1",
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
})
|
||||
})
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1-preview',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||
});
|
||||
});
|
||||
it("should complete prompt successfully with o1-preview model", async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: "o1-preview",
|
||||
openAiNativeApiKey: "test-api-key",
|
||||
})
|
||||
|
||||
it('should complete prompt successfully with o1-mini model', async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: 'o1-mini',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "o1-preview",
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
})
|
||||
})
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1-mini',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||
});
|
||||
});
|
||||
it("should complete prompt successfully with o1-mini model", async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: "o1-mini",
|
||||
openAiNativeApiKey: "test-api-key",
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('OpenAI Native completion error: API Error');
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "o1-mini",
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"OpenAI Native completion error: API Error",
|
||||
)
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe(mockOptions.apiModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(4096);
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||
});
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: "" } }],
|
||||
})
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle undefined model ID', () => {
|
||||
const handlerWithoutModel = new OpenAiNativeHandler({
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
const modelInfo = handlerWithoutModel.getModel();
|
||||
expect(modelInfo.id).toBe('gpt-4o'); // Default model
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe(mockOptions.apiModelId)
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(4096)
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000)
|
||||
})
|
||||
|
||||
it("should handle undefined model ID", () => {
|
||||
const handlerWithoutModel = new OpenAiNativeHandler({
|
||||
openAiNativeApiKey: "test-api-key",
|
||||
})
|
||||
const modelInfo = handlerWithoutModel.getModel()
|
||||
expect(modelInfo.id).toBe("gpt-4o") // Default model
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,224 +1,233 @@
|
||||
import { OpenAiHandler } from '../openai';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import { ApiStream } from '../../transform/stream';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { OpenAiHandler } from "../openai"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import { ApiStream } from "../../transform/stream"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response', refusal: null },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response", refusal: null },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
describe('OpenAiHandler', () => {
|
||||
let handler: OpenAiHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
openAiApiKey: 'test-api-key',
|
||||
openAiModelId: 'gpt-4',
|
||||
openAiBaseUrl: 'https://api.openai.com/v1'
|
||||
};
|
||||
handler = new OpenAiHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
describe("OpenAiHandler", () => {
|
||||
let handler: OpenAiHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(OpenAiHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.openAiModelId);
|
||||
});
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
openAiApiKey: "test-api-key",
|
||||
openAiModelId: "gpt-4",
|
||||
openAiBaseUrl: "https://api.openai.com/v1",
|
||||
}
|
||||
handler = new OpenAiHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
it('should use custom base URL if provided', () => {
|
||||
const customBaseUrl = 'https://custom.openai.com/v1';
|
||||
const handlerWithCustomUrl = new OpenAiHandler({
|
||||
...mockOptions,
|
||||
openAiBaseUrl: customBaseUrl
|
||||
});
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(OpenAiHandler);
|
||||
});
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(OpenAiHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.openAiModelId)
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: [{
|
||||
type: 'text' as const,
|
||||
text: 'Hello!'
|
||||
}]
|
||||
}
|
||||
];
|
||||
it("should use custom base URL if provided", () => {
|
||||
const customBaseUrl = "https://custom.openai.com/v1"
|
||||
const handlerWithCustomUrl = new OpenAiHandler({
|
||||
...mockOptions,
|
||||
openAiBaseUrl: customBaseUrl,
|
||||
})
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(OpenAiHandler)
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle non-streaming mode', async () => {
|
||||
const handler = new OpenAiHandler({
|
||||
...mockOptions,
|
||||
openAiStreamingEnabled: false
|
||||
});
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: "Hello!",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
it("should handle non-streaming mode", async () => {
|
||||
const handler = new OpenAiHandler({
|
||||
...mockOptions,
|
||||
openAiStreamingEnabled: false,
|
||||
})
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunk = chunks.find(chunk => chunk.type === 'text');
|
||||
const usageChunk = chunks.find(chunk => chunk.type === 'usage');
|
||||
|
||||
expect(textChunk).toBeDefined();
|
||||
expect(textChunk?.text).toBe('Test response');
|
||||
expect(usageChunk).toBeDefined();
|
||||
expect(usageChunk?.inputTokens).toBe(10);
|
||||
expect(usageChunk?.outputTokens).toBe(5);
|
||||
});
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunk = chunks.find((chunk) => chunk.type === "text")
|
||||
const usageChunk = chunks.find((chunk) => chunk.type === "usage")
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
});
|
||||
expect(textChunk).toBeDefined()
|
||||
expect(textChunk?.text).toBe("Test response")
|
||||
expect(usageChunk).toBeDefined()
|
||||
expect(usageChunk?.inputTokens).toBe(10)
|
||||
expect(usageChunk?.outputTokens).toBe(5)
|
||||
})
|
||||
|
||||
describe('error handling', () => {
|
||||
const testMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: [{
|
||||
type: 'text' as const,
|
||||
text: 'Hello'
|
||||
}]
|
||||
}
|
||||
];
|
||||
it("should handle streaming responses", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(1)
|
||||
expect(textChunks[0].text).toBe("Test response")
|
||||
})
|
||||
})
|
||||
|
||||
const stream = handler.createMessage('system prompt', testMessages);
|
||||
describe("error handling", () => {
|
||||
const testMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: "Hello",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow('API Error');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
|
||||
it('should handle rate limiting', async () => {
|
||||
const rateLimitError = new Error('Rate limit exceeded');
|
||||
rateLimitError.name = 'Error';
|
||||
(rateLimitError as any).status = 429;
|
||||
mockCreate.mockRejectedValueOnce(rateLimitError);
|
||||
const stream = handler.createMessage("system prompt", testMessages)
|
||||
|
||||
const stream = handler.createMessage('system prompt', testMessages);
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow("API Error")
|
||||
})
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow('Rate limit exceeded');
|
||||
});
|
||||
});
|
||||
it("should handle rate limiting", async () => {
|
||||
const rateLimitError = new Error("Rate limit exceeded")
|
||||
rateLimitError.name = "Error"
|
||||
;(rateLimitError as any).status = 429
|
||||
mockCreate.mockRejectedValueOnce(rateLimitError)
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.openAiModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0
|
||||
});
|
||||
});
|
||||
const stream = handler.createMessage("system prompt", testMessages)
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('OpenAI completion error: API Error');
|
||||
});
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow("Rate limit exceeded")
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockImplementationOnce(() => ({
|
||||
choices: [{ message: { content: '' } }]
|
||||
}));
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.openAiModelId,
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
})
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info with sane defaults', () => {
|
||||
const model = handler.getModel();
|
||||
expect(model.id).toBe(mockOptions.openAiModelId);
|
||||
expect(model.info).toBeDefined();
|
||||
expect(model.info.contextWindow).toBe(128_000);
|
||||
expect(model.info.supportsImages).toBe(true);
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("OpenAI completion error: API Error")
|
||||
})
|
||||
|
||||
it('should handle undefined model ID', () => {
|
||||
const handlerWithoutModel = new OpenAiHandler({
|
||||
...mockOptions,
|
||||
openAiModelId: undefined
|
||||
});
|
||||
const model = handlerWithoutModel.getModel();
|
||||
expect(model.id).toBe('');
|
||||
expect(model.info).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockImplementationOnce(() => ({
|
||||
choices: [{ message: { content: "" } }],
|
||||
}))
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe("getModel", () => {
|
||||
it("should return model info with sane defaults", () => {
|
||||
const model = handler.getModel()
|
||||
expect(model.id).toBe(mockOptions.openAiModelId)
|
||||
expect(model.info).toBeDefined()
|
||||
expect(model.info.contextWindow).toBe(128_000)
|
||||
expect(model.info.supportsImages).toBe(true)
|
||||
})
|
||||
|
||||
it("should handle undefined model ID", () => {
|
||||
const handlerWithoutModel = new OpenAiHandler({
|
||||
...mockOptions,
|
||||
openAiModelId: undefined,
|
||||
})
|
||||
const model = handlerWithoutModel.getModel()
|
||||
expect(model.id).toBe("")
|
||||
expect(model.info).toBeDefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,283 +1,297 @@
|
||||
import { OpenRouterHandler } from '../openrouter'
|
||||
import { ApiHandlerOptions, ModelInfo } from '../../../shared/api'
|
||||
import OpenAI from 'openai'
|
||||
import axios from 'axios'
|
||||
import { Anthropic } from '@anthropic-ai/sdk'
|
||||
import { OpenRouterHandler } from "../openrouter"
|
||||
import { ApiHandlerOptions, ModelInfo } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import axios from "axios"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock dependencies
|
||||
jest.mock('openai')
|
||||
jest.mock('axios')
|
||||
jest.mock('delay', () => jest.fn(() => Promise.resolve()))
|
||||
jest.mock("openai")
|
||||
jest.mock("axios")
|
||||
jest.mock("delay", () => jest.fn(() => Promise.resolve()))
|
||||
|
||||
describe('OpenRouterHandler', () => {
|
||||
const mockOptions: ApiHandlerOptions = {
|
||||
openRouterApiKey: 'test-key',
|
||||
openRouterModelId: 'test-model',
|
||||
openRouterModelInfo: {
|
||||
name: 'Test Model',
|
||||
description: 'Test Description',
|
||||
maxTokens: 1000,
|
||||
contextWindow: 2000,
|
||||
supportsPromptCache: true,
|
||||
inputPrice: 0.01,
|
||||
outputPrice: 0.02
|
||||
} as ModelInfo
|
||||
}
|
||||
describe("OpenRouterHandler", () => {
|
||||
const mockOptions: ApiHandlerOptions = {
|
||||
openRouterApiKey: "test-key",
|
||||
openRouterModelId: "test-model",
|
||||
openRouterModelInfo: {
|
||||
name: "Test Model",
|
||||
description: "Test Description",
|
||||
maxTokens: 1000,
|
||||
contextWindow: 2000,
|
||||
supportsPromptCache: true,
|
||||
inputPrice: 0.01,
|
||||
outputPrice: 0.02,
|
||||
} as ModelInfo,
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
test('constructor initializes with correct options', () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
expect(handler).toBeInstanceOf(OpenRouterHandler)
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
baseURL: 'https://openrouter.ai/api/v1',
|
||||
apiKey: mockOptions.openRouterApiKey,
|
||||
defaultHeaders: {
|
||||
'HTTP-Referer': 'https://github.com/RooVetGit/Roo-Cline',
|
||||
'X-Title': 'Roo-Cline',
|
||||
},
|
||||
})
|
||||
})
|
||||
test("constructor initializes with correct options", () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
expect(handler).toBeInstanceOf(OpenRouterHandler)
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
baseURL: "https://openrouter.ai/api/v1",
|
||||
apiKey: mockOptions.openRouterApiKey,
|
||||
defaultHeaders: {
|
||||
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
|
||||
"X-Title": "Roo-Cline",
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
test('getModel returns correct model info when options are provided', () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const result = handler.getModel()
|
||||
|
||||
expect(result).toEqual({
|
||||
id: mockOptions.openRouterModelId,
|
||||
info: mockOptions.openRouterModelInfo
|
||||
})
|
||||
})
|
||||
test("getModel returns correct model info when options are provided", () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const result = handler.getModel()
|
||||
|
||||
test('getModel returns default model info when options are not provided', () => {
|
||||
const handler = new OpenRouterHandler({})
|
||||
const result = handler.getModel()
|
||||
|
||||
expect(result.id).toBe('anthropic/claude-3.5-sonnet:beta')
|
||||
expect(result.info.supportsPromptCache).toBe(true)
|
||||
})
|
||||
expect(result).toEqual({
|
||||
id: mockOptions.openRouterModelId,
|
||||
info: mockOptions.openRouterModelInfo,
|
||||
})
|
||||
})
|
||||
|
||||
test('createMessage generates correct stream chunks', async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: 'test-id',
|
||||
choices: [{
|
||||
delta: {
|
||||
content: 'test response'
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
test("getModel returns default model info when options are not provided", () => {
|
||||
const handler = new OpenRouterHandler({})
|
||||
const result = handler.getModel()
|
||||
|
||||
// Mock OpenAI chat.completions.create
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
expect(result.id).toBe("anthropic/claude-3.5-sonnet:beta")
|
||||
expect(result.info.supportsPromptCache).toBe(true)
|
||||
})
|
||||
|
||||
// Mock axios.get for generation details
|
||||
;(axios.get as jest.Mock).mockResolvedValue({
|
||||
data: {
|
||||
data: {
|
||||
native_tokens_prompt: 10,
|
||||
native_tokens_completion: 20,
|
||||
total_cost: 0.001
|
||||
}
|
||||
}
|
||||
})
|
||||
test("createMessage generates correct stream chunks", async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: "test-id",
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: "test response",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const systemPrompt = 'test system prompt'
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{ role: 'user' as const, content: 'test message' }]
|
||||
// Mock OpenAI chat.completions.create
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const chunks = []
|
||||
|
||||
for await (const chunk of generator) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
// Mock axios.get for generation details
|
||||
;(axios.get as jest.Mock).mockResolvedValue({
|
||||
data: {
|
||||
data: {
|
||||
native_tokens_prompt: 10,
|
||||
native_tokens_completion: 20,
|
||||
total_cost: 0.001,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Verify stream chunks
|
||||
expect(chunks).toHaveLength(2) // One text chunk and one usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'test response'
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'usage',
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalCost: 0.001,
|
||||
fullResponseText: 'test response'
|
||||
})
|
||||
const systemPrompt = "test system prompt"
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user" as const, content: "test message" }]
|
||||
|
||||
// Verify OpenAI client was called with correct parameters
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
model: mockOptions.openRouterModelId,
|
||||
temperature: 0,
|
||||
messages: expect.arrayContaining([
|
||||
{ role: 'system', content: systemPrompt },
|
||||
{ role: 'user', content: 'test message' }
|
||||
]),
|
||||
stream: true
|
||||
}))
|
||||
})
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const chunks = []
|
||||
|
||||
test('createMessage with middle-out transform enabled', async () => {
|
||||
const handler = new OpenRouterHandler({
|
||||
...mockOptions,
|
||||
openRouterUseMiddleOutTransform: true
|
||||
})
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: 'test-id',
|
||||
choices: [{
|
||||
delta: {
|
||||
content: 'test response'
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
for await (const chunk of generator) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
||||
// Verify stream chunks
|
||||
expect(chunks).toHaveLength(2) // One text chunk and one usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: "text",
|
||||
text: "test response",
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: "usage",
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalCost: 0.001,
|
||||
fullResponseText: "test response",
|
||||
})
|
||||
|
||||
await handler.createMessage('test', []).next()
|
||||
// Verify OpenAI client was called with correct parameters
|
||||
expect(mockCreate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: mockOptions.openRouterModelId,
|
||||
temperature: 0,
|
||||
messages: expect.arrayContaining([
|
||||
{ role: "system", content: systemPrompt },
|
||||
{ role: "user", content: "test message" },
|
||||
]),
|
||||
stream: true,
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
transforms: ['middle-out']
|
||||
}))
|
||||
})
|
||||
test("createMessage with middle-out transform enabled", async () => {
|
||||
const handler = new OpenRouterHandler({
|
||||
...mockOptions,
|
||||
openRouterUseMiddleOutTransform: true,
|
||||
})
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: "test-id",
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: "test response",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
test('createMessage with Claude model adds cache control', async () => {
|
||||
const handler = new OpenRouterHandler({
|
||||
...mockOptions,
|
||||
openRouterModelId: 'anthropic/claude-3.5-sonnet'
|
||||
})
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: 'test-id',
|
||||
choices: [{
|
||||
delta: {
|
||||
content: 'test response'
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
||||
await handler.createMessage("test", []).next()
|
||||
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{ role: 'user', content: 'message 1' },
|
||||
{ role: 'assistant', content: 'response 1' },
|
||||
{ role: 'user', content: 'message 2' }
|
||||
]
|
||||
expect(mockCreate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
transforms: ["middle-out"],
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
await handler.createMessage('test system', messages).next()
|
||||
test("createMessage with Claude model adds cache control", async () => {
|
||||
const handler = new OpenRouterHandler({
|
||||
...mockOptions,
|
||||
openRouterModelId: "anthropic/claude-3.5-sonnet",
|
||||
})
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: "test-id",
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: "test response",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
messages: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
role: 'system',
|
||||
content: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
cache_control: { type: 'ephemeral' }
|
||||
})
|
||||
])
|
||||
})
|
||||
])
|
||||
}))
|
||||
})
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
||||
|
||||
test('createMessage handles API errors', async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
error: {
|
||||
message: 'API Error',
|
||||
code: 500
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{ role: "user", content: "message 1" },
|
||||
{ role: "assistant", content: "response 1" },
|
||||
{ role: "user", content: "message 2" },
|
||||
]
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
await handler.createMessage("test system", messages).next()
|
||||
|
||||
const generator = handler.createMessage('test', [])
|
||||
await expect(generator.next()).rejects.toThrow('OpenRouter API Error 500: API Error')
|
||||
})
|
||||
expect(mockCreate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
messages: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
role: "system",
|
||||
content: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
cache_control: { type: "ephemeral" },
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
test('completePrompt returns correct response', async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockResponse = {
|
||||
choices: [{
|
||||
message: {
|
||||
content: 'test completion'
|
||||
}
|
||||
}]
|
||||
}
|
||||
test("createMessage handles API errors", async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
error: {
|
||||
message: "API Error",
|
||||
code: 500,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockResponse)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
|
||||
const result = await handler.completePrompt('test prompt')
|
||||
const generator = handler.createMessage("test", [])
|
||||
await expect(generator.next()).rejects.toThrow("OpenRouter API Error 500: API Error")
|
||||
})
|
||||
|
||||
expect(result).toBe('test completion')
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.openRouterModelId,
|
||||
messages: [{ role: 'user', content: 'test prompt' }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
})
|
||||
})
|
||||
test("completePrompt returns correct response", async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockResponse = {
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
content: "test completion",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
test('completePrompt handles API errors', async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockError = {
|
||||
error: {
|
||||
message: 'API Error',
|
||||
code: 500
|
||||
}
|
||||
}
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockResponse)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockError)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
const result = await handler.completePrompt("test prompt")
|
||||
|
||||
await expect(handler.completePrompt('test prompt'))
|
||||
.rejects.toThrow('OpenRouter API Error 500: API Error')
|
||||
})
|
||||
expect(result).toBe("test completion")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.openRouterModelId,
|
||||
messages: [{ role: "user", content: "test prompt" }],
|
||||
temperature: 0,
|
||||
stream: false,
|
||||
})
|
||||
})
|
||||
|
||||
test('completePrompt handles unexpected errors', async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockCreate = jest.fn().mockRejectedValue(new Error('Unexpected error'))
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
test("completePrompt handles API errors", async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockError = {
|
||||
error: {
|
||||
message: "API Error",
|
||||
code: 500,
|
||||
},
|
||||
}
|
||||
|
||||
await expect(handler.completePrompt('test prompt'))
|
||||
.rejects.toThrow('OpenRouter completion error: Unexpected error')
|
||||
})
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockError)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
|
||||
await expect(handler.completePrompt("test prompt")).rejects.toThrow("OpenRouter API Error 500: API Error")
|
||||
})
|
||||
|
||||
test("completePrompt handles unexpected errors", async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockCreate = jest.fn().mockRejectedValue(new Error("Unexpected error"))
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
|
||||
await expect(handler.completePrompt("test prompt")).rejects.toThrow(
|
||||
"OpenRouter completion error: Unexpected error",
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,296 +1,295 @@
|
||||
import { VertexHandler } from '../vertex';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { AnthropicVertex } from '@anthropic-ai/vertex-sdk';
|
||||
import { VertexHandler } from "../vertex"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
|
||||
|
||||
// Mock Vertex SDK
|
||||
jest.mock('@anthropic-ai/vertex-sdk', () => ({
|
||||
AnthropicVertex: jest.fn().mockImplementation(() => ({
|
||||
messages: {
|
||||
create: jest.fn().mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
content: [
|
||||
{ type: 'text', text: 'Test response' }
|
||||
],
|
||||
role: 'assistant',
|
||||
model: options.model,
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
return {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: 'message_start',
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
yield {
|
||||
type: 'content_block_start',
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'Test response'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}))
|
||||
}));
|
||||
jest.mock("@anthropic-ai/vertex-sdk", () => ({
|
||||
AnthropicVertex: jest.fn().mockImplementation(() => ({
|
||||
messages: {
|
||||
create: jest.fn().mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: "test-completion",
|
||||
content: [{ type: "text", text: "Test response" }],
|
||||
role: "assistant",
|
||||
model: options.model,
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
},
|
||||
}
|
||||
}
|
||||
return {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: "message_start",
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
yield {
|
||||
type: "content_block_start",
|
||||
content_block: {
|
||||
type: "text",
|
||||
text: "Test response",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
})),
|
||||
}))
|
||||
|
||||
describe('VertexHandler', () => {
|
||||
let handler: VertexHandler;
|
||||
describe("VertexHandler", () => {
|
||||
let handler: VertexHandler
|
||||
|
||||
beforeEach(() => {
|
||||
handler = new VertexHandler({
|
||||
apiModelId: 'claude-3-5-sonnet-v2@20241022',
|
||||
vertexProjectId: 'test-project',
|
||||
vertexRegion: 'us-central1'
|
||||
});
|
||||
});
|
||||
beforeEach(() => {
|
||||
handler = new VertexHandler({
|
||||
apiModelId: "claude-3-5-sonnet-v2@20241022",
|
||||
vertexProjectId: "test-project",
|
||||
vertexRegion: "us-central1",
|
||||
})
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided config', () => {
|
||||
expect(AnthropicVertex).toHaveBeenCalledWith({
|
||||
projectId: 'test-project',
|
||||
region: 'us-central1'
|
||||
});
|
||||
});
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided config", () => {
|
||||
expect(AnthropicVertex).toHaveBeenCalledWith({
|
||||
projectId: "test-project",
|
||||
region: "us-central1",
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello'
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi there!'
|
||||
}
|
||||
];
|
||||
describe("createMessage", () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello",
|
||||
},
|
||||
{
|
||||
role: "assistant",
|
||||
content: "Hi there!",
|
||||
},
|
||||
]
|
||||
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
|
||||
it('should handle streaming responses correctly', async () => {
|
||||
const mockStream = [
|
||||
{
|
||||
type: 'message_start',
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 0
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
type: 'content_block_start',
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'Hello'
|
||||
}
|
||||
},
|
||||
{
|
||||
type: 'content_block_delta',
|
||||
delta: {
|
||||
type: 'text_delta',
|
||||
text: ' world!'
|
||||
}
|
||||
},
|
||||
{
|
||||
type: 'message_delta',
|
||||
usage: {
|
||||
output_tokens: 5
|
||||
}
|
||||
}
|
||||
];
|
||||
it("should handle streaming responses correctly", async () => {
|
||||
const mockStream = [
|
||||
{
|
||||
type: "message_start",
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "content_block_start",
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: "text",
|
||||
text: "Hello",
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "content_block_delta",
|
||||
delta: {
|
||||
type: "text_delta",
|
||||
text: " world!",
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "message_delta",
|
||||
usage: {
|
||||
output_tokens: 5,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
// Setup async iterator for mock stream
|
||||
const asyncIterator = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk;
|
||||
}
|
||||
}
|
||||
};
|
||||
// Setup async iterator for mock stream
|
||||
const asyncIterator = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const chunks = [];
|
||||
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
const chunks = []
|
||||
|
||||
expect(chunks.length).toBe(4);
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'usage',
|
||||
inputTokens: 10,
|
||||
outputTokens: 0
|
||||
});
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Hello'
|
||||
});
|
||||
expect(chunks[2]).toEqual({
|
||||
type: 'text',
|
||||
text: ' world!'
|
||||
});
|
||||
expect(chunks[3]).toEqual({
|
||||
type: 'usage',
|
||||
inputTokens: 0,
|
||||
outputTokens: 5
|
||||
});
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'claude-3-5-sonnet-v2@20241022',
|
||||
max_tokens: 8192,
|
||||
temperature: 0,
|
||||
system: systemPrompt,
|
||||
messages: mockMessages,
|
||||
stream: true
|
||||
});
|
||||
});
|
||||
expect(chunks.length).toBe(4)
|
||||
expect(chunks[0]).toEqual({
|
||||
type: "usage",
|
||||
inputTokens: 10,
|
||||
outputTokens: 0,
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: "text",
|
||||
text: "Hello",
|
||||
})
|
||||
expect(chunks[2]).toEqual({
|
||||
type: "text",
|
||||
text: " world!",
|
||||
})
|
||||
expect(chunks[3]).toEqual({
|
||||
type: "usage",
|
||||
inputTokens: 0,
|
||||
outputTokens: 5,
|
||||
})
|
||||
|
||||
it('should handle multiple content blocks with line breaks', async () => {
|
||||
const mockStream = [
|
||||
{
|
||||
type: 'content_block_start',
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'First line'
|
||||
}
|
||||
},
|
||||
{
|
||||
type: 'content_block_start',
|
||||
index: 1,
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'Second line'
|
||||
}
|
||||
}
|
||||
];
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "claude-3-5-sonnet-v2@20241022",
|
||||
max_tokens: 8192,
|
||||
temperature: 0,
|
||||
system: systemPrompt,
|
||||
messages: mockMessages,
|
||||
stream: true,
|
||||
})
|
||||
})
|
||||
|
||||
const asyncIterator = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk;
|
||||
}
|
||||
}
|
||||
};
|
||||
it("should handle multiple content blocks with line breaks", async () => {
|
||||
const mockStream = [
|
||||
{
|
||||
type: "content_block_start",
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: "text",
|
||||
text: "First line",
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "content_block_start",
|
||||
index: 1,
|
||||
content_block: {
|
||||
type: "text",
|
||||
text: "Second line",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
const asyncIterator = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const chunks = [];
|
||||
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
expect(chunks.length).toBe(3);
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'First line'
|
||||
});
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'text',
|
||||
text: '\n'
|
||||
});
|
||||
expect(chunks[2]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Second line'
|
||||
});
|
||||
});
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
const chunks = []
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('Vertex API error');
|
||||
const mockCreate = jest.fn().mockRejectedValue(mockError);
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
expect(chunks.length).toBe(3)
|
||||
expect(chunks[0]).toEqual({
|
||||
type: "text",
|
||||
text: "First line",
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: "text",
|
||||
text: "\n",
|
||||
})
|
||||
expect(chunks[2]).toEqual({
|
||||
type: "text",
|
||||
text: "Second line",
|
||||
})
|
||||
})
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
}
|
||||
}).rejects.toThrow('Vertex API error');
|
||||
});
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
const mockError = new Error("Vertex API error")
|
||||
const mockCreate = jest.fn().mockRejectedValue(mockError)
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(handler['client'].messages.create).toHaveBeenCalledWith({
|
||||
model: 'claude-3-5-sonnet-v2@20241022',
|
||||
max_tokens: 8192,
|
||||
temperature: 0,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
stream: false
|
||||
});
|
||||
});
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('Vertex API error');
|
||||
const mockCreate = jest.fn().mockRejectedValue(mockError);
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
}
|
||||
}).rejects.toThrow("Vertex API error")
|
||||
})
|
||||
})
|
||||
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Vertex completion error: Vertex API error');
|
||||
});
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(handler["client"].messages.create).toHaveBeenCalledWith({
|
||||
model: "claude-3-5-sonnet-v2@20241022",
|
||||
max_tokens: 8192,
|
||||
temperature: 0,
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
stream: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle non-text content', async () => {
|
||||
const mockCreate = jest.fn().mockResolvedValue({
|
||||
content: [{ type: 'image' }]
|
||||
});
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
it("should handle API errors", async () => {
|
||||
const mockError = new Error("Vertex API error")
|
||||
const mockCreate = jest.fn().mockRejectedValue(mockError)
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"Vertex completion error: Vertex API error",
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
const mockCreate = jest.fn().mockResolvedValue({
|
||||
content: [{ type: 'text', text: '' }]
|
||||
});
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
it("should handle non-text content", async () => {
|
||||
const mockCreate = jest.fn().mockResolvedValue({
|
||||
content: [{ type: "image" }],
|
||||
})
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return correct model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe('claude-3-5-sonnet-v2@20241022');
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(8192);
|
||||
expect(modelInfo.info.contextWindow).toBe(200_000);
|
||||
});
|
||||
it("should handle empty response", async () => {
|
||||
const mockCreate = jest.fn().mockResolvedValue({
|
||||
content: [{ type: "text", text: "" }],
|
||||
})
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
it('should return default model if invalid model specified', () => {
|
||||
const invalidHandler = new VertexHandler({
|
||||
apiModelId: 'invalid-model',
|
||||
vertexProjectId: 'test-project',
|
||||
vertexRegion: 'us-central1'
|
||||
});
|
||||
const modelInfo = invalidHandler.getModel();
|
||||
expect(modelInfo.id).toBe('claude-3-5-sonnet-v2@20241022'); // Default model
|
||||
});
|
||||
});
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe("getModel", () => {
|
||||
it("should return correct model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe("claude-3-5-sonnet-v2@20241022")
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(8192)
|
||||
expect(modelInfo.info.contextWindow).toBe(200_000)
|
||||
})
|
||||
|
||||
it("should return default model if invalid model specified", () => {
|
||||
const invalidHandler = new VertexHandler({
|
||||
apiModelId: "invalid-model",
|
||||
vertexProjectId: "test-project",
|
||||
vertexRegion: "us-central1",
|
||||
})
|
||||
const modelInfo = invalidHandler.getModel()
|
||||
expect(modelInfo.id).toBe("claude-3-5-sonnet-v2@20241022") // Default model
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,289 +1,295 @@
|
||||
import * as vscode from 'vscode';
|
||||
import { VsCodeLmHandler } from '../vscode-lm';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import * as vscode from "vscode"
|
||||
import { VsCodeLmHandler } from "../vscode-lm"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock vscode namespace
|
||||
jest.mock('vscode', () => {
|
||||
jest.mock("vscode", () => {
|
||||
class MockLanguageModelTextPart {
|
||||
type = 'text';
|
||||
type = "text"
|
||||
constructor(public value: string) {}
|
||||
}
|
||||
|
||||
class MockLanguageModelToolCallPart {
|
||||
type = 'tool_call';
|
||||
type = "tool_call"
|
||||
constructor(
|
||||
public callId: string,
|
||||
public name: string,
|
||||
public input: any
|
||||
public input: any,
|
||||
) {}
|
||||
}
|
||||
|
||||
return {
|
||||
workspace: {
|
||||
onDidChangeConfiguration: jest.fn((callback) => ({
|
||||
dispose: jest.fn()
|
||||
}))
|
||||
dispose: jest.fn(),
|
||||
})),
|
||||
},
|
||||
CancellationTokenSource: jest.fn(() => ({
|
||||
token: {
|
||||
isCancellationRequested: false,
|
||||
onCancellationRequested: jest.fn()
|
||||
onCancellationRequested: jest.fn(),
|
||||
},
|
||||
cancel: jest.fn(),
|
||||
dispose: jest.fn()
|
||||
dispose: jest.fn(),
|
||||
})),
|
||||
CancellationError: class CancellationError extends Error {
|
||||
constructor() {
|
||||
super('Operation cancelled');
|
||||
this.name = 'CancellationError';
|
||||
super("Operation cancelled")
|
||||
this.name = "CancellationError"
|
||||
}
|
||||
},
|
||||
LanguageModelChatMessage: {
|
||||
Assistant: jest.fn((content) => ({
|
||||
role: 'assistant',
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
|
||||
role: "assistant",
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
|
||||
})),
|
||||
User: jest.fn((content) => ({
|
||||
role: 'user',
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
|
||||
}))
|
||||
role: "user",
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
|
||||
})),
|
||||
},
|
||||
LanguageModelTextPart: MockLanguageModelTextPart,
|
||||
LanguageModelToolCallPart: MockLanguageModelToolCallPart,
|
||||
lm: {
|
||||
selectChatModels: jest.fn()
|
||||
}
|
||||
};
|
||||
});
|
||||
selectChatModels: jest.fn(),
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
const mockLanguageModelChat = {
|
||||
id: 'test-model',
|
||||
name: 'Test Model',
|
||||
vendor: 'test-vendor',
|
||||
family: 'test-family',
|
||||
version: '1.0',
|
||||
id: "test-model",
|
||||
name: "Test Model",
|
||||
vendor: "test-vendor",
|
||||
family: "test-family",
|
||||
version: "1.0",
|
||||
maxInputTokens: 4096,
|
||||
sendRequest: jest.fn(),
|
||||
countTokens: jest.fn()
|
||||
};
|
||||
countTokens: jest.fn(),
|
||||
}
|
||||
|
||||
describe('VsCodeLmHandler', () => {
|
||||
let handler: VsCodeLmHandler;
|
||||
describe("VsCodeLmHandler", () => {
|
||||
let handler: VsCodeLmHandler
|
||||
const defaultOptions: ApiHandlerOptions = {
|
||||
vsCodeLmModelSelector: {
|
||||
vendor: 'test-vendor',
|
||||
family: 'test-family'
|
||||
}
|
||||
};
|
||||
vendor: "test-vendor",
|
||||
family: "test-family",
|
||||
},
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
handler = new VsCodeLmHandler(defaultOptions);
|
||||
});
|
||||
jest.clearAllMocks()
|
||||
handler = new VsCodeLmHandler(defaultOptions)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
handler.dispose();
|
||||
});
|
||||
handler.dispose()
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeDefined();
|
||||
expect(vscode.workspace.onDidChangeConfiguration).toHaveBeenCalled();
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeDefined()
|
||||
expect(vscode.workspace.onDidChangeConfiguration).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle configuration changes', () => {
|
||||
const callback = (vscode.workspace.onDidChangeConfiguration as jest.Mock).mock.calls[0][0];
|
||||
callback({ affectsConfiguration: () => true });
|
||||
it("should handle configuration changes", () => {
|
||||
const callback = (vscode.workspace.onDidChangeConfiguration as jest.Mock).mock.calls[0][0]
|
||||
callback({ affectsConfiguration: () => true })
|
||||
// Should reset client when config changes
|
||||
expect(handler['client']).toBeNull();
|
||||
});
|
||||
});
|
||||
expect(handler["client"]).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('createClient', () => {
|
||||
it('should create client with selector', async () => {
|
||||
const mockModel = { ...mockLanguageModelChat };
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
||||
describe("createClient", () => {
|
||||
it("should create client with selector", async () => {
|
||||
const mockModel = { ...mockLanguageModelChat }
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||
|
||||
const client = await handler['createClient']({
|
||||
vendor: 'test-vendor',
|
||||
family: 'test-family'
|
||||
});
|
||||
const client = await handler["createClient"]({
|
||||
vendor: "test-vendor",
|
||||
family: "test-family",
|
||||
})
|
||||
|
||||
expect(client).toBeDefined();
|
||||
expect(client.id).toBe('test-model');
|
||||
expect(client).toBeDefined()
|
||||
expect(client.id).toBe("test-model")
|
||||
expect(vscode.lm.selectChatModels).toHaveBeenCalledWith({
|
||||
vendor: 'test-vendor',
|
||||
family: 'test-family'
|
||||
});
|
||||
});
|
||||
vendor: "test-vendor",
|
||||
family: "test-family",
|
||||
})
|
||||
})
|
||||
|
||||
it('should return default client when no models available', async () => {
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([]);
|
||||
it("should return default client when no models available", async () => {
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([])
|
||||
|
||||
const client = await handler['createClient']({});
|
||||
|
||||
expect(client).toBeDefined();
|
||||
expect(client.id).toBe('default-lm');
|
||||
expect(client.vendor).toBe('vscode');
|
||||
});
|
||||
});
|
||||
const client = await handler["createClient"]({})
|
||||
|
||||
describe('createMessage', () => {
|
||||
expect(client).toBeDefined()
|
||||
expect(client.id).toBe("default-lm")
|
||||
expect(client.vendor).toBe("vscode")
|
||||
})
|
||||
})
|
||||
|
||||
describe("createMessage", () => {
|
||||
beforeEach(() => {
|
||||
const mockModel = { ...mockLanguageModelChat };
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
||||
mockLanguageModelChat.countTokens.mockResolvedValue(10);
|
||||
});
|
||||
const mockModel = { ...mockLanguageModelChat }
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||
mockLanguageModelChat.countTokens.mockResolvedValue(10)
|
||||
})
|
||||
|
||||
it('should stream text responses', async () => {
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
||||
role: 'user' as const,
|
||||
content: 'Hello'
|
||||
}];
|
||||
it("should stream text responses", async () => {
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: "Hello",
|
||||
},
|
||||
]
|
||||
|
||||
const responseText = 'Hello! How can I help you?';
|
||||
const responseText = "Hello! How can I help you?"
|
||||
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
|
||||
stream: (async function* () {
|
||||
yield new vscode.LanguageModelTextPart(responseText);
|
||||
return;
|
||||
yield new vscode.LanguageModelTextPart(responseText)
|
||||
return
|
||||
})(),
|
||||
text: (async function* () {
|
||||
yield responseText;
|
||||
return;
|
||||
})()
|
||||
});
|
||||
yield responseText
|
||||
return
|
||||
})(),
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks = [];
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks).toHaveLength(2); // Text chunk + usage chunk
|
||||
expect(chunks).toHaveLength(2) // Text chunk + usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: responseText
|
||||
});
|
||||
type: "text",
|
||||
text: responseText,
|
||||
})
|
||||
expect(chunks[1]).toMatchObject({
|
||||
type: 'usage',
|
||||
type: "usage",
|
||||
inputTokens: expect.any(Number),
|
||||
outputTokens: expect.any(Number)
|
||||
});
|
||||
});
|
||||
outputTokens: expect.any(Number),
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle tool calls', async () => {
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
||||
role: 'user' as const,
|
||||
content: 'Calculate 2+2'
|
||||
}];
|
||||
it("should handle tool calls", async () => {
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: "Calculate 2+2",
|
||||
},
|
||||
]
|
||||
|
||||
const toolCallData = {
|
||||
name: 'calculator',
|
||||
arguments: { operation: 'add', numbers: [2, 2] },
|
||||
callId: 'call-1'
|
||||
};
|
||||
name: "calculator",
|
||||
arguments: { operation: "add", numbers: [2, 2] },
|
||||
callId: "call-1",
|
||||
}
|
||||
|
||||
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
|
||||
stream: (async function* () {
|
||||
yield new vscode.LanguageModelToolCallPart(
|
||||
toolCallData.callId,
|
||||
toolCallData.name,
|
||||
toolCallData.arguments
|
||||
);
|
||||
return;
|
||||
toolCallData.arguments,
|
||||
)
|
||||
return
|
||||
})(),
|
||||
text: (async function* () {
|
||||
yield JSON.stringify({ type: 'tool_call', ...toolCallData });
|
||||
return;
|
||||
})()
|
||||
});
|
||||
yield JSON.stringify({ type: "tool_call", ...toolCallData })
|
||||
return
|
||||
})(),
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks = [];
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks).toHaveLength(2); // Tool call chunk + usage chunk
|
||||
expect(chunks).toHaveLength(2) // Tool call chunk + usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: JSON.stringify({ type: 'tool_call', ...toolCallData })
|
||||
});
|
||||
});
|
||||
type: "text",
|
||||
text: JSON.stringify({ type: "tool_call", ...toolCallData }),
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle errors', async () => {
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
||||
role: 'user' as const,
|
||||
content: 'Hello'
|
||||
}];
|
||||
it("should handle errors", async () => {
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: "Hello",
|
||||
},
|
||||
]
|
||||
|
||||
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error('API Error'));
|
||||
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("API Error"))
|
||||
|
||||
await expect(async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
for await (const _ of stream) {
|
||||
// consume stream
|
||||
}
|
||||
}).rejects.toThrow('API Error');
|
||||
});
|
||||
});
|
||||
}).rejects.toThrow("API Error")
|
||||
})
|
||||
})
|
||||
|
||||
describe("getModel", () => {
|
||||
it("should return model info when client exists", async () => {
|
||||
const mockModel = { ...mockLanguageModelChat }
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info when client exists', async () => {
|
||||
const mockModel = { ...mockLanguageModelChat };
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
||||
|
||||
// Initialize client
|
||||
await handler['getClient']();
|
||||
|
||||
const model = handler.getModel();
|
||||
expect(model.id).toBe('test-model');
|
||||
expect(model.info).toBeDefined();
|
||||
expect(model.info.contextWindow).toBe(4096);
|
||||
});
|
||||
await handler["getClient"]()
|
||||
|
||||
it('should return fallback model info when no client exists', () => {
|
||||
const model = handler.getModel();
|
||||
expect(model.id).toBe('test-vendor/test-family');
|
||||
expect(model.info).toBeDefined();
|
||||
});
|
||||
});
|
||||
const model = handler.getModel()
|
||||
expect(model.id).toBe("test-model")
|
||||
expect(model.info).toBeDefined()
|
||||
expect(model.info.contextWindow).toBe(4096)
|
||||
})
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete single prompt', async () => {
|
||||
const mockModel = { ...mockLanguageModelChat };
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
||||
it("should return fallback model info when no client exists", () => {
|
||||
const model = handler.getModel()
|
||||
expect(model.id).toBe("test-vendor/test-family")
|
||||
expect(model.info).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
const responseText = 'Completed text';
|
||||
describe("completePrompt", () => {
|
||||
it("should complete single prompt", async () => {
|
||||
const mockModel = { ...mockLanguageModelChat }
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||
|
||||
const responseText = "Completed text"
|
||||
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
|
||||
stream: (async function* () {
|
||||
yield new vscode.LanguageModelTextPart(responseText);
|
||||
return;
|
||||
yield new vscode.LanguageModelTextPart(responseText)
|
||||
return
|
||||
})(),
|
||||
text: (async function* () {
|
||||
yield responseText;
|
||||
return;
|
||||
})()
|
||||
});
|
||||
yield responseText
|
||||
return
|
||||
})(),
|
||||
})
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe(responseText);
|
||||
expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled();
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe(responseText)
|
||||
expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle errors during completion', async () => {
|
||||
const mockModel = { ...mockLanguageModelChat };
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
||||
it("should handle errors during completion", async () => {
|
||||
const mockModel = { ...mockLanguageModelChat }
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||
|
||||
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error('Completion failed'));
|
||||
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("Completion failed"))
|
||||
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects
|
||||
.toThrow('VSCode LM completion error: Completion failed');
|
||||
});
|
||||
});
|
||||
});
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"VSCode LM completion error: Completion failed",
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -181,14 +181,14 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
|
||||
max_tokens: this.getModel().info.maxTokens || 8192,
|
||||
temperature: 0,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
stream: false
|
||||
stream: false,
|
||||
})
|
||||
|
||||
const content = response.content[0]
|
||||
if (content.type === 'text') {
|
||||
if (content.type === "text") {
|
||||
return content.text
|
||||
}
|
||||
return ''
|
||||
return ""
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`Anthropic completion error: ${error.message}`)
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime"
|
||||
import {
|
||||
BedrockRuntimeClient,
|
||||
ConverseStreamCommand,
|
||||
ConverseCommand,
|
||||
BedrockRuntimeClientConfig,
|
||||
} from "@aws-sdk/client-bedrock-runtime"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
|
||||
@@ -7,275 +12,276 @@ import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../
|
||||
|
||||
// Define types for stream events based on AWS SDK
|
||||
export interface StreamEvent {
|
||||
messageStart?: {
|
||||
role?: string;
|
||||
};
|
||||
messageStop?: {
|
||||
stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence";
|
||||
additionalModelResponseFields?: Record<string, unknown>;
|
||||
};
|
||||
contentBlockStart?: {
|
||||
start?: {
|
||||
text?: string;
|
||||
};
|
||||
contentBlockIndex?: number;
|
||||
};
|
||||
contentBlockDelta?: {
|
||||
delta?: {
|
||||
text?: string;
|
||||
};
|
||||
contentBlockIndex?: number;
|
||||
};
|
||||
metadata?: {
|
||||
usage?: {
|
||||
inputTokens: number;
|
||||
outputTokens: number;
|
||||
totalTokens?: number; // Made optional since we don't use it
|
||||
};
|
||||
metrics?: {
|
||||
latencyMs: number;
|
||||
};
|
||||
};
|
||||
messageStart?: {
|
||||
role?: string
|
||||
}
|
||||
messageStop?: {
|
||||
stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence"
|
||||
additionalModelResponseFields?: Record<string, unknown>
|
||||
}
|
||||
contentBlockStart?: {
|
||||
start?: {
|
||||
text?: string
|
||||
}
|
||||
contentBlockIndex?: number
|
||||
}
|
||||
contentBlockDelta?: {
|
||||
delta?: {
|
||||
text?: string
|
||||
}
|
||||
contentBlockIndex?: number
|
||||
}
|
||||
metadata?: {
|
||||
usage?: {
|
||||
inputTokens: number
|
||||
outputTokens: number
|
||||
totalTokens?: number // Made optional since we don't use it
|
||||
}
|
||||
metrics?: {
|
||||
latencyMs: number
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: BedrockRuntimeClient
|
||||
private options: ApiHandlerOptions
|
||||
private client: BedrockRuntimeClient
|
||||
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
this.options = options
|
||||
|
||||
// Only include credentials if they actually exist
|
||||
const clientConfig: BedrockRuntimeClientConfig = {
|
||||
region: this.options.awsRegion || "us-east-1"
|
||||
}
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
this.options = options
|
||||
|
||||
if (this.options.awsAccessKey && this.options.awsSecretKey) {
|
||||
// Create credentials object with all properties at once
|
||||
clientConfig.credentials = {
|
||||
accessKeyId: this.options.awsAccessKey,
|
||||
secretAccessKey: this.options.awsSecretKey,
|
||||
...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {})
|
||||
}
|
||||
}
|
||||
// Only include credentials if they actually exist
|
||||
const clientConfig: BedrockRuntimeClientConfig = {
|
||||
region: this.options.awsRegion || "us-east-1",
|
||||
}
|
||||
|
||||
this.client = new BedrockRuntimeClient(clientConfig)
|
||||
}
|
||||
if (this.options.awsAccessKey && this.options.awsSecretKey) {
|
||||
// Create credentials object with all properties at once
|
||||
clientConfig.credentials = {
|
||||
accessKeyId: this.options.awsAccessKey,
|
||||
secretAccessKey: this.options.awsSecretKey,
|
||||
...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {}),
|
||||
}
|
||||
}
|
||||
|
||||
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
||||
const modelConfig = this.getModel()
|
||||
|
||||
// Handle cross-region inference
|
||||
let modelId: string
|
||||
if (this.options.awsUseCrossRegionInference) {
|
||||
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
|
||||
switch (regionPrefix) {
|
||||
case "us-":
|
||||
modelId = `us.${modelConfig.id}`
|
||||
break
|
||||
case "eu-":
|
||||
modelId = `eu.${modelConfig.id}`
|
||||
break
|
||||
default:
|
||||
modelId = modelConfig.id
|
||||
break
|
||||
}
|
||||
} else {
|
||||
modelId = modelConfig.id
|
||||
}
|
||||
this.client = new BedrockRuntimeClient(clientConfig)
|
||||
}
|
||||
|
||||
// Convert messages to Bedrock format
|
||||
const formattedMessages = convertToBedrockConverseMessages(messages)
|
||||
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
||||
const modelConfig = this.getModel()
|
||||
|
||||
// Construct the payload
|
||||
const payload = {
|
||||
modelId,
|
||||
messages: formattedMessages,
|
||||
system: [{ text: systemPrompt }],
|
||||
inferenceConfig: {
|
||||
maxTokens: modelConfig.info.maxTokens || 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1,
|
||||
...(this.options.awsUsePromptCache ? {
|
||||
promptCache: {
|
||||
promptCacheId: this.options.awspromptCacheId || ""
|
||||
}
|
||||
} : {})
|
||||
}
|
||||
}
|
||||
// Handle cross-region inference
|
||||
let modelId: string
|
||||
if (this.options.awsUseCrossRegionInference) {
|
||||
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
|
||||
switch (regionPrefix) {
|
||||
case "us-":
|
||||
modelId = `us.${modelConfig.id}`
|
||||
break
|
||||
case "eu-":
|
||||
modelId = `eu.${modelConfig.id}`
|
||||
break
|
||||
default:
|
||||
modelId = modelConfig.id
|
||||
break
|
||||
}
|
||||
} else {
|
||||
modelId = modelConfig.id
|
||||
}
|
||||
|
||||
try {
|
||||
const command = new ConverseStreamCommand(payload)
|
||||
const response = await this.client.send(command)
|
||||
// Convert messages to Bedrock format
|
||||
const formattedMessages = convertToBedrockConverseMessages(messages)
|
||||
|
||||
if (!response.stream) {
|
||||
throw new Error('No stream available in the response')
|
||||
}
|
||||
// Construct the payload
|
||||
const payload = {
|
||||
modelId,
|
||||
messages: formattedMessages,
|
||||
system: [{ text: systemPrompt }],
|
||||
inferenceConfig: {
|
||||
maxTokens: modelConfig.info.maxTokens || 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1,
|
||||
...(this.options.awsUsePromptCache
|
||||
? {
|
||||
promptCache: {
|
||||
promptCacheId: this.options.awspromptCacheId || "",
|
||||
},
|
||||
}
|
||||
: {}),
|
||||
},
|
||||
}
|
||||
|
||||
for await (const chunk of response.stream) {
|
||||
// Parse the chunk as JSON if it's a string (for tests)
|
||||
let streamEvent: StreamEvent
|
||||
try {
|
||||
streamEvent = typeof chunk === 'string' ?
|
||||
JSON.parse(chunk) :
|
||||
chunk as unknown as StreamEvent
|
||||
} catch (e) {
|
||||
console.error('Failed to parse stream event:', e)
|
||||
continue
|
||||
}
|
||||
try {
|
||||
const command = new ConverseStreamCommand(payload)
|
||||
const response = await this.client.send(command)
|
||||
|
||||
// Handle metadata events first
|
||||
if (streamEvent.metadata?.usage) {
|
||||
yield {
|
||||
type: "usage",
|
||||
inputTokens: streamEvent.metadata.usage.inputTokens || 0,
|
||||
outputTokens: streamEvent.metadata.usage.outputTokens || 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if (!response.stream) {
|
||||
throw new Error("No stream available in the response")
|
||||
}
|
||||
|
||||
// Handle message start
|
||||
if (streamEvent.messageStart) {
|
||||
continue
|
||||
}
|
||||
for await (const chunk of response.stream) {
|
||||
// Parse the chunk as JSON if it's a string (for tests)
|
||||
let streamEvent: StreamEvent
|
||||
try {
|
||||
streamEvent = typeof chunk === "string" ? JSON.parse(chunk) : (chunk as unknown as StreamEvent)
|
||||
} catch (e) {
|
||||
console.error("Failed to parse stream event:", e)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle content blocks
|
||||
if (streamEvent.contentBlockStart?.start?.text) {
|
||||
yield {
|
||||
type: "text",
|
||||
text: streamEvent.contentBlockStart.start.text
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Handle metadata events first
|
||||
if (streamEvent.metadata?.usage) {
|
||||
yield {
|
||||
type: "usage",
|
||||
inputTokens: streamEvent.metadata.usage.inputTokens || 0,
|
||||
outputTokens: streamEvent.metadata.usage.outputTokens || 0,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle content deltas
|
||||
if (streamEvent.contentBlockDelta?.delta?.text) {
|
||||
yield {
|
||||
type: "text",
|
||||
text: streamEvent.contentBlockDelta.delta.text
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Handle message start
|
||||
if (streamEvent.messageStart) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle message stop
|
||||
if (streamEvent.messageStop) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Handle content blocks
|
||||
if (streamEvent.contentBlockStart?.start?.text) {
|
||||
yield {
|
||||
type: "text",
|
||||
text: streamEvent.contentBlockStart.start.text,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
} catch (error: unknown) {
|
||||
console.error('Bedrock Runtime API Error:', error)
|
||||
// Only access stack if error is an Error object
|
||||
if (error instanceof Error) {
|
||||
console.error('Error stack:', error.stack)
|
||||
yield {
|
||||
type: "text",
|
||||
text: `Error: ${error.message}`
|
||||
}
|
||||
yield {
|
||||
type: "usage",
|
||||
inputTokens: 0,
|
||||
outputTokens: 0
|
||||
}
|
||||
throw error
|
||||
} else {
|
||||
const unknownError = new Error("An unknown error occurred")
|
||||
yield {
|
||||
type: "text",
|
||||
text: unknownError.message
|
||||
}
|
||||
yield {
|
||||
type: "usage",
|
||||
inputTokens: 0,
|
||||
outputTokens: 0
|
||||
}
|
||||
throw unknownError
|
||||
}
|
||||
}
|
||||
}
|
||||
// Handle content deltas
|
||||
if (streamEvent.contentBlockDelta?.delta?.text) {
|
||||
yield {
|
||||
type: "text",
|
||||
text: streamEvent.contentBlockDelta.delta.text,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
getModel(): { id: BedrockModelId | string; info: ModelInfo } {
|
||||
const modelId = this.options.apiModelId
|
||||
if (modelId) {
|
||||
// For tests, allow any model ID
|
||||
if (process.env.NODE_ENV === 'test') {
|
||||
return {
|
||||
id: modelId,
|
||||
info: {
|
||||
maxTokens: 5000,
|
||||
contextWindow: 128_000,
|
||||
supportsPromptCache: false
|
||||
}
|
||||
}
|
||||
}
|
||||
// For production, validate against known models
|
||||
if (modelId in bedrockModels) {
|
||||
const id = modelId as BedrockModelId
|
||||
return { id, info: bedrockModels[id] }
|
||||
}
|
||||
}
|
||||
return {
|
||||
id: bedrockDefaultModelId,
|
||||
info: bedrockModels[bedrockDefaultModelId]
|
||||
}
|
||||
}
|
||||
// Handle message stop
|
||||
if (streamEvent.messageStop) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
console.error("Bedrock Runtime API Error:", error)
|
||||
// Only access stack if error is an Error object
|
||||
if (error instanceof Error) {
|
||||
console.error("Error stack:", error.stack)
|
||||
yield {
|
||||
type: "text",
|
||||
text: `Error: ${error.message}`,
|
||||
}
|
||||
yield {
|
||||
type: "usage",
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
}
|
||||
throw error
|
||||
} else {
|
||||
const unknownError = new Error("An unknown error occurred")
|
||||
yield {
|
||||
type: "text",
|
||||
text: unknownError.message,
|
||||
}
|
||||
yield {
|
||||
type: "usage",
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
}
|
||||
throw unknownError
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async completePrompt(prompt: string): Promise<string> {
|
||||
try {
|
||||
const modelConfig = this.getModel()
|
||||
|
||||
// Handle cross-region inference
|
||||
let modelId: string
|
||||
if (this.options.awsUseCrossRegionInference) {
|
||||
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
|
||||
switch (regionPrefix) {
|
||||
case "us-":
|
||||
modelId = `us.${modelConfig.id}`
|
||||
break
|
||||
case "eu-":
|
||||
modelId = `eu.${modelConfig.id}`
|
||||
break
|
||||
default:
|
||||
modelId = modelConfig.id
|
||||
break
|
||||
}
|
||||
} else {
|
||||
modelId = modelConfig.id
|
||||
}
|
||||
getModel(): { id: BedrockModelId | string; info: ModelInfo } {
|
||||
const modelId = this.options.apiModelId
|
||||
if (modelId) {
|
||||
// For tests, allow any model ID
|
||||
if (process.env.NODE_ENV === "test") {
|
||||
return {
|
||||
id: modelId,
|
||||
info: {
|
||||
maxTokens: 5000,
|
||||
contextWindow: 128_000,
|
||||
supportsPromptCache: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
// For production, validate against known models
|
||||
if (modelId in bedrockModels) {
|
||||
const id = modelId as BedrockModelId
|
||||
return { id, info: bedrockModels[id] }
|
||||
}
|
||||
}
|
||||
return {
|
||||
id: bedrockDefaultModelId,
|
||||
info: bedrockModels[bedrockDefaultModelId],
|
||||
}
|
||||
}
|
||||
|
||||
const payload = {
|
||||
modelId,
|
||||
messages: convertToBedrockConverseMessages([{
|
||||
role: "user",
|
||||
content: prompt
|
||||
}]),
|
||||
inferenceConfig: {
|
||||
maxTokens: modelConfig.info.maxTokens || 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1
|
||||
}
|
||||
}
|
||||
async completePrompt(prompt: string): Promise<string> {
|
||||
try {
|
||||
const modelConfig = this.getModel()
|
||||
|
||||
const command = new ConverseCommand(payload)
|
||||
const response = await this.client.send(command)
|
||||
// Handle cross-region inference
|
||||
let modelId: string
|
||||
if (this.options.awsUseCrossRegionInference) {
|
||||
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
|
||||
switch (regionPrefix) {
|
||||
case "us-":
|
||||
modelId = `us.${modelConfig.id}`
|
||||
break
|
||||
case "eu-":
|
||||
modelId = `eu.${modelConfig.id}`
|
||||
break
|
||||
default:
|
||||
modelId = modelConfig.id
|
||||
break
|
||||
}
|
||||
} else {
|
||||
modelId = modelConfig.id
|
||||
}
|
||||
|
||||
if (response.output && response.output instanceof Uint8Array) {
|
||||
try {
|
||||
const outputStr = new TextDecoder().decode(response.output)
|
||||
const output = JSON.parse(outputStr)
|
||||
if (output.content) {
|
||||
return output.content
|
||||
}
|
||||
} catch (parseError) {
|
||||
console.error('Failed to parse Bedrock response:', parseError)
|
||||
}
|
||||
}
|
||||
return ''
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`Bedrock completion error: ${error.message}`)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
const payload = {
|
||||
modelId,
|
||||
messages: convertToBedrockConverseMessages([
|
||||
{
|
||||
role: "user",
|
||||
content: prompt,
|
||||
},
|
||||
]),
|
||||
inferenceConfig: {
|
||||
maxTokens: modelConfig.info.maxTokens || 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1,
|
||||
},
|
||||
}
|
||||
|
||||
const command = new ConverseCommand(payload)
|
||||
const response = await this.client.send(command)
|
||||
|
||||
if (response.output && response.output instanceof Uint8Array) {
|
||||
try {
|
||||
const outputStr = new TextDecoder().decode(response.output)
|
||||
const output = JSON.parse(outputStr)
|
||||
if (output.content) {
|
||||
return output.content
|
||||
}
|
||||
} catch (parseError) {
|
||||
console.error("Failed to parse Bedrock response:", parseError)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`Bedrock completion error: ${error.message}`)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,24 +3,24 @@ import { ApiHandlerOptions, ModelInfo } from "../../shared/api"
|
||||
import { deepSeekModels, deepSeekDefaultModelId } from "../../shared/api"
|
||||
|
||||
export class DeepSeekHandler extends OpenAiHandler {
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
if (!options.deepSeekApiKey) {
|
||||
throw new Error("DeepSeek API key is required. Please provide it in the settings.")
|
||||
}
|
||||
super({
|
||||
...options,
|
||||
openAiApiKey: options.deepSeekApiKey,
|
||||
openAiModelId: options.deepSeekModelId ?? deepSeekDefaultModelId,
|
||||
openAiBaseUrl: options.deepSeekBaseUrl ?? "https://api.deepseek.com/v1",
|
||||
includeMaxTokens: true
|
||||
})
|
||||
}
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
if (!options.deepSeekApiKey) {
|
||||
throw new Error("DeepSeek API key is required. Please provide it in the settings.")
|
||||
}
|
||||
super({
|
||||
...options,
|
||||
openAiApiKey: options.deepSeekApiKey,
|
||||
openAiModelId: options.deepSeekModelId ?? deepSeekDefaultModelId,
|
||||
openAiBaseUrl: options.deepSeekBaseUrl ?? "https://api.deepseek.com/v1",
|
||||
includeMaxTokens: true,
|
||||
})
|
||||
}
|
||||
|
||||
override getModel(): { id: string; info: ModelInfo } {
|
||||
const modelId = this.options.deepSeekModelId ?? deepSeekDefaultModelId
|
||||
return {
|
||||
id: modelId,
|
||||
info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId]
|
||||
}
|
||||
}
|
||||
override getModel(): { id: string; info: ModelInfo } {
|
||||
const modelId = this.options.deepSeekModelId ?? deepSeekDefaultModelId
|
||||
return {
|
||||
id: modelId,
|
||||
info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,17 +72,17 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
|
||||
maxTokens = 8_192
|
||||
}
|
||||
|
||||
const { data: completion, response } = await this.client.chat.completions.create({
|
||||
model: this.getModel().id,
|
||||
max_tokens: maxTokens,
|
||||
temperature: 0,
|
||||
messages: openAiMessages,
|
||||
stream: true,
|
||||
}).withResponse();
|
||||
const { data: completion, response } = await this.client.chat.completions
|
||||
.create({
|
||||
model: this.getModel().id,
|
||||
max_tokens: maxTokens,
|
||||
temperature: 0,
|
||||
messages: openAiMessages,
|
||||
stream: true,
|
||||
})
|
||||
.withResponse()
|
||||
|
||||
const completionRequestId = response.headers.get(
|
||||
'x-completion-request-id',
|
||||
);
|
||||
const completionRequestId = response.headers.get("x-completion-request-id")
|
||||
|
||||
for await (const chunk of completion) {
|
||||
const delta = chunk.choices[0]?.delta
|
||||
@@ -96,13 +96,16 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await axios.get(`https://glama.ai/api/gateway/v1/completion-requests/${completionRequestId}`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.options.glamaApiKey}`,
|
||||
const response = await axios.get(
|
||||
`https://glama.ai/api/gateway/v1/completion-requests/${completionRequestId}`,
|
||||
{
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.options.glamaApiKey}`,
|
||||
},
|
||||
},
|
||||
})
|
||||
)
|
||||
|
||||
const completionRequest = response.data;
|
||||
const completionRequest = response.data
|
||||
|
||||
if (completionRequest.tokenUsage) {
|
||||
yield {
|
||||
@@ -113,7 +116,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
|
||||
outputTokens: completionRequest.tokenUsage.completionTokens,
|
||||
totalCost: parseFloat(completionRequest.totalCostUsd),
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error fetching Glama completion details", error)
|
||||
}
|
||||
@@ -126,7 +129,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
|
||||
if (modelId && modelInfo) {
|
||||
return { id: modelId, info: modelInfo }
|
||||
}
|
||||
|
||||
|
||||
return { id: glamaDefaultModelId, info: glamaDefaultModelInfo }
|
||||
}
|
||||
|
||||
@@ -141,7 +144,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
|
||||
if (this.getModel().id.startsWith("anthropic/")) {
|
||||
requestOptions.max_tokens = 8192
|
||||
}
|
||||
|
||||
|
||||
const response = await this.client.chat.completions.create(requestOptions)
|
||||
return response.choices[0]?.message.content || ""
|
||||
} catch (error) {
|
||||
|
||||
@@ -60,7 +60,7 @@ export class LmStudioHandler implements ApiHandler, SingleCompletionHandler {
|
||||
model: this.getModel().id,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
stream: false,
|
||||
})
|
||||
return response.choices[0]?.message.content || ""
|
||||
} catch (error) {
|
||||
|
||||
@@ -53,7 +53,7 @@ export class OllamaHandler implements ApiHandler, SingleCompletionHandler {
|
||||
model: this.getModel().id,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
stream: false,
|
||||
})
|
||||
return response.choices[0]?.message.content || ""
|
||||
} catch (error) {
|
||||
|
||||
@@ -32,7 +32,10 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
|
||||
// o1 doesnt support streaming or non-1 temp but does support a developer prompt
|
||||
const response = await this.client.chat.completions.create({
|
||||
model: modelId,
|
||||
messages: [{ role: modelId === "o1" ? "developer" : "user", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
|
||||
messages: [
|
||||
{ role: modelId === "o1" ? "developer" : "user", content: systemPrompt },
|
||||
...convertToOpenAiMessages(messages),
|
||||
],
|
||||
})
|
||||
yield {
|
||||
type: "text",
|
||||
@@ -98,14 +101,14 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
|
||||
// o1 doesn't support non-1 temp
|
||||
requestOptions = {
|
||||
model: modelId,
|
||||
messages: [{ role: "user", content: prompt }]
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
}
|
||||
break
|
||||
default:
|
||||
requestOptions = {
|
||||
model: modelId,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
temperature: 0
|
||||
temperature: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
this.options = options
|
||||
// Azure API shape slightly differs from the core API shape: https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
|
||||
const urlHost = new URL(this.options.openAiBaseUrl ?? "").host;
|
||||
const urlHost = new URL(this.options.openAiBaseUrl ?? "").host
|
||||
if (urlHost === "azure.com" || urlHost.endsWith(".azure.com")) {
|
||||
this.client = new AzureOpenAI({
|
||||
baseURL: this.options.openAiBaseUrl,
|
||||
@@ -39,7 +39,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
|
||||
if (this.options.openAiStreamingEnabled ?? true) {
|
||||
const systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = {
|
||||
role: "system",
|
||||
content: systemPrompt
|
||||
content: systemPrompt,
|
||||
}
|
||||
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
|
||||
model: modelId,
|
||||
@@ -74,14 +74,14 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
|
||||
// o1 for instance doesnt support streaming, non-1 temp, or system prompt
|
||||
const systemMessage: OpenAI.Chat.ChatCompletionUserMessageParam = {
|
||||
role: "user",
|
||||
content: systemPrompt
|
||||
content: systemPrompt,
|
||||
}
|
||||
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
|
||||
model: modelId,
|
||||
messages: [systemMessage, ...convertToOpenAiMessages(messages)],
|
||||
}
|
||||
const response = await this.client.chat.completions.create(requestOptions)
|
||||
|
||||
|
||||
yield {
|
||||
type: "text",
|
||||
text: response.choices[0]?.message.content || "",
|
||||
@@ -108,7 +108,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
temperature: 0,
|
||||
}
|
||||
|
||||
|
||||
const response = await this.client.chat.completions.create(requestOptions)
|
||||
return response.choices[0]?.message.content || ""
|
||||
} catch (error) {
|
||||
|
||||
@@ -9,12 +9,12 @@ import delay from "delay"
|
||||
|
||||
// Add custom interface for OpenRouter params
|
||||
type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & {
|
||||
transforms?: string[];
|
||||
transforms?: string[]
|
||||
}
|
||||
|
||||
// Add custom interface for OpenRouter usage chunk
|
||||
interface OpenRouterApiStreamUsageChunk extends ApiStreamUsageChunk {
|
||||
fullResponseText: string;
|
||||
fullResponseText: string
|
||||
}
|
||||
|
||||
import { SingleCompletionHandler } from ".."
|
||||
@@ -35,7 +35,10 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
||||
})
|
||||
}
|
||||
|
||||
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): AsyncGenerator<ApiStreamChunk> {
|
||||
async *createMessage(
|
||||
systemPrompt: string,
|
||||
messages: Anthropic.Messages.MessageParam[],
|
||||
): AsyncGenerator<ApiStreamChunk> {
|
||||
// Convert Anthropic messages to OpenAI format
|
||||
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
|
||||
{ role: "system", content: systemPrompt },
|
||||
@@ -108,7 +111,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
||||
break
|
||||
}
|
||||
// https://openrouter.ai/docs/transforms
|
||||
let fullResponseText = "";
|
||||
let fullResponseText = ""
|
||||
const stream = await this.client.chat.completions.create({
|
||||
model: this.getModel().id,
|
||||
max_tokens: maxTokens,
|
||||
@@ -116,8 +119,8 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
||||
messages: openAiMessages,
|
||||
stream: true,
|
||||
// This way, the transforms field will only be included in the parameters when openRouterUseMiddleOutTransform is true.
|
||||
...(this.options.openRouterUseMiddleOutTransform && { transforms: ["middle-out"] })
|
||||
} as OpenRouterChatCompletionParams);
|
||||
...(this.options.openRouterUseMiddleOutTransform && { transforms: ["middle-out"] }),
|
||||
} as OpenRouterChatCompletionParams)
|
||||
|
||||
let genId: string | undefined
|
||||
|
||||
@@ -135,11 +138,11 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
||||
|
||||
const delta = chunk.choices[0]?.delta
|
||||
if (delta?.content) {
|
||||
fullResponseText += delta.content;
|
||||
fullResponseText += delta.content
|
||||
yield {
|
||||
type: "text",
|
||||
text: delta.content,
|
||||
} as ApiStreamChunk;
|
||||
} as ApiStreamChunk
|
||||
}
|
||||
// if (chunk.usage) {
|
||||
// yield {
|
||||
@@ -170,13 +173,12 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
||||
inputTokens: generation?.native_tokens_prompt || 0,
|
||||
outputTokens: generation?.native_tokens_completion || 0,
|
||||
totalCost: generation?.total_cost || 0,
|
||||
fullResponseText
|
||||
} as OpenRouterApiStreamUsageChunk;
|
||||
fullResponseText,
|
||||
} as OpenRouterApiStreamUsageChunk
|
||||
} catch (error) {
|
||||
// ignore if fails
|
||||
console.error("Error fetching OpenRouter generation details:", error)
|
||||
}
|
||||
|
||||
}
|
||||
getModel(): { id: string; info: ModelInfo } {
|
||||
const modelId = this.options.openRouterModelId
|
||||
@@ -193,7 +195,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
||||
model: this.getModel().id,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
stream: false,
|
||||
})
|
||||
|
||||
if ("error" in response) {
|
||||
|
||||
@@ -91,14 +91,14 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
|
||||
max_tokens: this.getModel().info.maxTokens || 8192,
|
||||
temperature: 0,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
stream: false
|
||||
stream: false,
|
||||
})
|
||||
|
||||
const content = response.content[0]
|
||||
if (content.type === 'text') {
|
||||
if (content.type === "text") {
|
||||
return content.text
|
||||
}
|
||||
return ''
|
||||
return ""
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`Vertex completion error: ${error.message}`)
|
||||
|
||||
@@ -1,31 +1,31 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk";
|
||||
import * as vscode from 'vscode';
|
||||
import { ApiHandler, SingleCompletionHandler } from "../";
|
||||
import { calculateApiCost } from "../../utils/cost";
|
||||
import { ApiStream } from "../transform/stream";
|
||||
import { convertToVsCodeLmMessages } from "../transform/vscode-lm-format";
|
||||
import { SELECTOR_SEPARATOR, stringifyVsCodeLmModelSelector } from "../../shared/vsCodeSelectorUtils";
|
||||
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api";
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import * as vscode from "vscode"
|
||||
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||
import { calculateApiCost } from "../../utils/cost"
|
||||
import { ApiStream } from "../transform/stream"
|
||||
import { convertToVsCodeLmMessages } from "../transform/vscode-lm-format"
|
||||
import { SELECTOR_SEPARATOR, stringifyVsCodeLmModelSelector } from "../../shared/vsCodeSelectorUtils"
|
||||
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
|
||||
|
||||
/**
|
||||
* Handles interaction with VS Code's Language Model API for chat-based operations.
|
||||
* This handler implements the ApiHandler interface to provide VS Code LM specific functionality.
|
||||
*
|
||||
*
|
||||
* @implements {ApiHandler}
|
||||
*
|
||||
*
|
||||
* @remarks
|
||||
* The handler manages a VS Code language model chat client and provides methods to:
|
||||
* - Create and manage chat client instances
|
||||
* - Stream messages using VS Code's Language Model API
|
||||
* - Retrieve model information
|
||||
*
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const options = {
|
||||
* vsCodeLmModelSelector: { vendor: "copilot", family: "gpt-4" }
|
||||
* };
|
||||
* const handler = new VsCodeLmHandler(options);
|
||||
*
|
||||
*
|
||||
* // Stream a conversation
|
||||
* const systemPrompt = "You are a helpful assistant";
|
||||
* const messages = [{ role: "user", content: "Hello!" }];
|
||||
@@ -35,39 +35,36 @@ import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../..
|
||||
* ```
|
||||
*/
|
||||
export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
||||
|
||||
private options: ApiHandlerOptions;
|
||||
private client: vscode.LanguageModelChat | null;
|
||||
private disposable: vscode.Disposable | null;
|
||||
private currentRequestCancellation: vscode.CancellationTokenSource | null;
|
||||
private options: ApiHandlerOptions
|
||||
private client: vscode.LanguageModelChat | null
|
||||
private disposable: vscode.Disposable | null
|
||||
private currentRequestCancellation: vscode.CancellationTokenSource | null
|
||||
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
this.options = options;
|
||||
this.client = null;
|
||||
this.disposable = null;
|
||||
this.currentRequestCancellation = null;
|
||||
this.options = options
|
||||
this.client = null
|
||||
this.disposable = null
|
||||
this.currentRequestCancellation = null
|
||||
|
||||
try {
|
||||
// Listen for model changes and reset client
|
||||
this.disposable = vscode.workspace.onDidChangeConfiguration(event => {
|
||||
if (event.affectsConfiguration('lm')) {
|
||||
this.disposable = vscode.workspace.onDidChangeConfiguration((event) => {
|
||||
if (event.affectsConfiguration("lm")) {
|
||||
try {
|
||||
this.client = null;
|
||||
this.ensureCleanState();
|
||||
}
|
||||
catch (error) {
|
||||
console.error('Error during configuration change cleanup:', error);
|
||||
this.client = null
|
||||
this.ensureCleanState()
|
||||
} catch (error) {
|
||||
console.error("Error during configuration change cleanup:", error)
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
catch (error) {
|
||||
})
|
||||
} catch (error) {
|
||||
// Ensure cleanup if constructor fails
|
||||
this.dispose();
|
||||
this.dispose()
|
||||
|
||||
throw new Error(
|
||||
`Cline <Language Model API>: Failed to initialize handler: ${error instanceof Error ? error.message : 'Unknown error'}`
|
||||
);
|
||||
`Cline <Language Model API>: Failed to initialize handler: ${error instanceof Error ? error.message : "Unknown error"}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,46 +74,46 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
||||
* @param selector - Selector criteria to filter language model chat instances
|
||||
* @returns Promise resolving to the first matching language model chat instance
|
||||
* @throws Error when no matching models are found with the given selector
|
||||
*
|
||||
*
|
||||
* @example
|
||||
* const selector = { vendor: "copilot", family: "gpt-4o" };
|
||||
* const chatClient = await createClient(selector);
|
||||
*/
|
||||
async createClient(selector: vscode.LanguageModelChatSelector): Promise<vscode.LanguageModelChat> {
|
||||
try {
|
||||
const models = await vscode.lm.selectChatModels(selector);
|
||||
const models = await vscode.lm.selectChatModels(selector)
|
||||
|
||||
// Use first available model or create a minimal model object
|
||||
if (models && Array.isArray(models) && models.length > 0) {
|
||||
return models[0];
|
||||
return models[0]
|
||||
}
|
||||
|
||||
// Create a minimal model if no models are available
|
||||
return {
|
||||
id: 'default-lm',
|
||||
name: 'Default Language Model',
|
||||
vendor: 'vscode',
|
||||
family: 'lm',
|
||||
version: '1.0',
|
||||
id: "default-lm",
|
||||
name: "Default Language Model",
|
||||
vendor: "vscode",
|
||||
family: "lm",
|
||||
version: "1.0",
|
||||
maxInputTokens: 8192,
|
||||
sendRequest: async (messages, options, token) => {
|
||||
// Provide a minimal implementation
|
||||
return {
|
||||
stream: (async function* () {
|
||||
yield new vscode.LanguageModelTextPart(
|
||||
"Language model functionality is limited. Please check VS Code configuration."
|
||||
);
|
||||
"Language model functionality is limited. Please check VS Code configuration.",
|
||||
)
|
||||
})(),
|
||||
text: (async function* () {
|
||||
yield "Language model functionality is limited. Please check VS Code configuration.";
|
||||
})()
|
||||
};
|
||||
yield "Language model functionality is limited. Please check VS Code configuration."
|
||||
})(),
|
||||
}
|
||||
},
|
||||
countTokens: async () => 0
|
||||
};
|
||||
countTokens: async () => 0,
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
throw new Error(`Cline <Language Model API>: Failed to select model: ${errorMessage}`);
|
||||
const errorMessage = error instanceof Error ? error.message : "Unknown error"
|
||||
throw new Error(`Cline <Language Model API>: Failed to select model: ${errorMessage}`)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,242 +122,234 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
||||
*
|
||||
* @param systemPrompt - The system prompt to initialize the conversation context
|
||||
* @param messages - An array of message parameters following the Anthropic message format
|
||||
*
|
||||
*
|
||||
* @yields {ApiStream} An async generator that yields either text chunks or tool calls from the model response
|
||||
*
|
||||
*
|
||||
* @throws {Error} When vsCodeLmModelSelector option is not provided
|
||||
* @throws {Error} When the response stream encounters an error
|
||||
*
|
||||
*
|
||||
* @remarks
|
||||
* This method handles the initialization of the VS Code LM client if not already created,
|
||||
* converts the messages to VS Code LM format, and streams the response chunks.
|
||||
* Tool calls handling is currently a work in progress.
|
||||
*/
|
||||
dispose(): void {
|
||||
|
||||
if (this.disposable) {
|
||||
|
||||
this.disposable.dispose();
|
||||
this.disposable.dispose()
|
||||
}
|
||||
|
||||
if (this.currentRequestCancellation) {
|
||||
|
||||
this.currentRequestCancellation.cancel();
|
||||
this.currentRequestCancellation.dispose();
|
||||
this.currentRequestCancellation.cancel()
|
||||
this.currentRequestCancellation.dispose()
|
||||
}
|
||||
}
|
||||
|
||||
private async countTokens(text: string | vscode.LanguageModelChatMessage): Promise<number> {
|
||||
// Check for required dependencies
|
||||
if (!this.client) {
|
||||
console.warn('Cline <Language Model API>: No client available for token counting');
|
||||
return 0;
|
||||
console.warn("Cline <Language Model API>: No client available for token counting")
|
||||
return 0
|
||||
}
|
||||
|
||||
if (!this.currentRequestCancellation) {
|
||||
console.warn('Cline <Language Model API>: No cancellation token available for token counting');
|
||||
return 0;
|
||||
console.warn("Cline <Language Model API>: No cancellation token available for token counting")
|
||||
return 0
|
||||
}
|
||||
|
||||
// Validate input
|
||||
if (!text) {
|
||||
console.debug('Cline <Language Model API>: Empty text provided for token counting');
|
||||
return 0;
|
||||
console.debug("Cline <Language Model API>: Empty text provided for token counting")
|
||||
return 0
|
||||
}
|
||||
|
||||
try {
|
||||
// Handle different input types
|
||||
let tokenCount: number;
|
||||
let tokenCount: number
|
||||
|
||||
if (typeof text === 'string') {
|
||||
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token);
|
||||
if (typeof text === "string") {
|
||||
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
|
||||
} else if (text instanceof vscode.LanguageModelChatMessage) {
|
||||
// For chat messages, ensure we have content
|
||||
if (!text.content || (Array.isArray(text.content) && text.content.length === 0)) {
|
||||
console.debug('Cline <Language Model API>: Empty chat message content');
|
||||
return 0;
|
||||
console.debug("Cline <Language Model API>: Empty chat message content")
|
||||
return 0
|
||||
}
|
||||
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token);
|
||||
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
|
||||
} else {
|
||||
console.warn('Cline <Language Model API>: Invalid input type for token counting');
|
||||
return 0;
|
||||
console.warn("Cline <Language Model API>: Invalid input type for token counting")
|
||||
return 0
|
||||
}
|
||||
|
||||
// Validate the result
|
||||
if (typeof tokenCount !== 'number') {
|
||||
console.warn('Cline <Language Model API>: Non-numeric token count received:', tokenCount);
|
||||
return 0;
|
||||
if (typeof tokenCount !== "number") {
|
||||
console.warn("Cline <Language Model API>: Non-numeric token count received:", tokenCount)
|
||||
return 0
|
||||
}
|
||||
|
||||
if (tokenCount < 0) {
|
||||
console.warn('Cline <Language Model API>: Negative token count received:', tokenCount);
|
||||
return 0;
|
||||
console.warn("Cline <Language Model API>: Negative token count received:", tokenCount)
|
||||
return 0
|
||||
}
|
||||
|
||||
return tokenCount;
|
||||
}
|
||||
catch (error) {
|
||||
return tokenCount
|
||||
} catch (error) {
|
||||
// Handle specific error types
|
||||
if (error instanceof vscode.CancellationError) {
|
||||
console.debug('Cline <Language Model API>: Token counting cancelled by user');
|
||||
return 0;
|
||||
console.debug("Cline <Language Model API>: Token counting cancelled by user")
|
||||
return 0
|
||||
}
|
||||
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
console.warn('Cline <Language Model API>: Token counting failed:', errorMessage);
|
||||
const errorMessage = error instanceof Error ? error.message : "Unknown error"
|
||||
console.warn("Cline <Language Model API>: Token counting failed:", errorMessage)
|
||||
|
||||
// Log additional error details if available
|
||||
if (error instanceof Error && error.stack) {
|
||||
console.debug('Token counting error stack:', error.stack);
|
||||
console.debug("Token counting error stack:", error.stack)
|
||||
}
|
||||
|
||||
return 0; // Fallback to prevent stream interruption
|
||||
return 0 // Fallback to prevent stream interruption
|
||||
}
|
||||
}
|
||||
|
||||
private async calculateTotalInputTokens(systemPrompt: string, vsCodeLmMessages: vscode.LanguageModelChatMessage[]): Promise<number> {
|
||||
private async calculateTotalInputTokens(
|
||||
systemPrompt: string,
|
||||
vsCodeLmMessages: vscode.LanguageModelChatMessage[],
|
||||
): Promise<number> {
|
||||
const systemTokens: number = await this.countTokens(systemPrompt)
|
||||
|
||||
const systemTokens: number = await this.countTokens(systemPrompt);
|
||||
const messageTokens: number[] = await Promise.all(vsCodeLmMessages.map((msg) => this.countTokens(msg)))
|
||||
|
||||
const messageTokens: number[] = await Promise.all(
|
||||
vsCodeLmMessages.map(msg => this.countTokens(msg))
|
||||
);
|
||||
|
||||
return systemTokens + messageTokens.reduce(
|
||||
(sum: number, tokens: number): number => sum + tokens, 0
|
||||
);
|
||||
return systemTokens + messageTokens.reduce((sum: number, tokens: number): number => sum + tokens, 0)
|
||||
}
|
||||
|
||||
private ensureCleanState(): void {
|
||||
|
||||
if (this.currentRequestCancellation) {
|
||||
|
||||
this.currentRequestCancellation.cancel();
|
||||
this.currentRequestCancellation.dispose();
|
||||
this.currentRequestCancellation = null;
|
||||
this.currentRequestCancellation.cancel()
|
||||
this.currentRequestCancellation.dispose()
|
||||
this.currentRequestCancellation = null
|
||||
}
|
||||
}
|
||||
|
||||
private async getClient(): Promise<vscode.LanguageModelChat> {
|
||||
if (!this.client) {
|
||||
console.debug('Cline <Language Model API>: Getting client with options:', {
|
||||
console.debug("Cline <Language Model API>: Getting client with options:", {
|
||||
vsCodeLmModelSelector: this.options.vsCodeLmModelSelector,
|
||||
hasOptions: !!this.options,
|
||||
selectorKeys: this.options.vsCodeLmModelSelector ? Object.keys(this.options.vsCodeLmModelSelector) : []
|
||||
});
|
||||
selectorKeys: this.options.vsCodeLmModelSelector ? Object.keys(this.options.vsCodeLmModelSelector) : [],
|
||||
})
|
||||
|
||||
try {
|
||||
// Use default empty selector if none provided to get all available models
|
||||
const selector = this.options?.vsCodeLmModelSelector || {};
|
||||
console.debug('Cline <Language Model API>: Creating client with selector:', selector);
|
||||
this.client = await this.createClient(selector);
|
||||
const selector = this.options?.vsCodeLmModelSelector || {}
|
||||
console.debug("Cline <Language Model API>: Creating client with selector:", selector)
|
||||
this.client = await this.createClient(selector)
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : 'Unknown error';
|
||||
console.error('Cline <Language Model API>: Client creation failed:', message);
|
||||
throw new Error(`Cline <Language Model API>: Failed to create client: ${message}`);
|
||||
const message = error instanceof Error ? error.message : "Unknown error"
|
||||
console.error("Cline <Language Model API>: Client creation failed:", message)
|
||||
throw new Error(`Cline <Language Model API>: Failed to create client: ${message}`)
|
||||
}
|
||||
}
|
||||
|
||||
return this.client;
|
||||
return this.client
|
||||
}
|
||||
|
||||
private cleanTerminalOutput(text: string): string {
|
||||
if (!text) {
|
||||
return '';
|
||||
return ""
|
||||
}
|
||||
|
||||
return text
|
||||
// Нормализуем переносы строк
|
||||
.replace(/\r\n/g, '\n')
|
||||
.replace(/\r/g, '\n')
|
||||
return (
|
||||
text
|
||||
// Нормализуем переносы строк
|
||||
.replace(/\r\n/g, "\n")
|
||||
.replace(/\r/g, "\n")
|
||||
|
||||
// Удаляем ANSI escape sequences
|
||||
.replace(/\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])/g, '') // Полный набор ANSI sequences
|
||||
.replace(/\x9B[0-?]*[ -/]*[@-~]/g, '') // CSI sequences
|
||||
// Удаляем ANSI escape sequences
|
||||
.replace(/\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])/g, "") // Полный набор ANSI sequences
|
||||
.replace(/\x9B[0-?]*[ -/]*[@-~]/g, "") // CSI sequences
|
||||
|
||||
// Удаляем последовательности установки заголовка терминала и прочие OSC sequences
|
||||
.replace(/\x1B\][0-9;]*(?:\x07|\x1B\\)/g, '')
|
||||
// Удаляем последовательности установки заголовка терминала и прочие OSC sequences
|
||||
.replace(/\x1B\][0-9;]*(?:\x07|\x1B\\)/g, "")
|
||||
|
||||
// Удаляем управляющие символы
|
||||
.replace(/[\x00-\x09\x0B-\x0C\x0E-\x1F\x7F]/g, '')
|
||||
// Удаляем управляющие символы
|
||||
.replace(/[\x00-\x09\x0B-\x0C\x0E-\x1F\x7F]/g, "")
|
||||
|
||||
// Удаляем escape-последовательности VS Code
|
||||
.replace(/\x1B[PD].*?\x1B\\/g, '') // DCS sequences
|
||||
.replace(/\x1B_.*?\x1B\\/g, '') // APC sequences
|
||||
.replace(/\x1B\^.*?\x1B\\/g, '') // PM sequences
|
||||
.replace(/\x1B\[[\d;]*[HfABCDEFGJKST]/g, '') // Cursor movement and clear screen
|
||||
// Удаляем escape-последовательности VS Code
|
||||
.replace(/\x1B[PD].*?\x1B\\/g, "") // DCS sequences
|
||||
.replace(/\x1B_.*?\x1B\\/g, "") // APC sequences
|
||||
.replace(/\x1B\^.*?\x1B\\/g, "") // PM sequences
|
||||
.replace(/\x1B\[[\d;]*[HfABCDEFGJKST]/g, "") // Cursor movement and clear screen
|
||||
|
||||
// Удаляем пути Windows и служебную информацию
|
||||
.replace(/^(?:PS )?[A-Z]:\\[^\n]*$/mg, '')
|
||||
.replace(/^;?Cwd=.*$/mg, '')
|
||||
// Удаляем пути Windows и служебную информацию
|
||||
.replace(/^(?:PS )?[A-Z]:\\[^\n]*$/gm, "")
|
||||
.replace(/^;?Cwd=.*$/gm, "")
|
||||
|
||||
// Очищаем экранированные последовательности
|
||||
.replace(/\\x[0-9a-fA-F]{2}/g, '')
|
||||
.replace(/\\u[0-9a-fA-F]{4}/g, '')
|
||||
// Очищаем экранированные последовательности
|
||||
.replace(/\\x[0-9a-fA-F]{2}/g, "")
|
||||
.replace(/\\u[0-9a-fA-F]{4}/g, "")
|
||||
|
||||
// Финальная очистка
|
||||
.replace(/\n{3,}/g, '\n\n') // Убираем множественные пустые строки
|
||||
.trim();
|
||||
// Финальная очистка
|
||||
.replace(/\n{3,}/g, "\n\n") // Убираем множественные пустые строки
|
||||
.trim()
|
||||
)
|
||||
}
|
||||
|
||||
private cleanMessageContent(content: any): any {
|
||||
if (!content) {
|
||||
return content;
|
||||
return content
|
||||
}
|
||||
|
||||
if (typeof content === 'string') {
|
||||
return this.cleanTerminalOutput(content);
|
||||
if (typeof content === "string") {
|
||||
return this.cleanTerminalOutput(content)
|
||||
}
|
||||
|
||||
if (Array.isArray(content)) {
|
||||
return content.map(item => this.cleanMessageContent(item));
|
||||
return content.map((item) => this.cleanMessageContent(item))
|
||||
}
|
||||
|
||||
if (typeof content === 'object') {
|
||||
const cleaned: any = {};
|
||||
if (typeof content === "object") {
|
||||
const cleaned: any = {}
|
||||
for (const [key, value] of Object.entries(content)) {
|
||||
cleaned[key] = this.cleanMessageContent(value);
|
||||
cleaned[key] = this.cleanMessageContent(value)
|
||||
}
|
||||
return cleaned;
|
||||
return cleaned
|
||||
}
|
||||
|
||||
return content;
|
||||
return content
|
||||
}
|
||||
|
||||
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
||||
|
||||
// Ensure clean state before starting a new request
|
||||
this.ensureCleanState();
|
||||
const client: vscode.LanguageModelChat = await this.getClient();
|
||||
this.ensureCleanState()
|
||||
const client: vscode.LanguageModelChat = await this.getClient()
|
||||
|
||||
// Clean system prompt and messages
|
||||
const cleanedSystemPrompt = this.cleanTerminalOutput(systemPrompt);
|
||||
const cleanedMessages = messages.map(msg => ({
|
||||
const cleanedSystemPrompt = this.cleanTerminalOutput(systemPrompt)
|
||||
const cleanedMessages = messages.map((msg) => ({
|
||||
...msg,
|
||||
content: this.cleanMessageContent(msg.content)
|
||||
}));
|
||||
content: this.cleanMessageContent(msg.content),
|
||||
}))
|
||||
|
||||
// Convert Anthropic messages to VS Code LM messages
|
||||
const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = [
|
||||
vscode.LanguageModelChatMessage.Assistant(cleanedSystemPrompt),
|
||||
...convertToVsCodeLmMessages(cleanedMessages),
|
||||
];
|
||||
]
|
||||
|
||||
// Initialize cancellation token for the request
|
||||
this.currentRequestCancellation = new vscode.CancellationTokenSource();
|
||||
this.currentRequestCancellation = new vscode.CancellationTokenSource()
|
||||
|
||||
// Calculate input tokens before starting the stream
|
||||
const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages);
|
||||
const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages)
|
||||
|
||||
// Accumulate the text and count at the end of the stream to reduce token counting overhead.
|
||||
let accumulatedText: string = '';
|
||||
let accumulatedText: string = ""
|
||||
|
||||
try {
|
||||
|
||||
// Create the response stream with minimal required options
|
||||
const requestOptions: vscode.LanguageModelChatRequestOptions = {
|
||||
justification: `Cline would like to use '${client.name}' from '${client.vendor}', Click 'Allow' to proceed.`
|
||||
};
|
||||
justification: `Cline would like to use '${client.name}' from '${client.vendor}', Click 'Allow' to proceed.`,
|
||||
}
|
||||
|
||||
// Note: Tool support is currently provided by the VSCode Language Model API directly
|
||||
// Extensions can register tools using vscode.lm.registerTool()
|
||||
@@ -368,40 +357,40 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
||||
const response: vscode.LanguageModelChatResponse = await client.sendRequest(
|
||||
vsCodeLmMessages,
|
||||
requestOptions,
|
||||
this.currentRequestCancellation.token
|
||||
);
|
||||
this.currentRequestCancellation.token,
|
||||
)
|
||||
|
||||
// Consume the stream and handle both text and tool call chunks
|
||||
for await (const chunk of response.stream) {
|
||||
if (chunk instanceof vscode.LanguageModelTextPart) {
|
||||
// Validate text part value
|
||||
if (typeof chunk.value !== 'string') {
|
||||
console.warn('Cline <Language Model API>: Invalid text part value received:', chunk.value);
|
||||
continue;
|
||||
if (typeof chunk.value !== "string") {
|
||||
console.warn("Cline <Language Model API>: Invalid text part value received:", chunk.value)
|
||||
continue
|
||||
}
|
||||
|
||||
accumulatedText += chunk.value;
|
||||
accumulatedText += chunk.value
|
||||
yield {
|
||||
type: "text",
|
||||
text: chunk.value,
|
||||
};
|
||||
}
|
||||
} else if (chunk instanceof vscode.LanguageModelToolCallPart) {
|
||||
try {
|
||||
// Validate tool call parameters
|
||||
if (!chunk.name || typeof chunk.name !== 'string') {
|
||||
console.warn('Cline <Language Model API>: Invalid tool name received:', chunk.name);
|
||||
continue;
|
||||
if (!chunk.name || typeof chunk.name !== "string") {
|
||||
console.warn("Cline <Language Model API>: Invalid tool name received:", chunk.name)
|
||||
continue
|
||||
}
|
||||
|
||||
if (!chunk.callId || typeof chunk.callId !== 'string') {
|
||||
console.warn('Cline <Language Model API>: Invalid tool callId received:', chunk.callId);
|
||||
continue;
|
||||
if (!chunk.callId || typeof chunk.callId !== "string") {
|
||||
console.warn("Cline <Language Model API>: Invalid tool callId received:", chunk.callId)
|
||||
continue
|
||||
}
|
||||
|
||||
// Ensure input is a valid object
|
||||
if (!chunk.input || typeof chunk.input !== 'object') {
|
||||
console.warn('Cline <Language Model API>: Invalid tool input received:', chunk.input);
|
||||
continue;
|
||||
if (!chunk.input || typeof chunk.input !== "object") {
|
||||
console.warn("Cline <Language Model API>: Invalid tool input received:", chunk.input)
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert tool calls to text format with proper error handling
|
||||
@@ -409,82 +398,75 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
||||
type: "tool_call",
|
||||
name: chunk.name,
|
||||
arguments: chunk.input,
|
||||
callId: chunk.callId
|
||||
};
|
||||
callId: chunk.callId,
|
||||
}
|
||||
|
||||
const toolCallText = JSON.stringify(toolCall);
|
||||
accumulatedText += toolCallText;
|
||||
const toolCallText = JSON.stringify(toolCall)
|
||||
accumulatedText += toolCallText
|
||||
|
||||
// Log tool call for debugging
|
||||
console.debug('Cline <Language Model API>: Processing tool call:', {
|
||||
console.debug("Cline <Language Model API>: Processing tool call:", {
|
||||
name: chunk.name,
|
||||
callId: chunk.callId,
|
||||
inputSize: JSON.stringify(chunk.input).length
|
||||
});
|
||||
inputSize: JSON.stringify(chunk.input).length,
|
||||
})
|
||||
|
||||
yield {
|
||||
type: "text",
|
||||
text: toolCallText,
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Cline <Language Model API>: Failed to process tool call:', error);
|
||||
console.error("Cline <Language Model API>: Failed to process tool call:", error)
|
||||
// Continue processing other chunks even if one fails
|
||||
continue;
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
console.warn('Cline <Language Model API>: Unknown chunk type received:', chunk);
|
||||
console.warn("Cline <Language Model API>: Unknown chunk type received:", chunk)
|
||||
}
|
||||
}
|
||||
|
||||
// Count tokens in the accumulated text after stream completion
|
||||
const totalOutputTokens: number = await this.countTokens(accumulatedText);
|
||||
const totalOutputTokens: number = await this.countTokens(accumulatedText)
|
||||
|
||||
// Report final usage after stream completion
|
||||
yield {
|
||||
type: "usage",
|
||||
inputTokens: totalInputTokens,
|
||||
outputTokens: totalOutputTokens,
|
||||
totalCost: calculateApiCost(
|
||||
this.getModel().info,
|
||||
totalInputTokens,
|
||||
totalOutputTokens
|
||||
)
|
||||
};
|
||||
}
|
||||
catch (error: unknown) {
|
||||
|
||||
this.ensureCleanState();
|
||||
totalCost: calculateApiCost(this.getModel().info, totalInputTokens, totalOutputTokens),
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
this.ensureCleanState()
|
||||
|
||||
if (error instanceof vscode.CancellationError) {
|
||||
|
||||
throw new Error("Cline <Language Model API>: Request cancelled by user");
|
||||
throw new Error("Cline <Language Model API>: Request cancelled by user")
|
||||
}
|
||||
|
||||
if (error instanceof Error) {
|
||||
console.error('Cline <Language Model API>: Stream error details:', {
|
||||
console.error("Cline <Language Model API>: Stream error details:", {
|
||||
message: error.message,
|
||||
stack: error.stack,
|
||||
name: error.name
|
||||
});
|
||||
name: error.name,
|
||||
})
|
||||
|
||||
// Return original error if it's already an Error instance
|
||||
throw error;
|
||||
} else if (typeof error === 'object' && error !== null) {
|
||||
throw error
|
||||
} else if (typeof error === "object" && error !== null) {
|
||||
// Handle error-like objects
|
||||
const errorDetails = JSON.stringify(error, null, 2);
|
||||
console.error('Cline <Language Model API>: Stream error object:', errorDetails);
|
||||
throw new Error(`Cline <Language Model API>: Response stream error: ${errorDetails}`);
|
||||
const errorDetails = JSON.stringify(error, null, 2)
|
||||
console.error("Cline <Language Model API>: Stream error object:", errorDetails)
|
||||
throw new Error(`Cline <Language Model API>: Response stream error: ${errorDetails}`)
|
||||
} else {
|
||||
// Fallback for unknown error types
|
||||
const errorMessage = String(error);
|
||||
console.error('Cline <Language Model API>: Unknown stream error:', errorMessage);
|
||||
throw new Error(`Cline <Language Model API>: Response stream error: ${errorMessage}`);
|
||||
const errorMessage = String(error)
|
||||
console.error("Cline <Language Model API>: Unknown stream error:", errorMessage)
|
||||
throw new Error(`Cline <Language Model API>: Response stream error: ${errorMessage}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return model information based on the current client state
|
||||
getModel(): { id: string; info: ModelInfo; } {
|
||||
getModel(): { id: string; info: ModelInfo } {
|
||||
if (this.client) {
|
||||
// Validate client properties
|
||||
const requiredProps = {
|
||||
@@ -492,68 +474,69 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
||||
vendor: this.client.vendor,
|
||||
family: this.client.family,
|
||||
version: this.client.version,
|
||||
maxInputTokens: this.client.maxInputTokens
|
||||
};
|
||||
maxInputTokens: this.client.maxInputTokens,
|
||||
}
|
||||
|
||||
// Log any missing properties for debugging
|
||||
for (const [prop, value] of Object.entries(requiredProps)) {
|
||||
if (!value && value !== 0) {
|
||||
console.warn(`Cline <Language Model API>: Client missing ${prop} property`);
|
||||
console.warn(`Cline <Language Model API>: Client missing ${prop} property`)
|
||||
}
|
||||
}
|
||||
|
||||
// Construct model ID using available information
|
||||
const modelParts = [
|
||||
this.client.vendor,
|
||||
this.client.family,
|
||||
this.client.version
|
||||
].filter(Boolean);
|
||||
const modelParts = [this.client.vendor, this.client.family, this.client.version].filter(Boolean)
|
||||
|
||||
const modelId = this.client.id || modelParts.join(SELECTOR_SEPARATOR);
|
||||
const modelId = this.client.id || modelParts.join(SELECTOR_SEPARATOR)
|
||||
|
||||
// Build model info with conservative defaults for missing values
|
||||
const modelInfo: ModelInfo = {
|
||||
maxTokens: -1, // Unlimited tokens by default
|
||||
contextWindow: typeof this.client.maxInputTokens === 'number'
|
||||
? Math.max(0, this.client.maxInputTokens)
|
||||
: openAiModelInfoSaneDefaults.contextWindow,
|
||||
contextWindow:
|
||||
typeof this.client.maxInputTokens === "number"
|
||||
? Math.max(0, this.client.maxInputTokens)
|
||||
: openAiModelInfoSaneDefaults.contextWindow,
|
||||
supportsImages: false, // VSCode Language Model API currently doesn't support image inputs
|
||||
supportsPromptCache: true,
|
||||
inputPrice: 0,
|
||||
outputPrice: 0,
|
||||
description: `VSCode Language Model: ${modelId}`
|
||||
};
|
||||
description: `VSCode Language Model: ${modelId}`,
|
||||
}
|
||||
|
||||
return { id: modelId, info: modelInfo };
|
||||
return { id: modelId, info: modelInfo }
|
||||
}
|
||||
|
||||
// Fallback when no client is available
|
||||
const fallbackId = this.options.vsCodeLmModelSelector
|
||||
? stringifyVsCodeLmModelSelector(this.options.vsCodeLmModelSelector)
|
||||
: "vscode-lm";
|
||||
: "vscode-lm"
|
||||
|
||||
console.debug('Cline <Language Model API>: No client available, using fallback model info');
|
||||
console.debug("Cline <Language Model API>: No client available, using fallback model info")
|
||||
|
||||
return {
|
||||
id: fallbackId,
|
||||
info: {
|
||||
...openAiModelInfoSaneDefaults,
|
||||
description: `VSCode Language Model (Fallback): ${fallbackId}`
|
||||
}
|
||||
};
|
||||
description: `VSCode Language Model (Fallback): ${fallbackId}`,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async completePrompt(prompt: string): Promise<string> {
|
||||
try {
|
||||
const client = await this.getClient();
|
||||
const response = await client.sendRequest([vscode.LanguageModelChatMessage.User(prompt)], {}, new vscode.CancellationTokenSource().token);
|
||||
let result = "";
|
||||
const client = await this.getClient()
|
||||
const response = await client.sendRequest(
|
||||
[vscode.LanguageModelChatMessage.User(prompt)],
|
||||
{},
|
||||
new vscode.CancellationTokenSource().token,
|
||||
)
|
||||
let result = ""
|
||||
for await (const chunk of response.stream) {
|
||||
if (chunk instanceof vscode.LanguageModelTextPart) {
|
||||
result += chunk.value;
|
||||
result += chunk.value
|
||||
}
|
||||
}
|
||||
return result;
|
||||
return result
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`VSCode LM completion error: ${error.message}`)
|
||||
|
||||
@@ -1,252 +1,250 @@
|
||||
import { convertToBedrockConverseMessages, convertToAnthropicMessage } from '../bedrock-converse-format'
|
||||
import { Anthropic } from '@anthropic-ai/sdk'
|
||||
import { ContentBlock, ToolResultContentBlock } from '@aws-sdk/client-bedrock-runtime'
|
||||
import { StreamEvent } from '../../providers/bedrock'
|
||||
import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../bedrock-converse-format"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { ContentBlock, ToolResultContentBlock } from "@aws-sdk/client-bedrock-runtime"
|
||||
import { StreamEvent } from "../../providers/bedrock"
|
||||
|
||||
describe('bedrock-converse-format', () => {
|
||||
describe('convertToBedrockConverseMessages', () => {
|
||||
test('converts simple text messages correctly', () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{ role: 'user', content: 'Hello' },
|
||||
{ role: 'assistant', content: 'Hi there' }
|
||||
]
|
||||
describe("bedrock-converse-format", () => {
|
||||
describe("convertToBedrockConverseMessages", () => {
|
||||
test("converts simple text messages correctly", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{ role: "user", content: "Hello" },
|
||||
{ role: "assistant", content: "Hi there" },
|
||||
]
|
||||
|
||||
const result = convertToBedrockConverseMessages(messages)
|
||||
const result = convertToBedrockConverseMessages(messages)
|
||||
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ text: 'Hello' }]
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [{ text: 'Hi there' }]
|
||||
}
|
||||
])
|
||||
})
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: "user",
|
||||
content: [{ text: "Hello" }],
|
||||
},
|
||||
{
|
||||
role: "assistant",
|
||||
content: [{ text: "Hi there" }],
|
||||
},
|
||||
])
|
||||
})
|
||||
|
||||
test('converts messages with images correctly', () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Look at this image:'
|
||||
},
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
type: 'base64',
|
||||
data: 'SGVsbG8=', // "Hello" in base64
|
||||
media_type: 'image/jpeg' as const
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
test("converts messages with images correctly", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Look at this image:",
|
||||
},
|
||||
{
|
||||
type: "image",
|
||||
source: {
|
||||
type: "base64",
|
||||
data: "SGVsbG8=", // "Hello" in base64
|
||||
media_type: "image/jpeg" as const,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
const result = convertToBedrockConverseMessages(messages)
|
||||
const result = convertToBedrockConverseMessages(messages)
|
||||
|
||||
if (!result[0] || !result[0].content) {
|
||||
fail('Expected result to have content')
|
||||
return
|
||||
}
|
||||
if (!result[0] || !result[0].content) {
|
||||
fail("Expected result to have content")
|
||||
return
|
||||
}
|
||||
|
||||
expect(result[0].role).toBe('user')
|
||||
expect(result[0].content).toHaveLength(2)
|
||||
expect(result[0].content[0]).toEqual({ text: 'Look at this image:' })
|
||||
|
||||
const imageBlock = result[0].content[1] as ContentBlock
|
||||
if ('image' in imageBlock && imageBlock.image && imageBlock.image.source) {
|
||||
expect(imageBlock.image.format).toBe('jpeg')
|
||||
expect(imageBlock.image.source).toBeDefined()
|
||||
expect(imageBlock.image.source.bytes).toBeDefined()
|
||||
} else {
|
||||
fail('Expected image block not found')
|
||||
}
|
||||
})
|
||||
expect(result[0].role).toBe("user")
|
||||
expect(result[0].content).toHaveLength(2)
|
||||
expect(result[0].content[0]).toEqual({ text: "Look at this image:" })
|
||||
|
||||
test('converts tool use messages correctly', () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_use',
|
||||
id: 'test-id',
|
||||
name: 'read_file',
|
||||
input: {
|
||||
path: 'test.txt'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
const imageBlock = result[0].content[1] as ContentBlock
|
||||
if ("image" in imageBlock && imageBlock.image && imageBlock.image.source) {
|
||||
expect(imageBlock.image.format).toBe("jpeg")
|
||||
expect(imageBlock.image.source).toBeDefined()
|
||||
expect(imageBlock.image.source.bytes).toBeDefined()
|
||||
} else {
|
||||
fail("Expected image block not found")
|
||||
}
|
||||
})
|
||||
|
||||
const result = convertToBedrockConverseMessages(messages)
|
||||
test("converts tool use messages correctly", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "tool_use",
|
||||
id: "test-id",
|
||||
name: "read_file",
|
||||
input: {
|
||||
path: "test.txt",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
if (!result[0] || !result[0].content) {
|
||||
fail('Expected result to have content')
|
||||
return
|
||||
}
|
||||
const result = convertToBedrockConverseMessages(messages)
|
||||
|
||||
expect(result[0].role).toBe('assistant')
|
||||
const toolBlock = result[0].content[0] as ContentBlock
|
||||
if ('toolUse' in toolBlock && toolBlock.toolUse) {
|
||||
expect(toolBlock.toolUse).toEqual({
|
||||
toolUseId: 'test-id',
|
||||
name: 'read_file',
|
||||
input: '<read_file>\n<path>\ntest.txt\n</path>\n</read_file>'
|
||||
})
|
||||
} else {
|
||||
fail('Expected tool use block not found')
|
||||
}
|
||||
})
|
||||
if (!result[0] || !result[0].content) {
|
||||
fail("Expected result to have content")
|
||||
return
|
||||
}
|
||||
|
||||
test('converts tool result messages correctly', () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: 'test-id',
|
||||
content: [{ type: 'text', text: 'File contents here' }]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
expect(result[0].role).toBe("assistant")
|
||||
const toolBlock = result[0].content[0] as ContentBlock
|
||||
if ("toolUse" in toolBlock && toolBlock.toolUse) {
|
||||
expect(toolBlock.toolUse).toEqual({
|
||||
toolUseId: "test-id",
|
||||
name: "read_file",
|
||||
input: "<read_file>\n<path>\ntest.txt\n</path>\n</read_file>",
|
||||
})
|
||||
} else {
|
||||
fail("Expected tool use block not found")
|
||||
}
|
||||
})
|
||||
|
||||
const result = convertToBedrockConverseMessages(messages)
|
||||
test("converts tool result messages correctly", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "tool_result",
|
||||
tool_use_id: "test-id",
|
||||
content: [{ type: "text", text: "File contents here" }],
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
if (!result[0] || !result[0].content) {
|
||||
fail('Expected result to have content')
|
||||
return
|
||||
}
|
||||
const result = convertToBedrockConverseMessages(messages)
|
||||
|
||||
expect(result[0].role).toBe('assistant')
|
||||
const resultBlock = result[0].content[0] as ContentBlock
|
||||
if ('toolResult' in resultBlock && resultBlock.toolResult) {
|
||||
const expectedContent: ToolResultContentBlock[] = [
|
||||
{ text: 'File contents here' }
|
||||
]
|
||||
expect(resultBlock.toolResult).toEqual({
|
||||
toolUseId: 'test-id',
|
||||
content: expectedContent,
|
||||
status: 'success'
|
||||
})
|
||||
} else {
|
||||
fail('Expected tool result block not found')
|
||||
}
|
||||
})
|
||||
if (!result[0] || !result[0].content) {
|
||||
fail("Expected result to have content")
|
||||
return
|
||||
}
|
||||
|
||||
test('handles text content correctly', () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Hello world'
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
expect(result[0].role).toBe("assistant")
|
||||
const resultBlock = result[0].content[0] as ContentBlock
|
||||
if ("toolResult" in resultBlock && resultBlock.toolResult) {
|
||||
const expectedContent: ToolResultContentBlock[] = [{ text: "File contents here" }]
|
||||
expect(resultBlock.toolResult).toEqual({
|
||||
toolUseId: "test-id",
|
||||
content: expectedContent,
|
||||
status: "success",
|
||||
})
|
||||
} else {
|
||||
fail("Expected tool result block not found")
|
||||
}
|
||||
})
|
||||
|
||||
const result = convertToBedrockConverseMessages(messages)
|
||||
test("handles text content correctly", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Hello world",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
if (!result[0] || !result[0].content) {
|
||||
fail('Expected result to have content')
|
||||
return
|
||||
}
|
||||
const result = convertToBedrockConverseMessages(messages)
|
||||
|
||||
expect(result[0].role).toBe('user')
|
||||
expect(result[0].content).toHaveLength(1)
|
||||
const textBlock = result[0].content[0] as ContentBlock
|
||||
expect(textBlock).toEqual({ text: 'Hello world' })
|
||||
})
|
||||
})
|
||||
if (!result[0] || !result[0].content) {
|
||||
fail("Expected result to have content")
|
||||
return
|
||||
}
|
||||
|
||||
describe('convertToAnthropicMessage', () => {
|
||||
test('converts metadata events correctly', () => {
|
||||
const event: StreamEvent = {
|
||||
metadata: {
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
outputTokens: 20
|
||||
}
|
||||
}
|
||||
}
|
||||
expect(result[0].role).toBe("user")
|
||||
expect(result[0].content).toHaveLength(1)
|
||||
const textBlock = result[0].content[0] as ContentBlock
|
||||
expect(textBlock).toEqual({ text: "Hello world" })
|
||||
})
|
||||
})
|
||||
|
||||
const result = convertToAnthropicMessage(event, 'test-model')
|
||||
describe("convertToAnthropicMessage", () => {
|
||||
test("converts metadata events correctly", () => {
|
||||
const event: StreamEvent = {
|
||||
metadata: {
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
expect(result).toEqual({
|
||||
id: '',
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
model: 'test-model',
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 20
|
||||
}
|
||||
})
|
||||
})
|
||||
const result = convertToAnthropicMessage(event, "test-model")
|
||||
|
||||
test('converts content block start events correctly', () => {
|
||||
const event: StreamEvent = {
|
||||
contentBlockStart: {
|
||||
start: {
|
||||
text: 'Hello'
|
||||
}
|
||||
}
|
||||
}
|
||||
expect(result).toEqual({
|
||||
id: "",
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
model: "test-model",
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 20,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
const result = convertToAnthropicMessage(event, 'test-model')
|
||||
test("converts content block start events correctly", () => {
|
||||
const event: StreamEvent = {
|
||||
contentBlockStart: {
|
||||
start: {
|
||||
text: "Hello",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Hello' }],
|
||||
model: 'test-model'
|
||||
})
|
||||
})
|
||||
const result = convertToAnthropicMessage(event, "test-model")
|
||||
|
||||
test('converts content block delta events correctly', () => {
|
||||
const event: StreamEvent = {
|
||||
contentBlockDelta: {
|
||||
delta: {
|
||||
text: ' world'
|
||||
}
|
||||
}
|
||||
}
|
||||
expect(result).toEqual({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "Hello" }],
|
||||
model: "test-model",
|
||||
})
|
||||
})
|
||||
|
||||
const result = convertToAnthropicMessage(event, 'test-model')
|
||||
test("converts content block delta events correctly", () => {
|
||||
const event: StreamEvent = {
|
||||
contentBlockDelta: {
|
||||
delta: {
|
||||
text: " world",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: ' world' }],
|
||||
model: 'test-model'
|
||||
})
|
||||
})
|
||||
const result = convertToAnthropicMessage(event, "test-model")
|
||||
|
||||
test('converts message stop events correctly', () => {
|
||||
const event: StreamEvent = {
|
||||
messageStop: {
|
||||
stopReason: 'end_turn' as const
|
||||
}
|
||||
}
|
||||
expect(result).toEqual({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: " world" }],
|
||||
model: "test-model",
|
||||
})
|
||||
})
|
||||
|
||||
const result = convertToAnthropicMessage(event, 'test-model')
|
||||
test("converts message stop events correctly", () => {
|
||||
const event: StreamEvent = {
|
||||
messageStop: {
|
||||
stopReason: "end_turn" as const,
|
||||
},
|
||||
}
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
stop_reason: 'end_turn',
|
||||
stop_sequence: null,
|
||||
model: 'test-model'
|
||||
})
|
||||
})
|
||||
})
|
||||
const result = convertToAnthropicMessage(event, "test-model")
|
||||
|
||||
expect(result).toEqual({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
stop_reason: "end_turn",
|
||||
stop_sequence: null,
|
||||
model: "test-model",
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,257 +1,275 @@
|
||||
import { convertToOpenAiMessages, convertToAnthropicMessage } from '../openai-format';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import OpenAI from 'openai';
|
||||
import { convertToOpenAiMessages, convertToAnthropicMessage } from "../openai-format"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import OpenAI from "openai"
|
||||
|
||||
type PartialChatCompletion = Omit<OpenAI.Chat.Completions.ChatCompletion, 'choices'> & {
|
||||
choices: Array<Partial<OpenAI.Chat.Completions.ChatCompletion.Choice> & {
|
||||
message: OpenAI.Chat.Completions.ChatCompletion.Choice['message'];
|
||||
finish_reason: string;
|
||||
index: number;
|
||||
}>;
|
||||
};
|
||||
type PartialChatCompletion = Omit<OpenAI.Chat.Completions.ChatCompletion, "choices"> & {
|
||||
choices: Array<
|
||||
Partial<OpenAI.Chat.Completions.ChatCompletion.Choice> & {
|
||||
message: OpenAI.Chat.Completions.ChatCompletion.Choice["message"]
|
||||
finish_reason: string
|
||||
index: number
|
||||
}
|
||||
>
|
||||
}
|
||||
|
||||
describe('OpenAI Format Transformations', () => {
|
||||
describe('convertToOpenAiMessages', () => {
|
||||
it('should convert simple text messages', () => {
|
||||
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello'
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi there!'
|
||||
}
|
||||
];
|
||||
describe("OpenAI Format Transformations", () => {
|
||||
describe("convertToOpenAiMessages", () => {
|
||||
it("should convert simple text messages", () => {
|
||||
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello",
|
||||
},
|
||||
{
|
||||
role: "assistant",
|
||||
content: "Hi there!",
|
||||
},
|
||||
]
|
||||
|
||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages);
|
||||
expect(openAiMessages).toHaveLength(2);
|
||||
expect(openAiMessages[0]).toEqual({
|
||||
role: 'user',
|
||||
content: 'Hello'
|
||||
});
|
||||
expect(openAiMessages[1]).toEqual({
|
||||
role: 'assistant',
|
||||
content: 'Hi there!'
|
||||
});
|
||||
});
|
||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages)
|
||||
expect(openAiMessages).toHaveLength(2)
|
||||
expect(openAiMessages[0]).toEqual({
|
||||
role: "user",
|
||||
content: "Hello",
|
||||
})
|
||||
expect(openAiMessages[1]).toEqual({
|
||||
role: "assistant",
|
||||
content: "Hi there!",
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle messages with image content', () => {
|
||||
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'What is in this image?'
|
||||
},
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
type: 'base64',
|
||||
media_type: 'image/jpeg',
|
||||
data: 'base64data'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
];
|
||||
it("should handle messages with image content", () => {
|
||||
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "What is in this image?",
|
||||
},
|
||||
{
|
||||
type: "image",
|
||||
source: {
|
||||
type: "base64",
|
||||
media_type: "image/jpeg",
|
||||
data: "base64data",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages);
|
||||
expect(openAiMessages).toHaveLength(1);
|
||||
expect(openAiMessages[0].role).toBe('user');
|
||||
|
||||
const content = openAiMessages[0].content as Array<{
|
||||
type: string;
|
||||
text?: string;
|
||||
image_url?: { url: string };
|
||||
}>;
|
||||
|
||||
expect(Array.isArray(content)).toBe(true);
|
||||
expect(content).toHaveLength(2);
|
||||
expect(content[0]).toEqual({ type: 'text', text: 'What is in this image?' });
|
||||
expect(content[1]).toEqual({
|
||||
type: 'image_url',
|
||||
image_url: { url: 'data:image/jpeg;base64,base64data' }
|
||||
});
|
||||
});
|
||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages)
|
||||
expect(openAiMessages).toHaveLength(1)
|
||||
expect(openAiMessages[0].role).toBe("user")
|
||||
|
||||
it('should handle assistant messages with tool use', () => {
|
||||
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Let me check the weather.'
|
||||
},
|
||||
{
|
||||
type: 'tool_use',
|
||||
id: 'weather-123',
|
||||
name: 'get_weather',
|
||||
input: { city: 'London' }
|
||||
}
|
||||
]
|
||||
}
|
||||
];
|
||||
const content = openAiMessages[0].content as Array<{
|
||||
type: string
|
||||
text?: string
|
||||
image_url?: { url: string }
|
||||
}>
|
||||
|
||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages);
|
||||
expect(openAiMessages).toHaveLength(1);
|
||||
|
||||
const assistantMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionAssistantMessageParam;
|
||||
expect(assistantMessage.role).toBe('assistant');
|
||||
expect(assistantMessage.content).toBe('Let me check the weather.');
|
||||
expect(assistantMessage.tool_calls).toHaveLength(1);
|
||||
expect(assistantMessage.tool_calls![0]).toEqual({
|
||||
id: 'weather-123',
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'get_weather',
|
||||
arguments: JSON.stringify({ city: 'London' })
|
||||
}
|
||||
});
|
||||
});
|
||||
expect(Array.isArray(content)).toBe(true)
|
||||
expect(content).toHaveLength(2)
|
||||
expect(content[0]).toEqual({ type: "text", text: "What is in this image?" })
|
||||
expect(content[1]).toEqual({
|
||||
type: "image_url",
|
||||
image_url: { url: "data:image/jpeg;base64,base64data" },
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle user messages with tool results', () => {
|
||||
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: 'weather-123',
|
||||
content: 'Current temperature in London: 20°C'
|
||||
}
|
||||
]
|
||||
}
|
||||
];
|
||||
it("should handle assistant messages with tool use", () => {
|
||||
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Let me check the weather.",
|
||||
},
|
||||
{
|
||||
type: "tool_use",
|
||||
id: "weather-123",
|
||||
name: "get_weather",
|
||||
input: { city: "London" },
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages);
|
||||
expect(openAiMessages).toHaveLength(1);
|
||||
|
||||
const toolMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionToolMessageParam;
|
||||
expect(toolMessage.role).toBe('tool');
|
||||
expect(toolMessage.tool_call_id).toBe('weather-123');
|
||||
expect(toolMessage.content).toBe('Current temperature in London: 20°C');
|
||||
});
|
||||
});
|
||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages)
|
||||
expect(openAiMessages).toHaveLength(1)
|
||||
|
||||
describe('convertToAnthropicMessage', () => {
|
||||
it('should convert simple completion', () => {
|
||||
const openAiCompletion: PartialChatCompletion = {
|
||||
id: 'completion-123',
|
||||
model: 'gpt-4',
|
||||
choices: [{
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content: 'Hello there!',
|
||||
refusal: null
|
||||
},
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
},
|
||||
created: 123456789,
|
||||
object: 'chat.completion'
|
||||
};
|
||||
const assistantMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionAssistantMessageParam
|
||||
expect(assistantMessage.role).toBe("assistant")
|
||||
expect(assistantMessage.content).toBe("Let me check the weather.")
|
||||
expect(assistantMessage.tool_calls).toHaveLength(1)
|
||||
expect(assistantMessage.tool_calls![0]).toEqual({
|
||||
id: "weather-123",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "get_weather",
|
||||
arguments: JSON.stringify({ city: "London" }),
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion);
|
||||
expect(anthropicMessage.id).toBe('completion-123');
|
||||
expect(anthropicMessage.role).toBe('assistant');
|
||||
expect(anthropicMessage.content).toHaveLength(1);
|
||||
expect(anthropicMessage.content[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Hello there!'
|
||||
});
|
||||
expect(anthropicMessage.stop_reason).toBe('end_turn');
|
||||
expect(anthropicMessage.usage).toEqual({
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
});
|
||||
});
|
||||
it("should handle user messages with tool results", () => {
|
||||
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "tool_result",
|
||||
tool_use_id: "weather-123",
|
||||
content: "Current temperature in London: 20°C",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
it('should handle tool calls in completion', () => {
|
||||
const openAiCompletion: PartialChatCompletion = {
|
||||
id: 'completion-123',
|
||||
model: 'gpt-4',
|
||||
choices: [{
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content: 'Let me check the weather.',
|
||||
tool_calls: [{
|
||||
id: 'weather-123',
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'get_weather',
|
||||
arguments: '{"city":"London"}'
|
||||
}
|
||||
}],
|
||||
refusal: null
|
||||
},
|
||||
finish_reason: 'tool_calls',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 15,
|
||||
completion_tokens: 8,
|
||||
total_tokens: 23
|
||||
},
|
||||
created: 123456789,
|
||||
object: 'chat.completion'
|
||||
};
|
||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages)
|
||||
expect(openAiMessages).toHaveLength(1)
|
||||
|
||||
const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion);
|
||||
expect(anthropicMessage.content).toHaveLength(2);
|
||||
expect(anthropicMessage.content[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Let me check the weather.'
|
||||
});
|
||||
expect(anthropicMessage.content[1]).toEqual({
|
||||
type: 'tool_use',
|
||||
id: 'weather-123',
|
||||
name: 'get_weather',
|
||||
input: { city: 'London' }
|
||||
});
|
||||
expect(anthropicMessage.stop_reason).toBe('tool_use');
|
||||
});
|
||||
const toolMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionToolMessageParam
|
||||
expect(toolMessage.role).toBe("tool")
|
||||
expect(toolMessage.tool_call_id).toBe("weather-123")
|
||||
expect(toolMessage.content).toBe("Current temperature in London: 20°C")
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle invalid tool call arguments', () => {
|
||||
const openAiCompletion: PartialChatCompletion = {
|
||||
id: 'completion-123',
|
||||
model: 'gpt-4',
|
||||
choices: [{
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content: 'Testing invalid arguments',
|
||||
tool_calls: [{
|
||||
id: 'test-123',
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'test_function',
|
||||
arguments: 'invalid json'
|
||||
}
|
||||
}],
|
||||
refusal: null
|
||||
},
|
||||
finish_reason: 'tool_calls',
|
||||
index: 0
|
||||
}],
|
||||
created: 123456789,
|
||||
object: 'chat.completion'
|
||||
};
|
||||
describe("convertToAnthropicMessage", () => {
|
||||
it("should convert simple completion", () => {
|
||||
const openAiCompletion: PartialChatCompletion = {
|
||||
id: "completion-123",
|
||||
model: "gpt-4",
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: "Hello there!",
|
||||
refusal: null,
|
||||
},
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
created: 123456789,
|
||||
object: "chat.completion",
|
||||
}
|
||||
|
||||
const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion);
|
||||
expect(anthropicMessage.content).toHaveLength(2);
|
||||
expect(anthropicMessage.content[1]).toEqual({
|
||||
type: 'tool_use',
|
||||
id: 'test-123',
|
||||
name: 'test_function',
|
||||
input: {} // Should default to empty object for invalid JSON
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
const anthropicMessage = convertToAnthropicMessage(
|
||||
openAiCompletion as OpenAI.Chat.Completions.ChatCompletion,
|
||||
)
|
||||
expect(anthropicMessage.id).toBe("completion-123")
|
||||
expect(anthropicMessage.role).toBe("assistant")
|
||||
expect(anthropicMessage.content).toHaveLength(1)
|
||||
expect(anthropicMessage.content[0]).toEqual({
|
||||
type: "text",
|
||||
text: "Hello there!",
|
||||
})
|
||||
expect(anthropicMessage.stop_reason).toBe("end_turn")
|
||||
expect(anthropicMessage.usage).toEqual({
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
})
|
||||
})
|
||||
|
||||
it("should handle tool calls in completion", () => {
|
||||
const openAiCompletion: PartialChatCompletion = {
|
||||
id: "completion-123",
|
||||
model: "gpt-4",
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: "Let me check the weather.",
|
||||
tool_calls: [
|
||||
{
|
||||
id: "weather-123",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "get_weather",
|
||||
arguments: '{"city":"London"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
refusal: null,
|
||||
},
|
||||
finish_reason: "tool_calls",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 15,
|
||||
completion_tokens: 8,
|
||||
total_tokens: 23,
|
||||
},
|
||||
created: 123456789,
|
||||
object: "chat.completion",
|
||||
}
|
||||
|
||||
const anthropicMessage = convertToAnthropicMessage(
|
||||
openAiCompletion as OpenAI.Chat.Completions.ChatCompletion,
|
||||
)
|
||||
expect(anthropicMessage.content).toHaveLength(2)
|
||||
expect(anthropicMessage.content[0]).toEqual({
|
||||
type: "text",
|
||||
text: "Let me check the weather.",
|
||||
})
|
||||
expect(anthropicMessage.content[1]).toEqual({
|
||||
type: "tool_use",
|
||||
id: "weather-123",
|
||||
name: "get_weather",
|
||||
input: { city: "London" },
|
||||
})
|
||||
expect(anthropicMessage.stop_reason).toBe("tool_use")
|
||||
})
|
||||
|
||||
it("should handle invalid tool call arguments", () => {
|
||||
const openAiCompletion: PartialChatCompletion = {
|
||||
id: "completion-123",
|
||||
model: "gpt-4",
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: "Testing invalid arguments",
|
||||
tool_calls: [
|
||||
{
|
||||
id: "test-123",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "test_function",
|
||||
arguments: "invalid json",
|
||||
},
|
||||
},
|
||||
],
|
||||
refusal: null,
|
||||
},
|
||||
finish_reason: "tool_calls",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
created: 123456789,
|
||||
object: "chat.completion",
|
||||
}
|
||||
|
||||
const anthropicMessage = convertToAnthropicMessage(
|
||||
openAiCompletion as OpenAI.Chat.Completions.ChatCompletion,
|
||||
)
|
||||
expect(anthropicMessage.content).toHaveLength(2)
|
||||
expect(anthropicMessage.content[1]).toEqual({
|
||||
type: "tool_use",
|
||||
id: "test-123",
|
||||
name: "test_function",
|
||||
input: {}, // Should default to empty object for invalid JSON
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,114 +1,114 @@
|
||||
import { ApiStreamChunk } from '../stream';
|
||||
import { ApiStreamChunk } from "../stream"
|
||||
|
||||
describe('API Stream Types', () => {
|
||||
describe('ApiStreamChunk', () => {
|
||||
it('should correctly handle text chunks', () => {
|
||||
const textChunk: ApiStreamChunk = {
|
||||
type: 'text',
|
||||
text: 'Hello world'
|
||||
};
|
||||
describe("API Stream Types", () => {
|
||||
describe("ApiStreamChunk", () => {
|
||||
it("should correctly handle text chunks", () => {
|
||||
const textChunk: ApiStreamChunk = {
|
||||
type: "text",
|
||||
text: "Hello world",
|
||||
}
|
||||
|
||||
expect(textChunk.type).toBe('text');
|
||||
expect(textChunk.text).toBe('Hello world');
|
||||
});
|
||||
expect(textChunk.type).toBe("text")
|
||||
expect(textChunk.text).toBe("Hello world")
|
||||
})
|
||||
|
||||
it('should correctly handle usage chunks with cache information', () => {
|
||||
const usageChunk: ApiStreamChunk = {
|
||||
type: 'usage',
|
||||
inputTokens: 100,
|
||||
outputTokens: 50,
|
||||
cacheWriteTokens: 20,
|
||||
cacheReadTokens: 10
|
||||
};
|
||||
it("should correctly handle usage chunks with cache information", () => {
|
||||
const usageChunk: ApiStreamChunk = {
|
||||
type: "usage",
|
||||
inputTokens: 100,
|
||||
outputTokens: 50,
|
||||
cacheWriteTokens: 20,
|
||||
cacheReadTokens: 10,
|
||||
}
|
||||
|
||||
expect(usageChunk.type).toBe('usage');
|
||||
expect(usageChunk.inputTokens).toBe(100);
|
||||
expect(usageChunk.outputTokens).toBe(50);
|
||||
expect(usageChunk.cacheWriteTokens).toBe(20);
|
||||
expect(usageChunk.cacheReadTokens).toBe(10);
|
||||
});
|
||||
expect(usageChunk.type).toBe("usage")
|
||||
expect(usageChunk.inputTokens).toBe(100)
|
||||
expect(usageChunk.outputTokens).toBe(50)
|
||||
expect(usageChunk.cacheWriteTokens).toBe(20)
|
||||
expect(usageChunk.cacheReadTokens).toBe(10)
|
||||
})
|
||||
|
||||
it('should handle usage chunks without cache tokens', () => {
|
||||
const usageChunk: ApiStreamChunk = {
|
||||
type: 'usage',
|
||||
inputTokens: 100,
|
||||
outputTokens: 50
|
||||
};
|
||||
it("should handle usage chunks without cache tokens", () => {
|
||||
const usageChunk: ApiStreamChunk = {
|
||||
type: "usage",
|
||||
inputTokens: 100,
|
||||
outputTokens: 50,
|
||||
}
|
||||
|
||||
expect(usageChunk.type).toBe('usage');
|
||||
expect(usageChunk.inputTokens).toBe(100);
|
||||
expect(usageChunk.outputTokens).toBe(50);
|
||||
expect(usageChunk.cacheWriteTokens).toBeUndefined();
|
||||
expect(usageChunk.cacheReadTokens).toBeUndefined();
|
||||
});
|
||||
expect(usageChunk.type).toBe("usage")
|
||||
expect(usageChunk.inputTokens).toBe(100)
|
||||
expect(usageChunk.outputTokens).toBe(50)
|
||||
expect(usageChunk.cacheWriteTokens).toBeUndefined()
|
||||
expect(usageChunk.cacheReadTokens).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle text chunks with empty strings', () => {
|
||||
const emptyTextChunk: ApiStreamChunk = {
|
||||
type: 'text',
|
||||
text: ''
|
||||
};
|
||||
it("should handle text chunks with empty strings", () => {
|
||||
const emptyTextChunk: ApiStreamChunk = {
|
||||
type: "text",
|
||||
text: "",
|
||||
}
|
||||
|
||||
expect(emptyTextChunk.type).toBe('text');
|
||||
expect(emptyTextChunk.text).toBe('');
|
||||
});
|
||||
expect(emptyTextChunk.type).toBe("text")
|
||||
expect(emptyTextChunk.text).toBe("")
|
||||
})
|
||||
|
||||
it('should handle usage chunks with zero tokens', () => {
|
||||
const zeroUsageChunk: ApiStreamChunk = {
|
||||
type: 'usage',
|
||||
inputTokens: 0,
|
||||
outputTokens: 0
|
||||
};
|
||||
it("should handle usage chunks with zero tokens", () => {
|
||||
const zeroUsageChunk: ApiStreamChunk = {
|
||||
type: "usage",
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
}
|
||||
|
||||
expect(zeroUsageChunk.type).toBe('usage');
|
||||
expect(zeroUsageChunk.inputTokens).toBe(0);
|
||||
expect(zeroUsageChunk.outputTokens).toBe(0);
|
||||
});
|
||||
expect(zeroUsageChunk.type).toBe("usage")
|
||||
expect(zeroUsageChunk.inputTokens).toBe(0)
|
||||
expect(zeroUsageChunk.outputTokens).toBe(0)
|
||||
})
|
||||
|
||||
it('should handle usage chunks with large token counts', () => {
|
||||
const largeUsageChunk: ApiStreamChunk = {
|
||||
type: 'usage',
|
||||
inputTokens: 1000000,
|
||||
outputTokens: 500000,
|
||||
cacheWriteTokens: 200000,
|
||||
cacheReadTokens: 100000
|
||||
};
|
||||
it("should handle usage chunks with large token counts", () => {
|
||||
const largeUsageChunk: ApiStreamChunk = {
|
||||
type: "usage",
|
||||
inputTokens: 1000000,
|
||||
outputTokens: 500000,
|
||||
cacheWriteTokens: 200000,
|
||||
cacheReadTokens: 100000,
|
||||
}
|
||||
|
||||
expect(largeUsageChunk.type).toBe('usage');
|
||||
expect(largeUsageChunk.inputTokens).toBe(1000000);
|
||||
expect(largeUsageChunk.outputTokens).toBe(500000);
|
||||
expect(largeUsageChunk.cacheWriteTokens).toBe(200000);
|
||||
expect(largeUsageChunk.cacheReadTokens).toBe(100000);
|
||||
});
|
||||
expect(largeUsageChunk.type).toBe("usage")
|
||||
expect(largeUsageChunk.inputTokens).toBe(1000000)
|
||||
expect(largeUsageChunk.outputTokens).toBe(500000)
|
||||
expect(largeUsageChunk.cacheWriteTokens).toBe(200000)
|
||||
expect(largeUsageChunk.cacheReadTokens).toBe(100000)
|
||||
})
|
||||
|
||||
it('should handle text chunks with special characters', () => {
|
||||
const specialCharsChunk: ApiStreamChunk = {
|
||||
type: 'text',
|
||||
text: '!@#$%^&*()_+-=[]{}|;:,.<>?`~'
|
||||
};
|
||||
it("should handle text chunks with special characters", () => {
|
||||
const specialCharsChunk: ApiStreamChunk = {
|
||||
type: "text",
|
||||
text: "!@#$%^&*()_+-=[]{}|;:,.<>?`~",
|
||||
}
|
||||
|
||||
expect(specialCharsChunk.type).toBe('text');
|
||||
expect(specialCharsChunk.text).toBe('!@#$%^&*()_+-=[]{}|;:,.<>?`~');
|
||||
});
|
||||
expect(specialCharsChunk.type).toBe("text")
|
||||
expect(specialCharsChunk.text).toBe("!@#$%^&*()_+-=[]{}|;:,.<>?`~")
|
||||
})
|
||||
|
||||
it('should handle text chunks with unicode characters', () => {
|
||||
const unicodeChunk: ApiStreamChunk = {
|
||||
type: 'text',
|
||||
text: '你好世界👋🌍'
|
||||
};
|
||||
it("should handle text chunks with unicode characters", () => {
|
||||
const unicodeChunk: ApiStreamChunk = {
|
||||
type: "text",
|
||||
text: "你好世界👋🌍",
|
||||
}
|
||||
|
||||
expect(unicodeChunk.type).toBe('text');
|
||||
expect(unicodeChunk.text).toBe('你好世界👋🌍');
|
||||
});
|
||||
expect(unicodeChunk.type).toBe("text")
|
||||
expect(unicodeChunk.text).toBe("你好世界👋🌍")
|
||||
})
|
||||
|
||||
it('should handle text chunks with multiline content', () => {
|
||||
const multilineChunk: ApiStreamChunk = {
|
||||
type: 'text',
|
||||
text: 'Line 1\nLine 2\nLine 3'
|
||||
};
|
||||
it("should handle text chunks with multiline content", () => {
|
||||
const multilineChunk: ApiStreamChunk = {
|
||||
type: "text",
|
||||
text: "Line 1\nLine 2\nLine 3",
|
||||
}
|
||||
|
||||
expect(multilineChunk.type).toBe('text');
|
||||
expect(multilineChunk.text).toBe('Line 1\nLine 2\nLine 3');
|
||||
expect(multilineChunk.text.split('\n')).toHaveLength(3);
|
||||
});
|
||||
});
|
||||
});
|
||||
expect(multilineChunk.type).toBe("text")
|
||||
expect(multilineChunk.text).toBe("Line 1\nLine 2\nLine 3")
|
||||
expect(multilineChunk.text.split("\n")).toHaveLength(3)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,66 +1,66 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk";
|
||||
import * as vscode from 'vscode';
|
||||
import { convertToVsCodeLmMessages, convertToAnthropicRole, convertToAnthropicMessage } from '../vscode-lm-format';
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import * as vscode from "vscode"
|
||||
import { convertToVsCodeLmMessages, convertToAnthropicRole, convertToAnthropicMessage } from "../vscode-lm-format"
|
||||
|
||||
// Mock crypto
|
||||
const mockCrypto = {
|
||||
randomUUID: () => 'test-uuid'
|
||||
};
|
||||
global.crypto = mockCrypto as any;
|
||||
randomUUID: () => "test-uuid",
|
||||
}
|
||||
global.crypto = mockCrypto as any
|
||||
|
||||
// Define types for our mocked classes
|
||||
interface MockLanguageModelTextPart {
|
||||
type: 'text';
|
||||
value: string;
|
||||
type: "text"
|
||||
value: string
|
||||
}
|
||||
|
||||
interface MockLanguageModelToolCallPart {
|
||||
type: 'tool_call';
|
||||
callId: string;
|
||||
name: string;
|
||||
input: any;
|
||||
type: "tool_call"
|
||||
callId: string
|
||||
name: string
|
||||
input: any
|
||||
}
|
||||
|
||||
interface MockLanguageModelToolResultPart {
|
||||
type: 'tool_result';
|
||||
toolUseId: string;
|
||||
parts: MockLanguageModelTextPart[];
|
||||
type: "tool_result"
|
||||
toolUseId: string
|
||||
parts: MockLanguageModelTextPart[]
|
||||
}
|
||||
|
||||
type MockMessageContent = MockLanguageModelTextPart | MockLanguageModelToolCallPart | MockLanguageModelToolResultPart;
|
||||
type MockMessageContent = MockLanguageModelTextPart | MockLanguageModelToolCallPart | MockLanguageModelToolResultPart
|
||||
|
||||
interface MockLanguageModelChatMessage {
|
||||
role: string;
|
||||
name?: string;
|
||||
content: MockMessageContent[];
|
||||
role: string
|
||||
name?: string
|
||||
content: MockMessageContent[]
|
||||
}
|
||||
|
||||
// Mock vscode namespace
|
||||
jest.mock('vscode', () => {
|
||||
jest.mock("vscode", () => {
|
||||
const LanguageModelChatMessageRole = {
|
||||
Assistant: 'assistant',
|
||||
User: 'user'
|
||||
};
|
||||
Assistant: "assistant",
|
||||
User: "user",
|
||||
}
|
||||
|
||||
class MockLanguageModelTextPart {
|
||||
type = 'text';
|
||||
type = "text"
|
||||
constructor(public value: string) {}
|
||||
}
|
||||
|
||||
class MockLanguageModelToolCallPart {
|
||||
type = 'tool_call';
|
||||
type = "tool_call"
|
||||
constructor(
|
||||
public callId: string,
|
||||
public name: string,
|
||||
public input: any
|
||||
public input: any,
|
||||
) {}
|
||||
}
|
||||
|
||||
class MockLanguageModelToolResultPart {
|
||||
type = 'tool_result';
|
||||
type = "tool_result"
|
||||
constructor(
|
||||
public toolUseId: string,
|
||||
public parts: MockLanguageModelTextPart[]
|
||||
public parts: MockLanguageModelTextPart[],
|
||||
) {}
|
||||
}
|
||||
|
||||
@@ -68,179 +68,189 @@ jest.mock('vscode', () => {
|
||||
LanguageModelChatMessage: {
|
||||
Assistant: jest.fn((content) => ({
|
||||
role: LanguageModelChatMessageRole.Assistant,
|
||||
name: 'assistant',
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
|
||||
name: "assistant",
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
|
||||
})),
|
||||
User: jest.fn((content) => ({
|
||||
role: LanguageModelChatMessageRole.User,
|
||||
name: 'user',
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
|
||||
}))
|
||||
name: "user",
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
|
||||
})),
|
||||
},
|
||||
LanguageModelChatMessageRole,
|
||||
LanguageModelTextPart: MockLanguageModelTextPart,
|
||||
LanguageModelToolCallPart: MockLanguageModelToolCallPart,
|
||||
LanguageModelToolResultPart: MockLanguageModelToolResultPart
|
||||
};
|
||||
});
|
||||
LanguageModelToolResultPart: MockLanguageModelToolResultPart,
|
||||
}
|
||||
})
|
||||
|
||||
describe('vscode-lm-format', () => {
|
||||
describe('convertToVsCodeLmMessages', () => {
|
||||
it('should convert simple string messages', () => {
|
||||
describe("vscode-lm-format", () => {
|
||||
describe("convertToVsCodeLmMessages", () => {
|
||||
it("should convert simple string messages", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{ role: 'user', content: 'Hello' },
|
||||
{ role: 'assistant', content: 'Hi there' }
|
||||
];
|
||||
{ role: "user", content: "Hello" },
|
||||
{ role: "assistant", content: "Hi there" },
|
||||
]
|
||||
|
||||
const result = convertToVsCodeLmMessages(messages);
|
||||
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0].role).toBe('user');
|
||||
expect((result[0].content[0] as MockLanguageModelTextPart).value).toBe('Hello');
|
||||
expect(result[1].role).toBe('assistant');
|
||||
expect((result[1].content[0] as MockLanguageModelTextPart).value).toBe('Hi there');
|
||||
});
|
||||
const result = convertToVsCodeLmMessages(messages)
|
||||
|
||||
it('should handle complex user messages with tool results', () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Here is the result:' },
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: 'tool-1',
|
||||
content: 'Tool output'
|
||||
}
|
||||
]
|
||||
}];
|
||||
expect(result).toHaveLength(2)
|
||||
expect(result[0].role).toBe("user")
|
||||
expect((result[0].content[0] as MockLanguageModelTextPart).value).toBe("Hello")
|
||||
expect(result[1].role).toBe("assistant")
|
||||
expect((result[1].content[0] as MockLanguageModelTextPart).value).toBe("Hi there")
|
||||
})
|
||||
|
||||
const result = convertToVsCodeLmMessages(messages);
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].role).toBe('user');
|
||||
expect(result[0].content).toHaveLength(2);
|
||||
const [toolResult, textContent] = result[0].content as [MockLanguageModelToolResultPart, MockLanguageModelTextPart];
|
||||
expect(toolResult.type).toBe('tool_result');
|
||||
expect(textContent.type).toBe('text');
|
||||
});
|
||||
it("should handle complex user messages with tool results", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: "Here is the result:" },
|
||||
{
|
||||
type: "tool_result",
|
||||
tool_use_id: "tool-1",
|
||||
content: "Tool output",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
it('should handle complex assistant messages with tool calls', () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{ type: 'text', text: 'Let me help you with that.' },
|
||||
{
|
||||
type: 'tool_use',
|
||||
id: 'tool-1',
|
||||
name: 'calculator',
|
||||
input: { operation: 'add', numbers: [2, 2] }
|
||||
}
|
||||
]
|
||||
}];
|
||||
const result = convertToVsCodeLmMessages(messages)
|
||||
|
||||
const result = convertToVsCodeLmMessages(messages);
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].role).toBe('assistant');
|
||||
expect(result[0].content).toHaveLength(2);
|
||||
const [toolCall, textContent] = result[0].content as [MockLanguageModelToolCallPart, MockLanguageModelTextPart];
|
||||
expect(toolCall.type).toBe('tool_call');
|
||||
expect(textContent.type).toBe('text');
|
||||
});
|
||||
expect(result).toHaveLength(1)
|
||||
expect(result[0].role).toBe("user")
|
||||
expect(result[0].content).toHaveLength(2)
|
||||
const [toolResult, textContent] = result[0].content as [
|
||||
MockLanguageModelToolResultPart,
|
||||
MockLanguageModelTextPart,
|
||||
]
|
||||
expect(toolResult.type).toBe("tool_result")
|
||||
expect(textContent.type).toBe("text")
|
||||
})
|
||||
|
||||
it('should handle image blocks with appropriate placeholders', () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Look at this:' },
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
type: 'base64',
|
||||
media_type: 'image/png',
|
||||
data: 'base64data'
|
||||
}
|
||||
}
|
||||
]
|
||||
}];
|
||||
it("should handle complex assistant messages with tool calls", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "assistant",
|
||||
content: [
|
||||
{ type: "text", text: "Let me help you with that." },
|
||||
{
|
||||
type: "tool_use",
|
||||
id: "tool-1",
|
||||
name: "calculator",
|
||||
input: { operation: "add", numbers: [2, 2] },
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
const result = convertToVsCodeLmMessages(messages);
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
const imagePlaceholder = result[0].content[1] as MockLanguageModelTextPart;
|
||||
expect(imagePlaceholder.value).toContain('[Image (base64): image/png not supported by VSCode LM API]');
|
||||
});
|
||||
});
|
||||
const result = convertToVsCodeLmMessages(messages)
|
||||
|
||||
describe('convertToAnthropicRole', () => {
|
||||
it('should convert assistant role correctly', () => {
|
||||
const result = convertToAnthropicRole('assistant' as any);
|
||||
expect(result).toBe('assistant');
|
||||
});
|
||||
expect(result).toHaveLength(1)
|
||||
expect(result[0].role).toBe("assistant")
|
||||
expect(result[0].content).toHaveLength(2)
|
||||
const [toolCall, textContent] = result[0].content as [
|
||||
MockLanguageModelToolCallPart,
|
||||
MockLanguageModelTextPart,
|
||||
]
|
||||
expect(toolCall.type).toBe("tool_call")
|
||||
expect(textContent.type).toBe("text")
|
||||
})
|
||||
|
||||
it('should convert user role correctly', () => {
|
||||
const result = convertToAnthropicRole('user' as any);
|
||||
expect(result).toBe('user');
|
||||
});
|
||||
it("should handle image blocks with appropriate placeholders", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: "Look at this:" },
|
||||
{
|
||||
type: "image",
|
||||
source: {
|
||||
type: "base64",
|
||||
media_type: "image/png",
|
||||
data: "base64data",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
it('should return null for unknown roles', () => {
|
||||
const result = convertToAnthropicRole('unknown' as any);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
const result = convertToVsCodeLmMessages(messages)
|
||||
|
||||
describe('convertToAnthropicMessage', () => {
|
||||
it('should convert assistant message with text content', async () => {
|
||||
expect(result).toHaveLength(1)
|
||||
const imagePlaceholder = result[0].content[1] as MockLanguageModelTextPart
|
||||
expect(imagePlaceholder.value).toContain("[Image (base64): image/png not supported by VSCode LM API]")
|
||||
})
|
||||
})
|
||||
|
||||
describe("convertToAnthropicRole", () => {
|
||||
it("should convert assistant role correctly", () => {
|
||||
const result = convertToAnthropicRole("assistant" as any)
|
||||
expect(result).toBe("assistant")
|
||||
})
|
||||
|
||||
it("should convert user role correctly", () => {
|
||||
const result = convertToAnthropicRole("user" as any)
|
||||
expect(result).toBe("user")
|
||||
})
|
||||
|
||||
it("should return null for unknown roles", () => {
|
||||
const result = convertToAnthropicRole("unknown" as any)
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe("convertToAnthropicMessage", () => {
|
||||
it("should convert assistant message with text content", async () => {
|
||||
const vsCodeMessage = {
|
||||
role: 'assistant',
|
||||
name: 'assistant',
|
||||
content: [new vscode.LanguageModelTextPart('Hello')]
|
||||
};
|
||||
role: "assistant",
|
||||
name: "assistant",
|
||||
content: [new vscode.LanguageModelTextPart("Hello")],
|
||||
}
|
||||
|
||||
const result = await convertToAnthropicMessage(vsCodeMessage as any);
|
||||
|
||||
expect(result.role).toBe('assistant');
|
||||
expect(result.content).toHaveLength(1);
|
||||
const result = await convertToAnthropicMessage(vsCodeMessage as any)
|
||||
|
||||
expect(result.role).toBe("assistant")
|
||||
expect(result.content).toHaveLength(1)
|
||||
expect(result.content[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Hello'
|
||||
});
|
||||
expect(result.id).toBe('test-uuid');
|
||||
});
|
||||
type: "text",
|
||||
text: "Hello",
|
||||
})
|
||||
expect(result.id).toBe("test-uuid")
|
||||
})
|
||||
|
||||
it('should convert assistant message with tool calls', async () => {
|
||||
it("should convert assistant message with tool calls", async () => {
|
||||
const vsCodeMessage = {
|
||||
role: 'assistant',
|
||||
name: 'assistant',
|
||||
content: [new vscode.LanguageModelToolCallPart(
|
||||
'call-1',
|
||||
'calculator',
|
||||
{ operation: 'add', numbers: [2, 2] }
|
||||
)]
|
||||
};
|
||||
role: "assistant",
|
||||
name: "assistant",
|
||||
content: [
|
||||
new vscode.LanguageModelToolCallPart("call-1", "calculator", { operation: "add", numbers: [2, 2] }),
|
||||
],
|
||||
}
|
||||
|
||||
const result = await convertToAnthropicMessage(vsCodeMessage as any);
|
||||
|
||||
expect(result.content).toHaveLength(1);
|
||||
const result = await convertToAnthropicMessage(vsCodeMessage as any)
|
||||
|
||||
expect(result.content).toHaveLength(1)
|
||||
expect(result.content[0]).toEqual({
|
||||
type: 'tool_use',
|
||||
id: 'call-1',
|
||||
name: 'calculator',
|
||||
input: { operation: 'add', numbers: [2, 2] }
|
||||
});
|
||||
expect(result.id).toBe('test-uuid');
|
||||
});
|
||||
type: "tool_use",
|
||||
id: "call-1",
|
||||
name: "calculator",
|
||||
input: { operation: "add", numbers: [2, 2] },
|
||||
})
|
||||
expect(result.id).toBe("test-uuid")
|
||||
})
|
||||
|
||||
it('should throw error for non-assistant messages', async () => {
|
||||
it("should throw error for non-assistant messages", async () => {
|
||||
const vsCodeMessage = {
|
||||
role: 'user',
|
||||
name: 'user',
|
||||
content: [new vscode.LanguageModelTextPart('Hello')]
|
||||
};
|
||||
role: "user",
|
||||
name: "user",
|
||||
content: [new vscode.LanguageModelTextPart("Hello")],
|
||||
}
|
||||
|
||||
await expect(convertToAnthropicMessage(vsCodeMessage as any))
|
||||
.rejects
|
||||
.toThrow('Cline <Language Model API>: Only assistant messages are supported.');
|
||||
});
|
||||
});
|
||||
});
|
||||
await expect(convertToAnthropicMessage(vsCodeMessage as any)).rejects.toThrow(
|
||||
"Cline <Language Model API>: Only assistant messages are supported.",
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -8,210 +8,216 @@ import { StreamEvent } from "../providers/bedrock"
|
||||
/**
|
||||
* Convert Anthropic messages to Bedrock Converse format
|
||||
*/
|
||||
export function convertToBedrockConverseMessages(
|
||||
anthropicMessages: Anthropic.Messages.MessageParam[]
|
||||
): Message[] {
|
||||
return anthropicMessages.map(anthropicMessage => {
|
||||
// Map Anthropic roles to Bedrock roles
|
||||
const role: ConversationRole = anthropicMessage.role === "assistant" ? "assistant" : "user"
|
||||
export function convertToBedrockConverseMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): Message[] {
|
||||
return anthropicMessages.map((anthropicMessage) => {
|
||||
// Map Anthropic roles to Bedrock roles
|
||||
const role: ConversationRole = anthropicMessage.role === "assistant" ? "assistant" : "user"
|
||||
|
||||
if (typeof anthropicMessage.content === "string") {
|
||||
return {
|
||||
role,
|
||||
content: [{
|
||||
text: anthropicMessage.content
|
||||
}] as ContentBlock[]
|
||||
}
|
||||
}
|
||||
if (typeof anthropicMessage.content === "string") {
|
||||
return {
|
||||
role,
|
||||
content: [
|
||||
{
|
||||
text: anthropicMessage.content,
|
||||
},
|
||||
] as ContentBlock[],
|
||||
}
|
||||
}
|
||||
|
||||
// Process complex content types
|
||||
const content = anthropicMessage.content.map(block => {
|
||||
const messageBlock = block as MessageContent & {
|
||||
id?: string,
|
||||
tool_use_id?: string,
|
||||
content?: Array<{ type: string, text: string }>,
|
||||
output?: string | Array<{ type: string, text: string }>
|
||||
}
|
||||
// Process complex content types
|
||||
const content = anthropicMessage.content.map((block) => {
|
||||
const messageBlock = block as MessageContent & {
|
||||
id?: string
|
||||
tool_use_id?: string
|
||||
content?: Array<{ type: string; text: string }>
|
||||
output?: string | Array<{ type: string; text: string }>
|
||||
}
|
||||
|
||||
if (messageBlock.type === "text") {
|
||||
return {
|
||||
text: messageBlock.text || ''
|
||||
} as ContentBlock
|
||||
}
|
||||
|
||||
if (messageBlock.type === "image" && messageBlock.source) {
|
||||
// Convert base64 string to byte array if needed
|
||||
let byteArray: Uint8Array
|
||||
if (typeof messageBlock.source.data === 'string') {
|
||||
const binaryString = atob(messageBlock.source.data)
|
||||
byteArray = new Uint8Array(binaryString.length)
|
||||
for (let i = 0; i < binaryString.length; i++) {
|
||||
byteArray[i] = binaryString.charCodeAt(i)
|
||||
}
|
||||
} else {
|
||||
byteArray = messageBlock.source.data
|
||||
}
|
||||
if (messageBlock.type === "text") {
|
||||
return {
|
||||
text: messageBlock.text || "",
|
||||
} as ContentBlock
|
||||
}
|
||||
|
||||
// Extract format from media_type (e.g., "image/jpeg" -> "jpeg")
|
||||
const format = messageBlock.source.media_type.split('/')[1]
|
||||
if (!['png', 'jpeg', 'gif', 'webp'].includes(format)) {
|
||||
throw new Error(`Unsupported image format: ${format}`)
|
||||
}
|
||||
if (messageBlock.type === "image" && messageBlock.source) {
|
||||
// Convert base64 string to byte array if needed
|
||||
let byteArray: Uint8Array
|
||||
if (typeof messageBlock.source.data === "string") {
|
||||
const binaryString = atob(messageBlock.source.data)
|
||||
byteArray = new Uint8Array(binaryString.length)
|
||||
for (let i = 0; i < binaryString.length; i++) {
|
||||
byteArray[i] = binaryString.charCodeAt(i)
|
||||
}
|
||||
} else {
|
||||
byteArray = messageBlock.source.data
|
||||
}
|
||||
|
||||
return {
|
||||
image: {
|
||||
format: format as "png" | "jpeg" | "gif" | "webp",
|
||||
source: {
|
||||
bytes: byteArray
|
||||
}
|
||||
}
|
||||
} as ContentBlock
|
||||
}
|
||||
// Extract format from media_type (e.g., "image/jpeg" -> "jpeg")
|
||||
const format = messageBlock.source.media_type.split("/")[1]
|
||||
if (!["png", "jpeg", "gif", "webp"].includes(format)) {
|
||||
throw new Error(`Unsupported image format: ${format}`)
|
||||
}
|
||||
|
||||
if (messageBlock.type === "tool_use") {
|
||||
// Convert tool use to XML format
|
||||
const toolParams = Object.entries(messageBlock.input || {})
|
||||
.map(([key, value]) => `<${key}>\n${value}\n</${key}>`)
|
||||
.join('\n')
|
||||
return {
|
||||
image: {
|
||||
format: format as "png" | "jpeg" | "gif" | "webp",
|
||||
source: {
|
||||
bytes: byteArray,
|
||||
},
|
||||
},
|
||||
} as ContentBlock
|
||||
}
|
||||
|
||||
return {
|
||||
toolUse: {
|
||||
toolUseId: messageBlock.id || '',
|
||||
name: messageBlock.name || '',
|
||||
input: `<${messageBlock.name}>\n${toolParams}\n</${messageBlock.name}>`
|
||||
}
|
||||
} as ContentBlock
|
||||
}
|
||||
if (messageBlock.type === "tool_use") {
|
||||
// Convert tool use to XML format
|
||||
const toolParams = Object.entries(messageBlock.input || {})
|
||||
.map(([key, value]) => `<${key}>\n${value}\n</${key}>`)
|
||||
.join("\n")
|
||||
|
||||
if (messageBlock.type === "tool_result") {
|
||||
// First try to use content if available
|
||||
if (messageBlock.content && Array.isArray(messageBlock.content)) {
|
||||
return {
|
||||
toolResult: {
|
||||
toolUseId: messageBlock.tool_use_id || '',
|
||||
content: messageBlock.content.map(item => ({
|
||||
text: item.text
|
||||
})),
|
||||
status: "success"
|
||||
}
|
||||
} as ContentBlock
|
||||
}
|
||||
return {
|
||||
toolUse: {
|
||||
toolUseId: messageBlock.id || "",
|
||||
name: messageBlock.name || "",
|
||||
input: `<${messageBlock.name}>\n${toolParams}\n</${messageBlock.name}>`,
|
||||
},
|
||||
} as ContentBlock
|
||||
}
|
||||
|
||||
// Fall back to output handling if content is not available
|
||||
if (messageBlock.output && typeof messageBlock.output === "string") {
|
||||
return {
|
||||
toolResult: {
|
||||
toolUseId: messageBlock.tool_use_id || '',
|
||||
content: [{
|
||||
text: messageBlock.output
|
||||
}],
|
||||
status: "success"
|
||||
}
|
||||
} as ContentBlock
|
||||
}
|
||||
// Handle array of content blocks if output is an array
|
||||
if (Array.isArray(messageBlock.output)) {
|
||||
return {
|
||||
toolResult: {
|
||||
toolUseId: messageBlock.tool_use_id || '',
|
||||
content: messageBlock.output.map(part => {
|
||||
if (typeof part === "object" && "text" in part) {
|
||||
return { text: part.text }
|
||||
}
|
||||
// Skip images in tool results as they're handled separately
|
||||
if (typeof part === "object" && "type" in part && part.type === "image") {
|
||||
return { text: "(see following message for image)" }
|
||||
}
|
||||
return { text: String(part) }
|
||||
}),
|
||||
status: "success"
|
||||
}
|
||||
} as ContentBlock
|
||||
}
|
||||
if (messageBlock.type === "tool_result") {
|
||||
// First try to use content if available
|
||||
if (messageBlock.content && Array.isArray(messageBlock.content)) {
|
||||
return {
|
||||
toolResult: {
|
||||
toolUseId: messageBlock.tool_use_id || "",
|
||||
content: messageBlock.content.map((item) => ({
|
||||
text: item.text,
|
||||
})),
|
||||
status: "success",
|
||||
},
|
||||
} as ContentBlock
|
||||
}
|
||||
|
||||
// Default case
|
||||
return {
|
||||
toolResult: {
|
||||
toolUseId: messageBlock.tool_use_id || '',
|
||||
content: [{
|
||||
text: String(messageBlock.output || '')
|
||||
}],
|
||||
status: "success"
|
||||
}
|
||||
} as ContentBlock
|
||||
}
|
||||
// Fall back to output handling if content is not available
|
||||
if (messageBlock.output && typeof messageBlock.output === "string") {
|
||||
return {
|
||||
toolResult: {
|
||||
toolUseId: messageBlock.tool_use_id || "",
|
||||
content: [
|
||||
{
|
||||
text: messageBlock.output,
|
||||
},
|
||||
],
|
||||
status: "success",
|
||||
},
|
||||
} as ContentBlock
|
||||
}
|
||||
// Handle array of content blocks if output is an array
|
||||
if (Array.isArray(messageBlock.output)) {
|
||||
return {
|
||||
toolResult: {
|
||||
toolUseId: messageBlock.tool_use_id || "",
|
||||
content: messageBlock.output.map((part) => {
|
||||
if (typeof part === "object" && "text" in part) {
|
||||
return { text: part.text }
|
||||
}
|
||||
// Skip images in tool results as they're handled separately
|
||||
if (typeof part === "object" && "type" in part && part.type === "image") {
|
||||
return { text: "(see following message for image)" }
|
||||
}
|
||||
return { text: String(part) }
|
||||
}),
|
||||
status: "success",
|
||||
},
|
||||
} as ContentBlock
|
||||
}
|
||||
|
||||
if (messageBlock.type === "video") {
|
||||
const videoContent = messageBlock.s3Location ? {
|
||||
s3Location: {
|
||||
uri: messageBlock.s3Location.uri,
|
||||
bucketOwner: messageBlock.s3Location.bucketOwner
|
||||
}
|
||||
} : messageBlock.source
|
||||
// Default case
|
||||
return {
|
||||
toolResult: {
|
||||
toolUseId: messageBlock.tool_use_id || "",
|
||||
content: [
|
||||
{
|
||||
text: String(messageBlock.output || ""),
|
||||
},
|
||||
],
|
||||
status: "success",
|
||||
},
|
||||
} as ContentBlock
|
||||
}
|
||||
|
||||
return {
|
||||
video: {
|
||||
format: "mp4", // Default to mp4, adjust based on actual format if needed
|
||||
source: videoContent
|
||||
}
|
||||
} as ContentBlock
|
||||
}
|
||||
if (messageBlock.type === "video") {
|
||||
const videoContent = messageBlock.s3Location
|
||||
? {
|
||||
s3Location: {
|
||||
uri: messageBlock.s3Location.uri,
|
||||
bucketOwner: messageBlock.s3Location.bucketOwner,
|
||||
},
|
||||
}
|
||||
: messageBlock.source
|
||||
|
||||
// Default case for unknown block types
|
||||
return {
|
||||
text: '[Unknown Block Type]'
|
||||
} as ContentBlock
|
||||
})
|
||||
return {
|
||||
video: {
|
||||
format: "mp4", // Default to mp4, adjust based on actual format if needed
|
||||
source: videoContent,
|
||||
},
|
||||
} as ContentBlock
|
||||
}
|
||||
|
||||
return {
|
||||
role,
|
||||
content
|
||||
}
|
||||
})
|
||||
// Default case for unknown block types
|
||||
return {
|
||||
text: "[Unknown Block Type]",
|
||||
} as ContentBlock
|
||||
})
|
||||
|
||||
return {
|
||||
role,
|
||||
content,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert Bedrock Converse stream events to Anthropic message format
|
||||
*/
|
||||
export function convertToAnthropicMessage(
|
||||
streamEvent: StreamEvent,
|
||||
modelId: string
|
||||
streamEvent: StreamEvent,
|
||||
modelId: string,
|
||||
): Partial<Anthropic.Messages.Message> {
|
||||
// Handle metadata events
|
||||
if (streamEvent.metadata?.usage) {
|
||||
return {
|
||||
id: '', // Bedrock doesn't provide message IDs
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
model: modelId,
|
||||
usage: {
|
||||
input_tokens: streamEvent.metadata.usage.inputTokens || 0,
|
||||
output_tokens: streamEvent.metadata.usage.outputTokens || 0
|
||||
}
|
||||
}
|
||||
}
|
||||
// Handle metadata events
|
||||
if (streamEvent.metadata?.usage) {
|
||||
return {
|
||||
id: "", // Bedrock doesn't provide message IDs
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
model: modelId,
|
||||
usage: {
|
||||
input_tokens: streamEvent.metadata.usage.inputTokens || 0,
|
||||
output_tokens: streamEvent.metadata.usage.outputTokens || 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Handle content blocks
|
||||
const text = streamEvent.contentBlockStart?.start?.text || streamEvent.contentBlockDelta?.delta?.text
|
||||
if (text !== undefined) {
|
||||
return {
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: text }],
|
||||
model: modelId
|
||||
}
|
||||
}
|
||||
// Handle content blocks
|
||||
const text = streamEvent.contentBlockStart?.start?.text || streamEvent.contentBlockDelta?.delta?.text
|
||||
if (text !== undefined) {
|
||||
return {
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: text }],
|
||||
model: modelId,
|
||||
}
|
||||
}
|
||||
|
||||
// Handle message stop
|
||||
if (streamEvent.messageStop) {
|
||||
return {
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
stop_reason: streamEvent.messageStop.stopReason || null,
|
||||
stop_sequence: null,
|
||||
model: modelId
|
||||
}
|
||||
}
|
||||
// Handle message stop
|
||||
if (streamEvent.messageStop) {
|
||||
return {
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
stop_reason: streamEvent.messageStop.stopReason || null,
|
||||
stop_sequence: null,
|
||||
model: modelId,
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
return {}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk";
|
||||
import * as vscode from 'vscode';
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import * as vscode from "vscode"
|
||||
|
||||
/**
|
||||
* Safely converts a value into a plain object.
|
||||
@@ -7,30 +7,31 @@ import * as vscode from 'vscode';
|
||||
function asObjectSafe(value: any): object {
|
||||
// Handle null/undefined
|
||||
if (!value) {
|
||||
return {};
|
||||
return {}
|
||||
}
|
||||
|
||||
try {
|
||||
// Handle strings that might be JSON
|
||||
if (typeof value === 'string') {
|
||||
return JSON.parse(value);
|
||||
if (typeof value === "string") {
|
||||
return JSON.parse(value)
|
||||
}
|
||||
|
||||
// Handle pre-existing objects
|
||||
if (typeof value === 'object') {
|
||||
return Object.assign({}, value);
|
||||
if (typeof value === "object") {
|
||||
return Object.assign({}, value)
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
catch (error) {
|
||||
console.warn('Cline <Language Model API>: Failed to parse object:', error);
|
||||
return {};
|
||||
return {}
|
||||
} catch (error) {
|
||||
console.warn("Cline <Language Model API>: Failed to parse object:", error)
|
||||
return {}
|
||||
}
|
||||
}
|
||||
|
||||
export function convertToVsCodeLmMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): vscode.LanguageModelChatMessage[] {
|
||||
const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = [];
|
||||
export function convertToVsCodeLmMessages(
|
||||
anthropicMessages: Anthropic.Messages.MessageParam[],
|
||||
): vscode.LanguageModelChatMessage[] {
|
||||
const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = []
|
||||
|
||||
for (const anthropicMessage of anthropicMessages) {
|
||||
// Handle simple string messages
|
||||
@@ -38,135 +39,129 @@ export function convertToVsCodeLmMessages(anthropicMessages: Anthropic.Messages.
|
||||
vsCodeLmMessages.push(
|
||||
anthropicMessage.role === "assistant"
|
||||
? vscode.LanguageModelChatMessage.Assistant(anthropicMessage.content)
|
||||
: vscode.LanguageModelChatMessage.User(anthropicMessage.content)
|
||||
);
|
||||
continue;
|
||||
: vscode.LanguageModelChatMessage.User(anthropicMessage.content),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle complex message structures
|
||||
switch (anthropicMessage.role) {
|
||||
case "user": {
|
||||
const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
|
||||
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[];
|
||||
toolMessages: Anthropic.ToolResultBlockParam[];
|
||||
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
|
||||
toolMessages: Anthropic.ToolResultBlockParam[]
|
||||
}>(
|
||||
(acc, part) => {
|
||||
if (part.type === "tool_result") {
|
||||
acc.toolMessages.push(part);
|
||||
acc.toolMessages.push(part)
|
||||
} else if (part.type === "text" || part.type === "image") {
|
||||
acc.nonToolMessages.push(part)
|
||||
}
|
||||
else if (part.type === "text" || part.type === "image") {
|
||||
acc.nonToolMessages.push(part);
|
||||
}
|
||||
return acc;
|
||||
return acc
|
||||
},
|
||||
{ nonToolMessages: [], toolMessages: [] },
|
||||
);
|
||||
)
|
||||
|
||||
// Process tool messages first then non-tool messages
|
||||
const contentParts = [
|
||||
// Convert tool messages to ToolResultParts
|
||||
...toolMessages.map((toolMessage) => {
|
||||
// Process tool result content into TextParts
|
||||
const toolContentParts: vscode.LanguageModelTextPart[] = (
|
||||
const toolContentParts: vscode.LanguageModelTextPart[] =
|
||||
typeof toolMessage.content === "string"
|
||||
? [new vscode.LanguageModelTextPart(toolMessage.content)]
|
||||
: (
|
||||
toolMessage.content?.map((part) => {
|
||||
: (toolMessage.content?.map((part) => {
|
||||
if (part.type === "image") {
|
||||
return new vscode.LanguageModelTextPart(
|
||||
`[Image (${part.source?.type || 'Unknown source-type'}): ${part.source?.media_type || 'unknown media-type'} not supported by VSCode LM API]`
|
||||
);
|
||||
`[Image (${part.source?.type || "Unknown source-type"}): ${part.source?.media_type || "unknown media-type"} not supported by VSCode LM API]`,
|
||||
)
|
||||
}
|
||||
return new vscode.LanguageModelTextPart(part.text);
|
||||
})
|
||||
?? [new vscode.LanguageModelTextPart("")]
|
||||
)
|
||||
);
|
||||
return new vscode.LanguageModelTextPart(part.text)
|
||||
}) ?? [new vscode.LanguageModelTextPart("")])
|
||||
|
||||
return new vscode.LanguageModelToolResultPart(
|
||||
toolMessage.tool_use_id,
|
||||
toolContentParts
|
||||
);
|
||||
return new vscode.LanguageModelToolResultPart(toolMessage.tool_use_id, toolContentParts)
|
||||
}),
|
||||
|
||||
// Convert non-tool messages to TextParts after tool messages
|
||||
...nonToolMessages.map((part) => {
|
||||
if (part.type === "image") {
|
||||
return new vscode.LanguageModelTextPart(
|
||||
`[Image (${part.source?.type || 'Unknown source-type'}): ${part.source?.media_type || 'unknown media-type'} not supported by VSCode LM API]`
|
||||
);
|
||||
`[Image (${part.source?.type || "Unknown source-type"}): ${part.source?.media_type || "unknown media-type"} not supported by VSCode LM API]`,
|
||||
)
|
||||
}
|
||||
return new vscode.LanguageModelTextPart(part.text);
|
||||
})
|
||||
];
|
||||
return new vscode.LanguageModelTextPart(part.text)
|
||||
}),
|
||||
]
|
||||
|
||||
// Add single user message with all content parts
|
||||
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.User(contentParts));
|
||||
break;
|
||||
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.User(contentParts))
|
||||
break
|
||||
}
|
||||
|
||||
case "assistant": {
|
||||
const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
|
||||
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[];
|
||||
toolMessages: Anthropic.ToolUseBlockParam[];
|
||||
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
|
||||
toolMessages: Anthropic.ToolUseBlockParam[]
|
||||
}>(
|
||||
(acc, part) => {
|
||||
if (part.type === "tool_use") {
|
||||
acc.toolMessages.push(part);
|
||||
acc.toolMessages.push(part)
|
||||
} else if (part.type === "text" || part.type === "image") {
|
||||
acc.nonToolMessages.push(part)
|
||||
}
|
||||
else if (part.type === "text" || part.type === "image") {
|
||||
acc.nonToolMessages.push(part);
|
||||
}
|
||||
return acc;
|
||||
return acc
|
||||
},
|
||||
{ nonToolMessages: [], toolMessages: [] },
|
||||
);
|
||||
)
|
||||
|
||||
// Process tool messages first then non-tool messages
|
||||
// Process tool messages first then non-tool messages
|
||||
const contentParts = [
|
||||
// Convert tool messages to ToolCallParts first
|
||||
...toolMessages.map((toolMessage) =>
|
||||
new vscode.LanguageModelToolCallPart(
|
||||
toolMessage.id,
|
||||
toolMessage.name,
|
||||
asObjectSafe(toolMessage.input)
|
||||
)
|
||||
...toolMessages.map(
|
||||
(toolMessage) =>
|
||||
new vscode.LanguageModelToolCallPart(
|
||||
toolMessage.id,
|
||||
toolMessage.name,
|
||||
asObjectSafe(toolMessage.input),
|
||||
),
|
||||
),
|
||||
|
||||
// Convert non-tool messages to TextParts after tool messages
|
||||
...nonToolMessages.map((part) => {
|
||||
if (part.type === "image") {
|
||||
return new vscode.LanguageModelTextPart("[Image generation not supported by VSCode LM API]");
|
||||
return new vscode.LanguageModelTextPart("[Image generation not supported by VSCode LM API]")
|
||||
}
|
||||
return new vscode.LanguageModelTextPart(part.text);
|
||||
})
|
||||
];
|
||||
return new vscode.LanguageModelTextPart(part.text)
|
||||
}),
|
||||
]
|
||||
|
||||
// Add the assistant message to the list of messages
|
||||
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.Assistant(contentParts));
|
||||
break;
|
||||
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.Assistant(contentParts))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return vsCodeLmMessages;
|
||||
return vsCodeLmMessages
|
||||
}
|
||||
|
||||
export function convertToAnthropicRole(vsCodeLmMessageRole: vscode.LanguageModelChatMessageRole): string | null {
|
||||
switch (vsCodeLmMessageRole) {
|
||||
case vscode.LanguageModelChatMessageRole.Assistant:
|
||||
return "assistant";
|
||||
return "assistant"
|
||||
case vscode.LanguageModelChatMessageRole.User:
|
||||
return "user";
|
||||
return "user"
|
||||
default:
|
||||
return null;
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
export async function convertToAnthropicMessage(vsCodeLmMessage: vscode.LanguageModelChatMessage): Promise<Anthropic.Messages.Message> {
|
||||
const anthropicRole: string | null = convertToAnthropicRole(vsCodeLmMessage.role);
|
||||
export async function convertToAnthropicMessage(
|
||||
vsCodeLmMessage: vscode.LanguageModelChatMessage,
|
||||
): Promise<Anthropic.Messages.Message> {
|
||||
const anthropicRole: string | null = convertToAnthropicRole(vsCodeLmMessage.role)
|
||||
if (anthropicRole !== "assistant") {
|
||||
throw new Error("Cline <Language Model API>: Only assistant messages are supported.");
|
||||
throw new Error("Cline <Language Model API>: Only assistant messages are supported.")
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -174,36 +169,32 @@ export async function convertToAnthropicMessage(vsCodeLmMessage: vscode.Language
|
||||
type: "message",
|
||||
model: "vscode-lm",
|
||||
role: anthropicRole,
|
||||
content: (
|
||||
vsCodeLmMessage.content
|
||||
.map((part): Anthropic.ContentBlock | null => {
|
||||
if (part instanceof vscode.LanguageModelTextPart) {
|
||||
return {
|
||||
type: "text",
|
||||
text: part.value
|
||||
};
|
||||
content: vsCodeLmMessage.content
|
||||
.map((part): Anthropic.ContentBlock | null => {
|
||||
if (part instanceof vscode.LanguageModelTextPart) {
|
||||
return {
|
||||
type: "text",
|
||||
text: part.value,
|
||||
}
|
||||
}
|
||||
|
||||
if (part instanceof vscode.LanguageModelToolCallPart) {
|
||||
return {
|
||||
type: "tool_use",
|
||||
id: part.callId || crypto.randomUUID(),
|
||||
name: part.name,
|
||||
input: asObjectSafe(part.input)
|
||||
};
|
||||
if (part instanceof vscode.LanguageModelToolCallPart) {
|
||||
return {
|
||||
type: "tool_use",
|
||||
id: part.callId || crypto.randomUUID(),
|
||||
name: part.name,
|
||||
input: asObjectSafe(part.input),
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
})
|
||||
.filter(
|
||||
(part): part is Anthropic.ContentBlock => part !== null
|
||||
)
|
||||
),
|
||||
return null
|
||||
})
|
||||
.filter((part): part is Anthropic.ContentBlock => part !== null),
|
||||
stop_reason: null,
|
||||
stop_sequence: null,
|
||||
usage: {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
}
|
||||
};
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user