mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-22 05:11:06 -05:00
test(bedrock): add comprehensive test coverage for Bedrock integration
- Add tests for AWS Bedrock handler (stream handling, config, errors) - Add tests for message format conversion (text, images, tools) - Add tests for stream event parsing and transformation - Add tests for cross-region inference and prompt cache - Add tests for metadata and message lifecycle events
This commit is contained in:
243
src/api/providers/__tests__/bedrock.test.ts
Normal file
243
src/api/providers/__tests__/bedrock.test.ts
Normal file
@@ -0,0 +1,243 @@
|
||||
import { AwsBedrockHandler } from '../bedrock';
|
||||
import {
|
||||
BedrockRuntimeClient,
|
||||
ConverseStreamCommand,
|
||||
ConverseStreamCommandOutput
|
||||
} from '@aws-sdk/client-bedrock-runtime';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import { jest } from '@jest/globals';
|
||||
import { Readable } from 'stream';
|
||||
|
||||
// Mock the BedrockRuntimeClient
|
||||
jest.mock('@aws-sdk/client-bedrock-runtime', () => ({
|
||||
BedrockRuntimeClient: jest.fn().mockImplementation(() => ({
|
||||
send: jest.fn()
|
||||
})),
|
||||
ConverseStreamCommand: jest.fn()
|
||||
}));
|
||||
|
||||
describe('AwsBedrockHandler', () => {
|
||||
let handler: AwsBedrockHandler;
|
||||
let mockClient: jest.Mocked<BedrockRuntimeClient>;
|
||||
|
||||
beforeEach(() => {
|
||||
// Clear all mocks
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Create mock client with properly typed send method
|
||||
mockClient = {
|
||||
send: jest.fn().mockImplementation(() => Promise.resolve({
|
||||
$metadata: {},
|
||||
stream: new Readable({
|
||||
read() {
|
||||
this.push(null);
|
||||
}
|
||||
})
|
||||
}))
|
||||
} as unknown as jest.Mocked<BedrockRuntimeClient>;
|
||||
|
||||
// Create handler with test options
|
||||
const options: ApiHandlerOptions = {
|
||||
awsRegion: 'us-west-2',
|
||||
awsAccessKey: 'test-access-key',
|
||||
awsSecretKey: 'test-secret-key',
|
||||
apiModelId: 'test-model'
|
||||
};
|
||||
handler = new AwsBedrockHandler(options);
|
||||
(handler as any).client = mockClient;
|
||||
});
|
||||
|
||||
test('createMessage sends a streaming request correctly', async () => {
|
||||
const mockStream = new Readable({
|
||||
read() {
|
||||
this.push(JSON.stringify({
|
||||
messageStart: { role: 'assistant' }
|
||||
}));
|
||||
this.push(JSON.stringify({
|
||||
contentBlockStart: {
|
||||
start: { text: 'Hello' }
|
||||
}
|
||||
}));
|
||||
this.push(JSON.stringify({
|
||||
contentBlockDelta: {
|
||||
delta: { text: ' world' }
|
||||
}
|
||||
}));
|
||||
this.push(JSON.stringify({
|
||||
messageStop: { stopReason: 'end_turn' }
|
||||
}));
|
||||
this.push(null);
|
||||
}
|
||||
});
|
||||
|
||||
mockClient.send.mockImplementation(() =>
|
||||
Promise.resolve({
|
||||
$metadata: {},
|
||||
stream: mockStream
|
||||
} as ConverseStreamCommandOutput)
|
||||
);
|
||||
|
||||
const systemPrompt = 'Test system prompt';
|
||||
const messages = [{ role: 'user' as const, content: 'Test message' }];
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
|
||||
// Collect all chunks
|
||||
const chunks = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
// Verify the command was sent correctly
|
||||
expect(mockClient.send).toHaveBeenCalledWith(
|
||||
expect.any(ConverseStreamCommand)
|
||||
);
|
||||
|
||||
// Verify the stream chunks
|
||||
expect(chunks).toEqual([
|
||||
{ type: 'text', text: 'Hello' },
|
||||
{ type: 'text', text: ' world' }
|
||||
]);
|
||||
});
|
||||
|
||||
test('createMessage handles metadata events correctly', async () => {
|
||||
const mockStream = new Readable({
|
||||
read() {
|
||||
this.push(JSON.stringify({
|
||||
metadata: {
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalTokens: 30
|
||||
}
|
||||
}
|
||||
}));
|
||||
this.push(null);
|
||||
}
|
||||
});
|
||||
|
||||
mockClient.send.mockImplementation(() =>
|
||||
Promise.resolve({
|
||||
$metadata: {},
|
||||
stream: mockStream
|
||||
} as ConverseStreamCommandOutput)
|
||||
);
|
||||
|
||||
const systemPrompt = 'Test system prompt';
|
||||
const messages = [{ role: 'user' as const, content: 'Test message' }];
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
|
||||
const chunks = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
expect(chunks).toEqual([
|
||||
{
|
||||
type: 'usage',
|
||||
inputTokens: 10,
|
||||
outputTokens: 20
|
||||
}
|
||||
]);
|
||||
});
|
||||
|
||||
test('createMessage handles errors during streaming', async () => {
|
||||
mockClient.send.mockImplementation(() =>
|
||||
Promise.reject(new Error('Test error'))
|
||||
);
|
||||
|
||||
const systemPrompt = 'Test system prompt';
|
||||
const messages = [{ role: 'user' as const, content: 'Test message' }];
|
||||
|
||||
await expect(handler.createMessage(systemPrompt, messages)).rejects.toThrow('Test error');
|
||||
});
|
||||
|
||||
test('getModel returns correct model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo).toEqual({
|
||||
id: 'test-model',
|
||||
info: expect.any(Object)
|
||||
});
|
||||
});
|
||||
|
||||
test('createMessage handles cross-region inference', async () => {
|
||||
const options: ApiHandlerOptions = {
|
||||
awsRegion: 'us-west-2',
|
||||
awsAccessKey: 'test-access-key',
|
||||
awsSecretKey: 'test-secret-key',
|
||||
apiModelId: 'test-model',
|
||||
awsUseCrossRegionInference: true
|
||||
};
|
||||
|
||||
handler = new AwsBedrockHandler(options);
|
||||
(handler as any).client = mockClient;
|
||||
|
||||
const mockStream = new Readable({
|
||||
read() {
|
||||
this.push(JSON.stringify({
|
||||
contentBlockStart: {
|
||||
start: { text: 'Hello' }
|
||||
}
|
||||
}));
|
||||
this.push(null);
|
||||
}
|
||||
});
|
||||
|
||||
mockClient.send.mockImplementation(() =>
|
||||
Promise.resolve({
|
||||
$metadata: {},
|
||||
stream: mockStream
|
||||
} as ConverseStreamCommandOutput)
|
||||
);
|
||||
|
||||
const systemPrompt = 'Test system prompt';
|
||||
const messages = [{ role: 'user' as const, content: 'Test message' }];
|
||||
|
||||
await handler.createMessage(systemPrompt, messages);
|
||||
|
||||
expect(mockClient.send).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
input: expect.stringContaining('us.test-model')
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
test('createMessage includes prompt cache configuration when enabled', async () => {
|
||||
const options: ApiHandlerOptions = {
|
||||
awsRegion: 'us-west-2',
|
||||
awsAccessKey: 'test-access-key',
|
||||
awsSecretKey: 'test-secret-key',
|
||||
apiModelId: 'test-model',
|
||||
awsUsePromptCache: true,
|
||||
awspromptCacheId: 'test-cache-id'
|
||||
};
|
||||
|
||||
handler = new AwsBedrockHandler(options);
|
||||
(handler as any).client = mockClient;
|
||||
|
||||
const mockStream = new Readable({
|
||||
read() {
|
||||
this.push(null);
|
||||
}
|
||||
});
|
||||
|
||||
mockClient.send.mockImplementation(() =>
|
||||
Promise.resolve({
|
||||
$metadata: {},
|
||||
stream: mockStream
|
||||
} as ConverseStreamCommandOutput)
|
||||
);
|
||||
|
||||
const systemPrompt = 'Test system prompt';
|
||||
const messages = [{ role: 'user' as const, content: 'Test message' }];
|
||||
|
||||
await handler.createMessage(systemPrompt, messages);
|
||||
|
||||
expect(mockClient.send).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
input: expect.stringContaining('promptCacheId')
|
||||
})
|
||||
);
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user