Files
Roo-Code/src/api/providers/__tests__/bedrock.test.ts
2025-01-17 14:11:28 -05:00

260 lines
6.9 KiB
TypeScript

import { AwsBedrockHandler } from "../bedrock"
import { MessageContent } from "../../../shared/api"
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
import { Anthropic } from "@anthropic-ai/sdk"
describe("AwsBedrockHandler", () => {
let handler: AwsBedrockHandler
beforeEach(() => {
handler = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
})
})
describe("constructor", () => {
it("should initialize with provided config", () => {
expect(handler["options"].awsAccessKey).toBe("test-access-key")
expect(handler["options"].awsSecretKey).toBe("test-secret-key")
expect(handler["options"].awsRegion).toBe("us-east-1")
expect(handler["options"].apiModelId).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
})
it("should initialize with missing AWS credentials", () => {
const handlerWithoutCreds = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsRegion: "us-east-1",
})
expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler)
})
})
describe("createMessage", () => {
const mockMessages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello",
},
{
role: "assistant",
content: "Hi there!",
},
]
const systemPrompt = "You are a helpful assistant"
it("should handle text messages correctly", async () => {
const mockResponse = {
messages: [
{
role: "assistant",
content: [{ type: "text", text: "Hello! How can I help you?" }],
},
],
usage: {
input_tokens: 10,
output_tokens: 5,
},
}
// Mock AWS SDK invoke
const mockStream = {
[Symbol.asyncIterator]: async function* () {
yield {
metadata: {
usage: {
inputTokens: 10,
outputTokens: 5,
},
},
}
},
}
const mockInvoke = jest.fn().mockResolvedValue({
stream: mockStream,
})
handler["client"] = {
send: mockInvoke,
} as unknown as BedrockRuntimeClient
const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}
expect(chunks.length).toBeGreaterThan(0)
expect(chunks[0]).toEqual({
type: "usage",
inputTokens: 10,
outputTokens: 5,
})
expect(mockInvoke).toHaveBeenCalledWith(
expect.objectContaining({
input: expect.objectContaining({
modelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
}),
}),
)
})
it("should handle API errors", async () => {
// Mock AWS SDK invoke with error
const mockInvoke = jest.fn().mockRejectedValue(new Error("AWS Bedrock error"))
handler["client"] = {
send: mockInvoke,
} as unknown as BedrockRuntimeClient
const stream = handler.createMessage(systemPrompt, mockMessages)
await expect(async () => {
for await (const chunk of stream) {
// Should throw before yielding any chunks
}
}).rejects.toThrow("AWS Bedrock error")
})
})
describe("completePrompt", () => {
it("should complete prompt successfully", async () => {
const mockResponse = {
output: new TextEncoder().encode(
JSON.stringify({
content: "Test response",
}),
),
}
const mockSend = jest.fn().mockResolvedValue(mockResponse)
handler["client"] = {
send: mockSend,
} as unknown as BedrockRuntimeClient
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockSend).toHaveBeenCalledWith(
expect.objectContaining({
input: expect.objectContaining({
modelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
messages: expect.arrayContaining([
expect.objectContaining({
role: "user",
content: [{ text: "Test prompt" }],
}),
]),
inferenceConfig: expect.objectContaining({
maxTokens: 5000,
temperature: 0.3,
topP: 0.1,
}),
}),
}),
)
})
it("should handle API errors", async () => {
const mockError = new Error("AWS Bedrock error")
const mockSend = jest.fn().mockRejectedValue(mockError)
handler["client"] = {
send: mockSend,
} as unknown as BedrockRuntimeClient
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
"Bedrock completion error: AWS Bedrock error",
)
})
it("should handle invalid response format", async () => {
const mockResponse = {
output: new TextEncoder().encode("invalid json"),
}
const mockSend = jest.fn().mockResolvedValue(mockResponse)
handler["client"] = {
send: mockSend,
} as unknown as BedrockRuntimeClient
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})
it("should handle empty response", async () => {
const mockResponse = {
output: new TextEncoder().encode(JSON.stringify({})),
}
const mockSend = jest.fn().mockResolvedValue(mockResponse)
handler["client"] = {
send: mockSend,
} as unknown as BedrockRuntimeClient
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})
it("should handle cross-region inference", async () => {
handler = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
awsUseCrossRegionInference: true,
})
const mockResponse = {
output: new TextEncoder().encode(
JSON.stringify({
content: "Test response",
}),
),
}
const mockSend = jest.fn().mockResolvedValue(mockResponse)
handler["client"] = {
send: mockSend,
} as unknown as BedrockRuntimeClient
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockSend).toHaveBeenCalledWith(
expect.objectContaining({
input: expect.objectContaining({
modelId: "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
}),
}),
)
})
})
describe("getModel", () => {
it("should return correct model info in test environment", () => {
const modelInfo = handler.getModel()
expect(modelInfo.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
expect(modelInfo.info).toBeDefined()
expect(modelInfo.info.maxTokens).toBe(5000) // Test environment value
expect(modelInfo.info.contextWindow).toBe(128_000) // Test environment value
})
it("should return test model info for invalid model in test environment", () => {
const invalidHandler = new AwsBedrockHandler({
apiModelId: "invalid-model",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
})
const modelInfo = invalidHandler.getModel()
expect(modelInfo.id).toBe("invalid-model") // In test env, returns whatever is passed
expect(modelInfo.info.maxTokens).toBe(5000)
expect(modelInfo.info.contextWindow).toBe(128_000)
})
})
})