mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-21 21:01:06 -05:00
Prettier backfill
This commit is contained in:
@@ -1,4 +1,9 @@
|
||||
import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime"
|
||||
import {
|
||||
BedrockRuntimeClient,
|
||||
ConverseStreamCommand,
|
||||
ConverseCommand,
|
||||
BedrockRuntimeClientConfig,
|
||||
} from "@aws-sdk/client-bedrock-runtime"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
|
||||
@@ -7,275 +12,276 @@ import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../
|
||||
|
||||
// Define types for stream events based on AWS SDK
|
||||
export interface StreamEvent {
|
||||
messageStart?: {
|
||||
role?: string;
|
||||
};
|
||||
messageStop?: {
|
||||
stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence";
|
||||
additionalModelResponseFields?: Record<string, unknown>;
|
||||
};
|
||||
contentBlockStart?: {
|
||||
start?: {
|
||||
text?: string;
|
||||
};
|
||||
contentBlockIndex?: number;
|
||||
};
|
||||
contentBlockDelta?: {
|
||||
delta?: {
|
||||
text?: string;
|
||||
};
|
||||
contentBlockIndex?: number;
|
||||
};
|
||||
metadata?: {
|
||||
usage?: {
|
||||
inputTokens: number;
|
||||
outputTokens: number;
|
||||
totalTokens?: number; // Made optional since we don't use it
|
||||
};
|
||||
metrics?: {
|
||||
latencyMs: number;
|
||||
};
|
||||
};
|
||||
messageStart?: {
|
||||
role?: string
|
||||
}
|
||||
messageStop?: {
|
||||
stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence"
|
||||
additionalModelResponseFields?: Record<string, unknown>
|
||||
}
|
||||
contentBlockStart?: {
|
||||
start?: {
|
||||
text?: string
|
||||
}
|
||||
contentBlockIndex?: number
|
||||
}
|
||||
contentBlockDelta?: {
|
||||
delta?: {
|
||||
text?: string
|
||||
}
|
||||
contentBlockIndex?: number
|
||||
}
|
||||
metadata?: {
|
||||
usage?: {
|
||||
inputTokens: number
|
||||
outputTokens: number
|
||||
totalTokens?: number // Made optional since we don't use it
|
||||
}
|
||||
metrics?: {
|
||||
latencyMs: number
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: BedrockRuntimeClient
|
||||
private options: ApiHandlerOptions
|
||||
private client: BedrockRuntimeClient
|
||||
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
this.options = options
|
||||
|
||||
// Only include credentials if they actually exist
|
||||
const clientConfig: BedrockRuntimeClientConfig = {
|
||||
region: this.options.awsRegion || "us-east-1"
|
||||
}
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
this.options = options
|
||||
|
||||
if (this.options.awsAccessKey && this.options.awsSecretKey) {
|
||||
// Create credentials object with all properties at once
|
||||
clientConfig.credentials = {
|
||||
accessKeyId: this.options.awsAccessKey,
|
||||
secretAccessKey: this.options.awsSecretKey,
|
||||
...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {})
|
||||
}
|
||||
}
|
||||
// Only include credentials if they actually exist
|
||||
const clientConfig: BedrockRuntimeClientConfig = {
|
||||
region: this.options.awsRegion || "us-east-1",
|
||||
}
|
||||
|
||||
this.client = new BedrockRuntimeClient(clientConfig)
|
||||
}
|
||||
if (this.options.awsAccessKey && this.options.awsSecretKey) {
|
||||
// Create credentials object with all properties at once
|
||||
clientConfig.credentials = {
|
||||
accessKeyId: this.options.awsAccessKey,
|
||||
secretAccessKey: this.options.awsSecretKey,
|
||||
...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {}),
|
||||
}
|
||||
}
|
||||
|
||||
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
||||
const modelConfig = this.getModel()
|
||||
|
||||
// Handle cross-region inference
|
||||
let modelId: string
|
||||
if (this.options.awsUseCrossRegionInference) {
|
||||
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
|
||||
switch (regionPrefix) {
|
||||
case "us-":
|
||||
modelId = `us.${modelConfig.id}`
|
||||
break
|
||||
case "eu-":
|
||||
modelId = `eu.${modelConfig.id}`
|
||||
break
|
||||
default:
|
||||
modelId = modelConfig.id
|
||||
break
|
||||
}
|
||||
} else {
|
||||
modelId = modelConfig.id
|
||||
}
|
||||
this.client = new BedrockRuntimeClient(clientConfig)
|
||||
}
|
||||
|
||||
// Convert messages to Bedrock format
|
||||
const formattedMessages = convertToBedrockConverseMessages(messages)
|
||||
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
||||
const modelConfig = this.getModel()
|
||||
|
||||
// Construct the payload
|
||||
const payload = {
|
||||
modelId,
|
||||
messages: formattedMessages,
|
||||
system: [{ text: systemPrompt }],
|
||||
inferenceConfig: {
|
||||
maxTokens: modelConfig.info.maxTokens || 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1,
|
||||
...(this.options.awsUsePromptCache ? {
|
||||
promptCache: {
|
||||
promptCacheId: this.options.awspromptCacheId || ""
|
||||
}
|
||||
} : {})
|
||||
}
|
||||
}
|
||||
// Handle cross-region inference
|
||||
let modelId: string
|
||||
if (this.options.awsUseCrossRegionInference) {
|
||||
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
|
||||
switch (regionPrefix) {
|
||||
case "us-":
|
||||
modelId = `us.${modelConfig.id}`
|
||||
break
|
||||
case "eu-":
|
||||
modelId = `eu.${modelConfig.id}`
|
||||
break
|
||||
default:
|
||||
modelId = modelConfig.id
|
||||
break
|
||||
}
|
||||
} else {
|
||||
modelId = modelConfig.id
|
||||
}
|
||||
|
||||
try {
|
||||
const command = new ConverseStreamCommand(payload)
|
||||
const response = await this.client.send(command)
|
||||
// Convert messages to Bedrock format
|
||||
const formattedMessages = convertToBedrockConverseMessages(messages)
|
||||
|
||||
if (!response.stream) {
|
||||
throw new Error('No stream available in the response')
|
||||
}
|
||||
// Construct the payload
|
||||
const payload = {
|
||||
modelId,
|
||||
messages: formattedMessages,
|
||||
system: [{ text: systemPrompt }],
|
||||
inferenceConfig: {
|
||||
maxTokens: modelConfig.info.maxTokens || 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1,
|
||||
...(this.options.awsUsePromptCache
|
||||
? {
|
||||
promptCache: {
|
||||
promptCacheId: this.options.awspromptCacheId || "",
|
||||
},
|
||||
}
|
||||
: {}),
|
||||
},
|
||||
}
|
||||
|
||||
for await (const chunk of response.stream) {
|
||||
// Parse the chunk as JSON if it's a string (for tests)
|
||||
let streamEvent: StreamEvent
|
||||
try {
|
||||
streamEvent = typeof chunk === 'string' ?
|
||||
JSON.parse(chunk) :
|
||||
chunk as unknown as StreamEvent
|
||||
} catch (e) {
|
||||
console.error('Failed to parse stream event:', e)
|
||||
continue
|
||||
}
|
||||
try {
|
||||
const command = new ConverseStreamCommand(payload)
|
||||
const response = await this.client.send(command)
|
||||
|
||||
// Handle metadata events first
|
||||
if (streamEvent.metadata?.usage) {
|
||||
yield {
|
||||
type: "usage",
|
||||
inputTokens: streamEvent.metadata.usage.inputTokens || 0,
|
||||
outputTokens: streamEvent.metadata.usage.outputTokens || 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if (!response.stream) {
|
||||
throw new Error("No stream available in the response")
|
||||
}
|
||||
|
||||
// Handle message start
|
||||
if (streamEvent.messageStart) {
|
||||
continue
|
||||
}
|
||||
for await (const chunk of response.stream) {
|
||||
// Parse the chunk as JSON if it's a string (for tests)
|
||||
let streamEvent: StreamEvent
|
||||
try {
|
||||
streamEvent = typeof chunk === "string" ? JSON.parse(chunk) : (chunk as unknown as StreamEvent)
|
||||
} catch (e) {
|
||||
console.error("Failed to parse stream event:", e)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle content blocks
|
||||
if (streamEvent.contentBlockStart?.start?.text) {
|
||||
yield {
|
||||
type: "text",
|
||||
text: streamEvent.contentBlockStart.start.text
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Handle metadata events first
|
||||
if (streamEvent.metadata?.usage) {
|
||||
yield {
|
||||
type: "usage",
|
||||
inputTokens: streamEvent.metadata.usage.inputTokens || 0,
|
||||
outputTokens: streamEvent.metadata.usage.outputTokens || 0,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle content deltas
|
||||
if (streamEvent.contentBlockDelta?.delta?.text) {
|
||||
yield {
|
||||
type: "text",
|
||||
text: streamEvent.contentBlockDelta.delta.text
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Handle message start
|
||||
if (streamEvent.messageStart) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle message stop
|
||||
if (streamEvent.messageStop) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Handle content blocks
|
||||
if (streamEvent.contentBlockStart?.start?.text) {
|
||||
yield {
|
||||
type: "text",
|
||||
text: streamEvent.contentBlockStart.start.text,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
} catch (error: unknown) {
|
||||
console.error('Bedrock Runtime API Error:', error)
|
||||
// Only access stack if error is an Error object
|
||||
if (error instanceof Error) {
|
||||
console.error('Error stack:', error.stack)
|
||||
yield {
|
||||
type: "text",
|
||||
text: `Error: ${error.message}`
|
||||
}
|
||||
yield {
|
||||
type: "usage",
|
||||
inputTokens: 0,
|
||||
outputTokens: 0
|
||||
}
|
||||
throw error
|
||||
} else {
|
||||
const unknownError = new Error("An unknown error occurred")
|
||||
yield {
|
||||
type: "text",
|
||||
text: unknownError.message
|
||||
}
|
||||
yield {
|
||||
type: "usage",
|
||||
inputTokens: 0,
|
||||
outputTokens: 0
|
||||
}
|
||||
throw unknownError
|
||||
}
|
||||
}
|
||||
}
|
||||
// Handle content deltas
|
||||
if (streamEvent.contentBlockDelta?.delta?.text) {
|
||||
yield {
|
||||
type: "text",
|
||||
text: streamEvent.contentBlockDelta.delta.text,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
getModel(): { id: BedrockModelId | string; info: ModelInfo } {
|
||||
const modelId = this.options.apiModelId
|
||||
if (modelId) {
|
||||
// For tests, allow any model ID
|
||||
if (process.env.NODE_ENV === 'test') {
|
||||
return {
|
||||
id: modelId,
|
||||
info: {
|
||||
maxTokens: 5000,
|
||||
contextWindow: 128_000,
|
||||
supportsPromptCache: false
|
||||
}
|
||||
}
|
||||
}
|
||||
// For production, validate against known models
|
||||
if (modelId in bedrockModels) {
|
||||
const id = modelId as BedrockModelId
|
||||
return { id, info: bedrockModels[id] }
|
||||
}
|
||||
}
|
||||
return {
|
||||
id: bedrockDefaultModelId,
|
||||
info: bedrockModels[bedrockDefaultModelId]
|
||||
}
|
||||
}
|
||||
// Handle message stop
|
||||
if (streamEvent.messageStop) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
console.error("Bedrock Runtime API Error:", error)
|
||||
// Only access stack if error is an Error object
|
||||
if (error instanceof Error) {
|
||||
console.error("Error stack:", error.stack)
|
||||
yield {
|
||||
type: "text",
|
||||
text: `Error: ${error.message}`,
|
||||
}
|
||||
yield {
|
||||
type: "usage",
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
}
|
||||
throw error
|
||||
} else {
|
||||
const unknownError = new Error("An unknown error occurred")
|
||||
yield {
|
||||
type: "text",
|
||||
text: unknownError.message,
|
||||
}
|
||||
yield {
|
||||
type: "usage",
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
}
|
||||
throw unknownError
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async completePrompt(prompt: string): Promise<string> {
|
||||
try {
|
||||
const modelConfig = this.getModel()
|
||||
|
||||
// Handle cross-region inference
|
||||
let modelId: string
|
||||
if (this.options.awsUseCrossRegionInference) {
|
||||
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
|
||||
switch (regionPrefix) {
|
||||
case "us-":
|
||||
modelId = `us.${modelConfig.id}`
|
||||
break
|
||||
case "eu-":
|
||||
modelId = `eu.${modelConfig.id}`
|
||||
break
|
||||
default:
|
||||
modelId = modelConfig.id
|
||||
break
|
||||
}
|
||||
} else {
|
||||
modelId = modelConfig.id
|
||||
}
|
||||
getModel(): { id: BedrockModelId | string; info: ModelInfo } {
|
||||
const modelId = this.options.apiModelId
|
||||
if (modelId) {
|
||||
// For tests, allow any model ID
|
||||
if (process.env.NODE_ENV === "test") {
|
||||
return {
|
||||
id: modelId,
|
||||
info: {
|
||||
maxTokens: 5000,
|
||||
contextWindow: 128_000,
|
||||
supportsPromptCache: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
// For production, validate against known models
|
||||
if (modelId in bedrockModels) {
|
||||
const id = modelId as BedrockModelId
|
||||
return { id, info: bedrockModels[id] }
|
||||
}
|
||||
}
|
||||
return {
|
||||
id: bedrockDefaultModelId,
|
||||
info: bedrockModels[bedrockDefaultModelId],
|
||||
}
|
||||
}
|
||||
|
||||
const payload = {
|
||||
modelId,
|
||||
messages: convertToBedrockConverseMessages([{
|
||||
role: "user",
|
||||
content: prompt
|
||||
}]),
|
||||
inferenceConfig: {
|
||||
maxTokens: modelConfig.info.maxTokens || 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1
|
||||
}
|
||||
}
|
||||
async completePrompt(prompt: string): Promise<string> {
|
||||
try {
|
||||
const modelConfig = this.getModel()
|
||||
|
||||
const command = new ConverseCommand(payload)
|
||||
const response = await this.client.send(command)
|
||||
// Handle cross-region inference
|
||||
let modelId: string
|
||||
if (this.options.awsUseCrossRegionInference) {
|
||||
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
|
||||
switch (regionPrefix) {
|
||||
case "us-":
|
||||
modelId = `us.${modelConfig.id}`
|
||||
break
|
||||
case "eu-":
|
||||
modelId = `eu.${modelConfig.id}`
|
||||
break
|
||||
default:
|
||||
modelId = modelConfig.id
|
||||
break
|
||||
}
|
||||
} else {
|
||||
modelId = modelConfig.id
|
||||
}
|
||||
|
||||
if (response.output && response.output instanceof Uint8Array) {
|
||||
try {
|
||||
const outputStr = new TextDecoder().decode(response.output)
|
||||
const output = JSON.parse(outputStr)
|
||||
if (output.content) {
|
||||
return output.content
|
||||
}
|
||||
} catch (parseError) {
|
||||
console.error('Failed to parse Bedrock response:', parseError)
|
||||
}
|
||||
}
|
||||
return ''
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`Bedrock completion error: ${error.message}`)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
const payload = {
|
||||
modelId,
|
||||
messages: convertToBedrockConverseMessages([
|
||||
{
|
||||
role: "user",
|
||||
content: prompt,
|
||||
},
|
||||
]),
|
||||
inferenceConfig: {
|
||||
maxTokens: modelConfig.info.maxTokens || 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1,
|
||||
},
|
||||
}
|
||||
|
||||
const command = new ConverseCommand(payload)
|
||||
const response = await this.client.send(command)
|
||||
|
||||
if (response.output && response.output instanceof Uint8Array) {
|
||||
try {
|
||||
const outputStr = new TextDecoder().decode(response.output)
|
||||
const output = JSON.parse(outputStr)
|
||||
if (output.content) {
|
||||
return output.content
|
||||
}
|
||||
} catch (parseError) {
|
||||
console.error("Failed to parse Bedrock response:", parseError)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`Bedrock completion error: ${error.message}`)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user