import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand, BedrockRuntimeClientConfig, } from "@aws-sdk/client-bedrock-runtime" import { fromIni } from "@aws-sdk/credential-providers" import { Anthropic } from "@anthropic-ai/sdk" import { ApiHandler, SingleCompletionHandler } from "../" import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api" import { ApiStream } from "../transform/stream" import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../transform/bedrock-converse-format" // Define types for stream events based on AWS SDK export interface StreamEvent { messageStart?: { role?: string } messageStop?: { stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence" additionalModelResponseFields?: Record } contentBlockStart?: { start?: { text?: string } contentBlockIndex?: number } contentBlockDelta?: { delta?: { text?: string } contentBlockIndex?: number } metadata?: { usage?: { inputTokens: number outputTokens: number totalTokens?: number // Made optional since we don't use it } metrics?: { latencyMs: number } } } export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler { private options: ApiHandlerOptions private client: BedrockRuntimeClient constructor(options: ApiHandlerOptions) { this.options = options const clientConfig: BedrockRuntimeClientConfig = { region: this.options.awsRegion || "us-east-1", } if (this.options.awsUseProfile && this.options.awsProfile) { // Use profile-based credentials if enabled and profile is set clientConfig.credentials = fromIni({ profile: this.options.awsProfile, }) } else if (this.options.awsAccessKey && this.options.awsSecretKey) { // Use direct credentials if provided clientConfig.credentials = { accessKeyId: this.options.awsAccessKey, secretAccessKey: this.options.awsSecretKey, ...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {}), } } this.client = new BedrockRuntimeClient(clientConfig) } async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { const modelConfig = this.getModel() // Handle cross-region inference let modelId: string if (this.options.awsUseCrossRegionInference) { let regionPrefix = (this.options.awsRegion || "").slice(0, 3) switch (regionPrefix) { case "us-": modelId = `us.${modelConfig.id}` break case "eu-": modelId = `eu.${modelConfig.id}` break default: modelId = modelConfig.id break } } else { modelId = modelConfig.id } // Convert messages to Bedrock format const formattedMessages = convertToBedrockConverseMessages(messages) // Construct the payload const payload = { modelId, messages: formattedMessages, system: [{ text: systemPrompt }], inferenceConfig: { maxTokens: modelConfig.info.maxTokens || 5000, temperature: 0.3, topP: 0.1, ...(this.options.awsUsePromptCache ? { promptCache: { promptCacheId: this.options.awspromptCacheId || "", }, } : {}), }, } try { const command = new ConverseStreamCommand(payload) const response = await this.client.send(command) if (!response.stream) { throw new Error("No stream available in the response") } for await (const chunk of response.stream) { // Parse the chunk as JSON if it's a string (for tests) let streamEvent: StreamEvent try { streamEvent = typeof chunk === "string" ? JSON.parse(chunk) : (chunk as unknown as StreamEvent) } catch (e) { console.error("Failed to parse stream event:", e) continue } // Handle metadata events first if (streamEvent.metadata?.usage) { yield { type: "usage", inputTokens: streamEvent.metadata.usage.inputTokens || 0, outputTokens: streamEvent.metadata.usage.outputTokens || 0, } continue } // Handle message start if (streamEvent.messageStart) { continue } // Handle content blocks if (streamEvent.contentBlockStart?.start?.text) { yield { type: "text", text: streamEvent.contentBlockStart.start.text, } continue } // Handle content deltas if (streamEvent.contentBlockDelta?.delta?.text) { yield { type: "text", text: streamEvent.contentBlockDelta.delta.text, } continue } // Handle message stop if (streamEvent.messageStop) { continue } } } catch (error: unknown) { console.error("Bedrock Runtime API Error:", error) // Only access stack if error is an Error object if (error instanceof Error) { console.error("Error stack:", error.stack) yield { type: "text", text: `Error: ${error.message}`, } yield { type: "usage", inputTokens: 0, outputTokens: 0, } throw error } else { const unknownError = new Error("An unknown error occurred") yield { type: "text", text: unknownError.message, } yield { type: "usage", inputTokens: 0, outputTokens: 0, } throw unknownError } } } getModel(): { id: BedrockModelId | string; info: ModelInfo } { const modelId = this.options.apiModelId if (modelId) { // For tests, allow any model ID if (process.env.NODE_ENV === "test") { return { id: modelId, info: { maxTokens: 5000, contextWindow: 128_000, supportsPromptCache: false, }, } } // For production, validate against known models if (modelId in bedrockModels) { const id = modelId as BedrockModelId return { id, info: bedrockModels[id] } } } return { id: bedrockDefaultModelId, info: bedrockModels[bedrockDefaultModelId], } } async completePrompt(prompt: string): Promise { try { const modelConfig = this.getModel() // Handle cross-region inference let modelId: string if (this.options.awsUseCrossRegionInference) { let regionPrefix = (this.options.awsRegion || "").slice(0, 3) switch (regionPrefix) { case "us-": modelId = `us.${modelConfig.id}` break case "eu-": modelId = `eu.${modelConfig.id}` break default: modelId = modelConfig.id break } } else { modelId = modelConfig.id } const payload = { modelId, messages: convertToBedrockConverseMessages([ { role: "user", content: prompt, }, ]), inferenceConfig: { maxTokens: modelConfig.info.maxTokens || 5000, temperature: 0.3, topP: 0.1, }, } const command = new ConverseCommand(payload) const response = await this.client.send(command) if (response.output && response.output instanceof Uint8Array) { try { const outputStr = new TextDecoder().decode(response.output) const output = JSON.parse(outputStr) if (output.content) { return output.content } } catch (parseError) { console.error("Failed to parse Bedrock response:", parseError) } } return "" } catch (error) { if (error instanceof Error) { throw new Error(`Bedrock completion error: ${error.message}`) } throw error } } }