Files
Roo-Code/src/api/providers/bedrock.ts
Lunchb0ne 7a61e6ab74 Support AWS profile to configure Bedrock Authentication
Added support for configurations under ~/.aws/credentials or ~/.aws/config.
2025-01-18 12:41:46 +00:00

293 lines
7.2 KiB
TypeScript

import {
BedrockRuntimeClient,
ConverseStreamCommand,
ConverseCommand,
BedrockRuntimeClientConfig,
} from "@aws-sdk/client-bedrock-runtime"
import { fromIni } from "@aws-sdk/credential-providers"
import { Anthropic } from "@anthropic-ai/sdk"
import { ApiHandler, SingleCompletionHandler } from "../"
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
import { ApiStream } from "../transform/stream"
import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../transform/bedrock-converse-format"
// Define types for stream events based on AWS SDK
export interface StreamEvent {
messageStart?: {
role?: string
}
messageStop?: {
stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence"
additionalModelResponseFields?: Record<string, unknown>
}
contentBlockStart?: {
start?: {
text?: string
}
contentBlockIndex?: number
}
contentBlockDelta?: {
delta?: {
text?: string
}
contentBlockIndex?: number
}
metadata?: {
usage?: {
inputTokens: number
outputTokens: number
totalTokens?: number // Made optional since we don't use it
}
metrics?: {
latencyMs: number
}
}
}
export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
private client: BedrockRuntimeClient
constructor(options: ApiHandlerOptions) {
this.options = options
const clientConfig: BedrockRuntimeClientConfig = {
region: this.options.awsRegion || "us-east-1",
}
if (this.options.awsUseProfile && this.options.awsProfile) {
// Use profile-based credentials if enabled and profile is set
clientConfig.credentials = fromIni({
profile: this.options.awsProfile,
})
} else if (this.options.awsAccessKey && this.options.awsSecretKey) {
// Use direct credentials if provided
clientConfig.credentials = {
accessKeyId: this.options.awsAccessKey,
secretAccessKey: this.options.awsSecretKey,
...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {}),
}
}
this.client = new BedrockRuntimeClient(clientConfig)
}
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
const modelConfig = this.getModel()
// Handle cross-region inference
let modelId: string
if (this.options.awsUseCrossRegionInference) {
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
switch (regionPrefix) {
case "us-":
modelId = `us.${modelConfig.id}`
break
case "eu-":
modelId = `eu.${modelConfig.id}`
break
default:
modelId = modelConfig.id
break
}
} else {
modelId = modelConfig.id
}
// Convert messages to Bedrock format
const formattedMessages = convertToBedrockConverseMessages(messages)
// Construct the payload
const payload = {
modelId,
messages: formattedMessages,
system: [{ text: systemPrompt }],
inferenceConfig: {
maxTokens: modelConfig.info.maxTokens || 5000,
temperature: 0.3,
topP: 0.1,
...(this.options.awsUsePromptCache
? {
promptCache: {
promptCacheId: this.options.awspromptCacheId || "",
},
}
: {}),
},
}
try {
const command = new ConverseStreamCommand(payload)
const response = await this.client.send(command)
if (!response.stream) {
throw new Error("No stream available in the response")
}
for await (const chunk of response.stream) {
// Parse the chunk as JSON if it's a string (for tests)
let streamEvent: StreamEvent
try {
streamEvent = typeof chunk === "string" ? JSON.parse(chunk) : (chunk as unknown as StreamEvent)
} catch (e) {
console.error("Failed to parse stream event:", e)
continue
}
// Handle metadata events first
if (streamEvent.metadata?.usage) {
yield {
type: "usage",
inputTokens: streamEvent.metadata.usage.inputTokens || 0,
outputTokens: streamEvent.metadata.usage.outputTokens || 0,
}
continue
}
// Handle message start
if (streamEvent.messageStart) {
continue
}
// Handle content blocks
if (streamEvent.contentBlockStart?.start?.text) {
yield {
type: "text",
text: streamEvent.contentBlockStart.start.text,
}
continue
}
// Handle content deltas
if (streamEvent.contentBlockDelta?.delta?.text) {
yield {
type: "text",
text: streamEvent.contentBlockDelta.delta.text,
}
continue
}
// Handle message stop
if (streamEvent.messageStop) {
continue
}
}
} catch (error: unknown) {
console.error("Bedrock Runtime API Error:", error)
// Only access stack if error is an Error object
if (error instanceof Error) {
console.error("Error stack:", error.stack)
yield {
type: "text",
text: `Error: ${error.message}`,
}
yield {
type: "usage",
inputTokens: 0,
outputTokens: 0,
}
throw error
} else {
const unknownError = new Error("An unknown error occurred")
yield {
type: "text",
text: unknownError.message,
}
yield {
type: "usage",
inputTokens: 0,
outputTokens: 0,
}
throw unknownError
}
}
}
getModel(): { id: BedrockModelId | string; info: ModelInfo } {
const modelId = this.options.apiModelId
if (modelId) {
// For tests, allow any model ID
if (process.env.NODE_ENV === "test") {
return {
id: modelId,
info: {
maxTokens: 5000,
contextWindow: 128_000,
supportsPromptCache: false,
},
}
}
// For production, validate against known models
if (modelId in bedrockModels) {
const id = modelId as BedrockModelId
return { id, info: bedrockModels[id] }
}
}
return {
id: bedrockDefaultModelId,
info: bedrockModels[bedrockDefaultModelId],
}
}
async completePrompt(prompt: string): Promise<string> {
try {
const modelConfig = this.getModel()
// Handle cross-region inference
let modelId: string
if (this.options.awsUseCrossRegionInference) {
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
switch (regionPrefix) {
case "us-":
modelId = `us.${modelConfig.id}`
break
case "eu-":
modelId = `eu.${modelConfig.id}`
break
default:
modelId = modelConfig.id
break
}
} else {
modelId = modelConfig.id
}
const payload = {
modelId,
messages: convertToBedrockConverseMessages([
{
role: "user",
content: prompt,
},
]),
inferenceConfig: {
maxTokens: modelConfig.info.maxTokens || 5000,
temperature: 0.3,
topP: 0.1,
},
}
const command = new ConverseCommand(payload)
const response = await this.client.send(command)
if (response.output && response.output instanceof Uint8Array) {
try {
const outputStr = new TextDecoder().decode(response.output)
const output = JSON.parse(outputStr)
if (output.content) {
return output.content
}
} catch (parseError) {
console.error("Failed to parse Bedrock response:", parseError)
}
}
return ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`Bedrock completion error: ${error.message}`)
}
throw error
}
}
}