From ca41c54cb5c480718b8b3dc05227a4c00362887a Mon Sep 17 00:00:00 2001 From: Cline Date: Tue, 10 Dec 2024 23:48:08 +0200 Subject: [PATCH] test(bedrock): add comprehensive test coverage for Bedrock integration - Add tests for AWS Bedrock handler (stream handling, config, errors) - Add tests for message format conversion (text, images, tools) - Add tests for stream event parsing and transformation - Add tests for cross-region inference and prompt cache - Add tests for metadata and message lifecycle events --- src/api/providers/__tests__/bedrock.test.ts | 243 +++++++++++++++++ .../__tests__/bedrock-converse-format.test.ts | 252 ++++++++++++++++++ 2 files changed, 495 insertions(+) create mode 100644 src/api/providers/__tests__/bedrock.test.ts create mode 100644 src/api/transform/__tests__/bedrock-converse-format.test.ts diff --git a/src/api/providers/__tests__/bedrock.test.ts b/src/api/providers/__tests__/bedrock.test.ts new file mode 100644 index 0000000..c3285bd --- /dev/null +++ b/src/api/providers/__tests__/bedrock.test.ts @@ -0,0 +1,243 @@ +import { AwsBedrockHandler } from '../bedrock'; +import { + BedrockRuntimeClient, + ConverseStreamCommand, + ConverseStreamCommandOutput +} from '@aws-sdk/client-bedrock-runtime'; +import { ApiHandlerOptions } from '../../../shared/api'; +import { jest } from '@jest/globals'; +import { Readable } from 'stream'; + +// Mock the BedrockRuntimeClient +jest.mock('@aws-sdk/client-bedrock-runtime', () => ({ + BedrockRuntimeClient: jest.fn().mockImplementation(() => ({ + send: jest.fn() + })), + ConverseStreamCommand: jest.fn() +})); + +describe('AwsBedrockHandler', () => { + let handler: AwsBedrockHandler; + let mockClient: jest.Mocked; + + beforeEach(() => { + // Clear all mocks + jest.clearAllMocks(); + + // Create mock client with properly typed send method + mockClient = { + send: jest.fn().mockImplementation(() => Promise.resolve({ + $metadata: {}, + stream: new Readable({ + read() { + this.push(null); + } + }) + })) + } as unknown as jest.Mocked; + + // Create handler with test options + const options: ApiHandlerOptions = { + awsRegion: 'us-west-2', + awsAccessKey: 'test-access-key', + awsSecretKey: 'test-secret-key', + apiModelId: 'test-model' + }; + handler = new AwsBedrockHandler(options); + (handler as any).client = mockClient; + }); + + test('createMessage sends a streaming request correctly', async () => { + const mockStream = new Readable({ + read() { + this.push(JSON.stringify({ + messageStart: { role: 'assistant' } + })); + this.push(JSON.stringify({ + contentBlockStart: { + start: { text: 'Hello' } + } + })); + this.push(JSON.stringify({ + contentBlockDelta: { + delta: { text: ' world' } + } + })); + this.push(JSON.stringify({ + messageStop: { stopReason: 'end_turn' } + })); + this.push(null); + } + }); + + mockClient.send.mockImplementation(() => + Promise.resolve({ + $metadata: {}, + stream: mockStream + } as ConverseStreamCommandOutput) + ); + + const systemPrompt = 'Test system prompt'; + const messages = [{ role: 'user' as const, content: 'Test message' }]; + + const stream = handler.createMessage(systemPrompt, messages); + + // Collect all chunks + const chunks = []; + for await (const chunk of stream) { + chunks.push(chunk); + } + + // Verify the command was sent correctly + expect(mockClient.send).toHaveBeenCalledWith( + expect.any(ConverseStreamCommand) + ); + + // Verify the stream chunks + expect(chunks).toEqual([ + { type: 'text', text: 'Hello' }, + { type: 'text', text: ' world' } + ]); + }); + + test('createMessage handles metadata events correctly', async () => { + const mockStream = new Readable({ + read() { + this.push(JSON.stringify({ + metadata: { + usage: { + inputTokens: 10, + outputTokens: 20, + totalTokens: 30 + } + } + })); + this.push(null); + } + }); + + mockClient.send.mockImplementation(() => + Promise.resolve({ + $metadata: {}, + stream: mockStream + } as ConverseStreamCommandOutput) + ); + + const systemPrompt = 'Test system prompt'; + const messages = [{ role: 'user' as const, content: 'Test message' }]; + + const stream = handler.createMessage(systemPrompt, messages); + + const chunks = []; + for await (const chunk of stream) { + chunks.push(chunk); + } + + expect(chunks).toEqual([ + { + type: 'usage', + inputTokens: 10, + outputTokens: 20 + } + ]); + }); + + test('createMessage handles errors during streaming', async () => { + mockClient.send.mockImplementation(() => + Promise.reject(new Error('Test error')) + ); + + const systemPrompt = 'Test system prompt'; + const messages = [{ role: 'user' as const, content: 'Test message' }]; + + await expect(handler.createMessage(systemPrompt, messages)).rejects.toThrow('Test error'); + }); + + test('getModel returns correct model info', () => { + const modelInfo = handler.getModel(); + expect(modelInfo).toEqual({ + id: 'test-model', + info: expect.any(Object) + }); + }); + + test('createMessage handles cross-region inference', async () => { + const options: ApiHandlerOptions = { + awsRegion: 'us-west-2', + awsAccessKey: 'test-access-key', + awsSecretKey: 'test-secret-key', + apiModelId: 'test-model', + awsUseCrossRegionInference: true + }; + + handler = new AwsBedrockHandler(options); + (handler as any).client = mockClient; + + const mockStream = new Readable({ + read() { + this.push(JSON.stringify({ + contentBlockStart: { + start: { text: 'Hello' } + } + })); + this.push(null); + } + }); + + mockClient.send.mockImplementation(() => + Promise.resolve({ + $metadata: {}, + stream: mockStream + } as ConverseStreamCommandOutput) + ); + + const systemPrompt = 'Test system prompt'; + const messages = [{ role: 'user' as const, content: 'Test message' }]; + + await handler.createMessage(systemPrompt, messages); + + expect(mockClient.send).toHaveBeenCalledWith( + expect.objectContaining({ + input: expect.stringContaining('us.test-model') + }) + ); + }); + + test('createMessage includes prompt cache configuration when enabled', async () => { + const options: ApiHandlerOptions = { + awsRegion: 'us-west-2', + awsAccessKey: 'test-access-key', + awsSecretKey: 'test-secret-key', + apiModelId: 'test-model', + awsUsePromptCache: true, + awspromptCacheId: 'test-cache-id' + }; + + handler = new AwsBedrockHandler(options); + (handler as any).client = mockClient; + + const mockStream = new Readable({ + read() { + this.push(null); + } + }); + + mockClient.send.mockImplementation(() => + Promise.resolve({ + $metadata: {}, + stream: mockStream + } as ConverseStreamCommandOutput) + ); + + const systemPrompt = 'Test system prompt'; + const messages = [{ role: 'user' as const, content: 'Test message' }]; + + await handler.createMessage(systemPrompt, messages); + + expect(mockClient.send).toHaveBeenCalledWith( + expect.objectContaining({ + input: expect.stringContaining('promptCacheId') + }) + ); + }); +}); diff --git a/src/api/transform/__tests__/bedrock-converse-format.test.ts b/src/api/transform/__tests__/bedrock-converse-format.test.ts new file mode 100644 index 0000000..c9a0190 --- /dev/null +++ b/src/api/transform/__tests__/bedrock-converse-format.test.ts @@ -0,0 +1,252 @@ +import { convertToBedrockConverseMessages, convertToAnthropicMessage } from '../bedrock-converse-format' +import { Anthropic } from '@anthropic-ai/sdk' +import { ContentBlock, ToolResultContentBlock } from '@aws-sdk/client-bedrock-runtime' +import { StreamEvent } from '../../providers/bedrock' + +describe('bedrock-converse-format', () => { + describe('convertToBedrockConverseMessages', () => { + test('converts simple text messages correctly', () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { role: 'user', content: 'Hello' }, + { role: 'assistant', content: 'Hi there' } + ] + + const result = convertToBedrockConverseMessages(messages) + + expect(result).toEqual([ + { + role: 'user', + content: [{ text: 'Hello' }] + }, + { + role: 'assistant', + content: [{ text: 'Hi there' }] + } + ]) + }) + + test('converts messages with images correctly', () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'Look at this image:' + }, + { + type: 'image', + source: { + type: 'base64', + data: 'SGVsbG8=', // "Hello" in base64 + media_type: 'image/jpeg' as const + } + } + ] + } + ] + + const result = convertToBedrockConverseMessages(messages) + + if (!result[0] || !result[0].content) { + fail('Expected result to have content') + return + } + + expect(result[0].role).toBe('user') + expect(result[0].content).toHaveLength(2) + expect(result[0].content[0]).toEqual({ text: 'Look at this image:' }) + + const imageBlock = result[0].content[1] as ContentBlock + if ('image' in imageBlock && imageBlock.image && imageBlock.image.source) { + expect(imageBlock.image.format).toBe('jpeg') + expect(imageBlock.image.source).toBeDefined() + expect(imageBlock.image.source.bytes).toBeDefined() + } else { + fail('Expected image block not found') + } + }) + + test('converts tool use messages correctly', () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: 'assistant', + content: [ + { + type: 'tool_use', + id: 'test-id', + name: 'read_file', + input: { + path: 'test.txt' + } + } + ] + } + ] + + const result = convertToBedrockConverseMessages(messages) + + if (!result[0] || !result[0].content) { + fail('Expected result to have content') + return + } + + expect(result[0].role).toBe('assistant') + const toolBlock = result[0].content[0] as ContentBlock + if ('toolUse' in toolBlock && toolBlock.toolUse) { + expect(toolBlock.toolUse).toEqual({ + toolUseId: 'test-id', + name: 'read_file', + input: '\n\ntest.txt\n\n' + }) + } else { + fail('Expected tool use block not found') + } + }) + + test('converts tool result messages correctly', () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: 'assistant', + content: [ + { + type: 'tool_result', + tool_use_id: 'test-id', + content: [{ type: 'text', text: 'File contents here' }] + } + ] + } + ] + + const result = convertToBedrockConverseMessages(messages) + + if (!result[0] || !result[0].content) { + fail('Expected result to have content') + return + } + + expect(result[0].role).toBe('assistant') + const resultBlock = result[0].content[0] as ContentBlock + if ('toolResult' in resultBlock && resultBlock.toolResult) { + const expectedContent: ToolResultContentBlock[] = [ + { text: 'File contents here' } + ] + expect(resultBlock.toolResult).toEqual({ + toolUseId: 'test-id', + content: expectedContent, + status: 'success' + }) + } else { + fail('Expected tool result block not found') + } + }) + + test('handles text content correctly', () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'Hello world' + } + ] + } + ] + + const result = convertToBedrockConverseMessages(messages) + + if (!result[0] || !result[0].content) { + fail('Expected result to have content') + return + } + + expect(result[0].role).toBe('user') + expect(result[0].content).toHaveLength(1) + const textBlock = result[0].content[0] as ContentBlock + expect(textBlock).toEqual({ text: 'Hello world' }) + }) + }) + + describe('convertToAnthropicMessage', () => { + test('converts metadata events correctly', () => { + const event: StreamEvent = { + metadata: { + usage: { + inputTokens: 10, + outputTokens: 20 + } + } + } + + const result = convertToAnthropicMessage(event, 'test-model') + + expect(result).toEqual({ + id: '', + type: 'message', + role: 'assistant', + model: 'test-model', + usage: { + input_tokens: 10, + output_tokens: 20 + } + }) + }) + + test('converts content block start events correctly', () => { + const event: StreamEvent = { + contentBlockStart: { + start: { + text: 'Hello' + } + } + } + + const result = convertToAnthropicMessage(event, 'test-model') + + expect(result).toEqual({ + type: 'message', + role: 'assistant', + content: [{ type: 'text', text: 'Hello' }], + model: 'test-model' + }) + }) + + test('converts content block delta events correctly', () => { + const event: StreamEvent = { + contentBlockDelta: { + delta: { + text: ' world' + } + } + } + + const result = convertToAnthropicMessage(event, 'test-model') + + expect(result).toEqual({ + type: 'message', + role: 'assistant', + content: [{ type: 'text', text: ' world' }], + model: 'test-model' + }) + }) + + test('converts message stop events correctly', () => { + const event: StreamEvent = { + messageStop: { + stopReason: 'end_turn' as const + } + } + + const result = convertToAnthropicMessage(event, 'test-model') + + expect(result).toEqual({ + type: 'message', + role: 'assistant', + stop_reason: 'end_turn', + stop_sequence: null, + model: 'test-model' + }) + }) + }) +})