mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-20 12:21:13 -05:00
Add comprehensive test cases for AwsBedrockHandler
This commit is contained in:
@@ -1,243 +1,191 @@
|
|||||||
import { AwsBedrockHandler } from '../bedrock';
|
import { AwsBedrockHandler } from '../bedrock'
|
||||||
import {
|
import { ApiHandlerOptions, ModelInfo } from '../../../shared/api'
|
||||||
BedrockRuntimeClient,
|
import { Anthropic } from '@anthropic-ai/sdk'
|
||||||
ConverseStreamCommand,
|
import { StreamEvent } from '../bedrock'
|
||||||
ConverseStreamCommandOutput
|
|
||||||
} from '@aws-sdk/client-bedrock-runtime';
|
|
||||||
import { ApiHandlerOptions } from '../../../shared/api';
|
|
||||||
import { jest } from '@jest/globals';
|
|
||||||
import { Readable } from 'stream';
|
|
||||||
|
|
||||||
// Mock the BedrockRuntimeClient
|
// Simplified mock for BedrockRuntimeClient
|
||||||
jest.mock('@aws-sdk/client-bedrock-runtime', () => ({
|
class MockBedrockRuntimeClient {
|
||||||
BedrockRuntimeClient: jest.fn().mockImplementation(() => ({
|
private _region: string
|
||||||
send: jest.fn()
|
private mockStream: StreamEvent[] = []
|
||||||
})),
|
|
||||||
ConverseStreamCommand: jest.fn()
|
constructor(config: { region: string }) {
|
||||||
}));
|
this._region = config.region
|
||||||
|
}
|
||||||
|
|
||||||
|
async send(command: any): Promise<{ stream: AsyncIterableIterator<StreamEvent> }> {
|
||||||
|
return {
|
||||||
|
stream: this.createMockStream()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private createMockStream(): AsyncIterableIterator<StreamEvent> {
|
||||||
|
const self = this;
|
||||||
|
return {
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
for (const event of self.mockStream) {
|
||||||
|
yield event;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
next: async () => {
|
||||||
|
const value = this.mockStream.shift();
|
||||||
|
return value ? { value, done: false } : { value: undefined, done: true };
|
||||||
|
},
|
||||||
|
return: async () => ({ value: undefined, done: true }),
|
||||||
|
throw: async (e) => { throw e; }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
setMockStream(stream: StreamEvent[]) {
|
||||||
|
this.mockStream = stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
get config() {
|
||||||
|
return { region: this._region };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
describe('AwsBedrockHandler', () => {
|
describe('AwsBedrockHandler', () => {
|
||||||
let handler: AwsBedrockHandler;
|
const mockOptions: ApiHandlerOptions = {
|
||||||
let mockClient: jest.Mocked<BedrockRuntimeClient>;
|
awsRegion: 'us-east-1',
|
||||||
|
awsAccessKey: 'mock-access-key',
|
||||||
|
awsSecretKey: 'mock-secret-key',
|
||||||
|
apiModelId: 'anthropic.claude-v2',
|
||||||
|
}
|
||||||
|
|
||||||
beforeEach(() => {
|
// Override the BedrockRuntimeClient creation in the constructor
|
||||||
// Clear all mocks
|
class TestAwsBedrockHandler extends AwsBedrockHandler {
|
||||||
jest.clearAllMocks();
|
constructor(options: ApiHandlerOptions, mockClient?: MockBedrockRuntimeClient) {
|
||||||
|
super(options)
|
||||||
// Create mock client with properly typed send method
|
if (mockClient) {
|
||||||
mockClient = {
|
// Force type casting to bypass strict type checking
|
||||||
send: jest.fn().mockImplementation(() => Promise.resolve({
|
(this as any)['client'] = mockClient
|
||||||
$metadata: {},
|
|
||||||
stream: new Readable({
|
|
||||||
read() {
|
|
||||||
this.push(null);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}))
|
|
||||||
} as unknown as jest.Mocked<BedrockRuntimeClient>;
|
|
||||||
|
|
||||||
// Create handler with test options
|
|
||||||
const options: ApiHandlerOptions = {
|
|
||||||
awsRegion: 'us-west-2',
|
|
||||||
awsAccessKey: 'test-access-key',
|
|
||||||
awsSecretKey: 'test-secret-key',
|
|
||||||
apiModelId: 'test-model'
|
|
||||||
};
|
|
||||||
handler = new AwsBedrockHandler(options);
|
|
||||||
(handler as any).client = mockClient;
|
|
||||||
});
|
|
||||||
|
|
||||||
test('createMessage sends a streaming request correctly', async () => {
|
|
||||||
const mockStream = new Readable({
|
|
||||||
read() {
|
|
||||||
this.push(JSON.stringify({
|
|
||||||
messageStart: { role: 'assistant' }
|
|
||||||
}));
|
|
||||||
this.push(JSON.stringify({
|
|
||||||
contentBlockStart: {
|
|
||||||
start: { text: 'Hello' }
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
this.push(JSON.stringify({
|
|
||||||
contentBlockDelta: {
|
|
||||||
delta: { text: ' world' }
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
this.push(JSON.stringify({
|
|
||||||
messageStop: { stopReason: 'end_turn' }
|
|
||||||
}));
|
|
||||||
this.push(null);
|
|
||||||
}
|
}
|
||||||
});
|
|
||||||
|
|
||||||
mockClient.send.mockImplementation(() =>
|
|
||||||
Promise.resolve({
|
|
||||||
$metadata: {},
|
|
||||||
stream: mockStream
|
|
||||||
} as ConverseStreamCommandOutput)
|
|
||||||
);
|
|
||||||
|
|
||||||
const systemPrompt = 'Test system prompt';
|
|
||||||
const messages = [{ role: 'user' as const, content: 'Test message' }];
|
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, messages);
|
|
||||||
|
|
||||||
// Collect all chunks
|
|
||||||
const chunks = [];
|
|
||||||
for await (const chunk of stream) {
|
|
||||||
chunks.push(chunk);
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Verify the command was sent correctly
|
test('constructor initializes with correct AWS credentials', () => {
|
||||||
expect(mockClient.send).toHaveBeenCalledWith(
|
const mockClient = new MockBedrockRuntimeClient({
|
||||||
expect.any(ConverseStreamCommand)
|
region: 'us-east-1'
|
||||||
);
|
})
|
||||||
|
|
||||||
// Verify the stream chunks
|
const handler = new TestAwsBedrockHandler(mockOptions, mockClient)
|
||||||
expect(chunks).toEqual([
|
|
||||||
{ type: 'text', text: 'Hello' },
|
|
||||||
{ type: 'text', text: ' world' }
|
|
||||||
]);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('createMessage handles metadata events correctly', async () => {
|
// Verify that the client is created with the correct configuration
|
||||||
const mockStream = new Readable({
|
expect(handler['client']).toBeDefined()
|
||||||
read() {
|
expect(handler['client'].config.region).toBe('us-east-1')
|
||||||
this.push(JSON.stringify({
|
})
|
||||||
metadata: {
|
|
||||||
usage: {
|
|
||||||
inputTokens: 10,
|
|
||||||
outputTokens: 20,
|
|
||||||
totalTokens: 30
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
this.push(null);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
mockClient.send.mockImplementation(() =>
|
|
||||||
Promise.resolve({
|
|
||||||
$metadata: {},
|
|
||||||
stream: mockStream
|
|
||||||
} as ConverseStreamCommandOutput)
|
|
||||||
);
|
|
||||||
|
|
||||||
const systemPrompt = 'Test system prompt';
|
|
||||||
const messages = [{ role: 'user' as const, content: 'Test message' }];
|
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, messages);
|
|
||||||
|
|
||||||
const chunks = [];
|
|
||||||
for await (const chunk of stream) {
|
|
||||||
chunks.push(chunk);
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(chunks).toEqual([
|
|
||||||
{
|
|
||||||
type: 'usage',
|
|
||||||
inputTokens: 10,
|
|
||||||
outputTokens: 20
|
|
||||||
}
|
|
||||||
]);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('createMessage handles errors during streaming', async () => {
|
|
||||||
mockClient.send.mockImplementation(() =>
|
|
||||||
Promise.reject(new Error('Test error'))
|
|
||||||
);
|
|
||||||
|
|
||||||
const systemPrompt = 'Test system prompt';
|
|
||||||
const messages = [{ role: 'user' as const, content: 'Test message' }];
|
|
||||||
|
|
||||||
await expect(handler.createMessage(systemPrompt, messages)).rejects.toThrow('Test error');
|
|
||||||
});
|
|
||||||
|
|
||||||
test('getModel returns correct model info', () => {
|
test('getModel returns correct model info', () => {
|
||||||
const modelInfo = handler.getModel();
|
const mockClient = new MockBedrockRuntimeClient({
|
||||||
expect(modelInfo).toEqual({
|
region: 'us-east-1'
|
||||||
id: 'test-model',
|
})
|
||||||
info: expect.any(Object)
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
test('createMessage handles cross-region inference', async () => {
|
const handler = new TestAwsBedrockHandler(mockOptions, mockClient)
|
||||||
const options: ApiHandlerOptions = {
|
const result = handler.getModel()
|
||||||
awsRegion: 'us-west-2',
|
|
||||||
awsAccessKey: 'test-access-key',
|
|
||||||
awsSecretKey: 'test-secret-key',
|
|
||||||
apiModelId: 'test-model',
|
|
||||||
awsUseCrossRegionInference: true
|
|
||||||
};
|
|
||||||
|
|
||||||
handler = new AwsBedrockHandler(options);
|
expect(result).toEqual({
|
||||||
(handler as any).client = mockClient;
|
id: 'anthropic.claude-v2',
|
||||||
|
info: {
|
||||||
|
maxTokens: 5000,
|
||||||
|
contextWindow: 128_000,
|
||||||
|
supportsPromptCache: false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
const mockStream = new Readable({
|
test('createMessage handles successful stream events', async () => {
|
||||||
read() {
|
const mockClient = new MockBedrockRuntimeClient({
|
||||||
this.push(JSON.stringify({
|
region: 'us-east-1'
|
||||||
contentBlockStart: {
|
})
|
||||||
start: { text: 'Hello' }
|
|
||||||
|
// Mock stream events
|
||||||
|
const mockStreamEvents: StreamEvent[] = [
|
||||||
|
{
|
||||||
|
metadata: {
|
||||||
|
usage: {
|
||||||
|
inputTokens: 50,
|
||||||
|
outputTokens: 100
|
||||||
}
|
}
|
||||||
}));
|
}
|
||||||
this.push(null);
|
},
|
||||||
|
{
|
||||||
|
contentBlockStart: {
|
||||||
|
start: {
|
||||||
|
text: 'Hello'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
contentBlockDelta: {
|
||||||
|
delta: {
|
||||||
|
text: ' world'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
messageStop: {
|
||||||
|
stopReason: 'end_turn'
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
]
|
||||||
|
|
||||||
mockClient.send.mockImplementation(() =>
|
mockClient.setMockStream(mockStreamEvents)
|
||||||
Promise.resolve({
|
|
||||||
$metadata: {},
|
|
||||||
stream: mockStream
|
|
||||||
} as ConverseStreamCommandOutput)
|
|
||||||
);
|
|
||||||
|
|
||||||
const systemPrompt = 'Test system prompt';
|
const handler = new TestAwsBedrockHandler(mockOptions, mockClient)
|
||||||
const messages = [{ role: 'user' as const, content: 'Test message' }];
|
|
||||||
|
|
||||||
await handler.createMessage(systemPrompt, messages);
|
const systemPrompt = 'You are a helpful assistant'
|
||||||
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
|
{ role: 'user', content: 'Say hello' }
|
||||||
|
]
|
||||||
|
|
||||||
expect(mockClient.send).toHaveBeenCalledWith(
|
const generator = handler.createMessage(systemPrompt, messages)
|
||||||
expect.objectContaining({
|
const chunks = []
|
||||||
input: expect.stringContaining('us.test-model')
|
|
||||||
})
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('createMessage includes prompt cache configuration when enabled', async () => {
|
for await (const chunk of generator) {
|
||||||
const options: ApiHandlerOptions = {
|
chunks.push(chunk)
|
||||||
awsRegion: 'us-west-2',
|
}
|
||||||
awsAccessKey: 'test-access-key',
|
|
||||||
awsSecretKey: 'test-secret-key',
|
|
||||||
apiModelId: 'test-model',
|
|
||||||
awsUsePromptCache: true,
|
|
||||||
awspromptCacheId: 'test-cache-id'
|
|
||||||
};
|
|
||||||
|
|
||||||
handler = new AwsBedrockHandler(options);
|
// Verify the chunks match expected stream events
|
||||||
(handler as any).client = mockClient;
|
expect(chunks).toHaveLength(3)
|
||||||
|
expect(chunks[0]).toEqual({
|
||||||
|
type: 'usage',
|
||||||
|
inputTokens: 50,
|
||||||
|
outputTokens: 100
|
||||||
|
})
|
||||||
|
expect(chunks[1]).toEqual({
|
||||||
|
type: 'text',
|
||||||
|
text: 'Hello'
|
||||||
|
})
|
||||||
|
expect(chunks[2]).toEqual({
|
||||||
|
type: 'text',
|
||||||
|
text: ' world'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
const mockStream = new Readable({
|
test('createMessage handles error scenarios', async () => {
|
||||||
read() {
|
const mockClient = new MockBedrockRuntimeClient({
|
||||||
this.push(null);
|
region: 'us-east-1'
|
||||||
|
})
|
||||||
|
|
||||||
|
// Simulate an error by overriding the send method
|
||||||
|
mockClient.send = () => {
|
||||||
|
throw new Error('API request failed')
|
||||||
|
}
|
||||||
|
|
||||||
|
const handler = new TestAwsBedrockHandler(mockOptions, mockClient)
|
||||||
|
|
||||||
|
const systemPrompt = 'You are a helpful assistant'
|
||||||
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
|
{ role: 'user', content: 'Cause an error' }
|
||||||
|
]
|
||||||
|
|
||||||
|
await expect(async () => {
|
||||||
|
const generator = handler.createMessage(systemPrompt, messages)
|
||||||
|
const chunks = []
|
||||||
|
|
||||||
|
for await (const chunk of generator) {
|
||||||
|
chunks.push(chunk)
|
||||||
}
|
}
|
||||||
});
|
}).rejects.toThrow('API request failed')
|
||||||
|
})
|
||||||
mockClient.send.mockImplementation(() =>
|
})
|
||||||
Promise.resolve({
|
|
||||||
$metadata: {},
|
|
||||||
stream: mockStream
|
|
||||||
} as ConverseStreamCommandOutput)
|
|
||||||
);
|
|
||||||
|
|
||||||
const systemPrompt = 'Test system prompt';
|
|
||||||
const messages = [{ role: 'user' as const, content: 'Test message' }];
|
|
||||||
|
|
||||||
await handler.createMessage(systemPrompt, messages);
|
|
||||||
|
|
||||||
expect(mockClient.send).toHaveBeenCalledWith(
|
|
||||||
expect.objectContaining({
|
|
||||||
input: expect.stringContaining('promptCacheId')
|
|
||||||
})
|
|
||||||
);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|||||||
Reference in New Issue
Block a user