mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-21 21:01:06 -05:00
Add non-streaming completePrompt to all providers
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import { BedrockRuntimeClient, ConverseStreamCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime"
|
||||
import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { ApiHandler } from "../"
|
||||
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
|
||||
import { ApiStream } from "../transform/stream"
|
||||
import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../transform/bedrock-converse-format"
|
||||
@@ -38,7 +38,7 @@ export interface StreamEvent {
|
||||
};
|
||||
}
|
||||
|
||||
export class AwsBedrockHandler implements ApiHandler {
|
||||
export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: BedrockRuntimeClient
|
||||
|
||||
@@ -199,7 +199,7 @@ export class AwsBedrockHandler implements ApiHandler {
|
||||
if (modelId) {
|
||||
// For tests, allow any model ID
|
||||
if (process.env.NODE_ENV === 'test') {
|
||||
return {
|
||||
return {
|
||||
id: modelId,
|
||||
info: {
|
||||
maxTokens: 5000,
|
||||
@@ -214,9 +214,68 @@ export class AwsBedrockHandler implements ApiHandler {
|
||||
return { id, info: bedrockModels[id] }
|
||||
}
|
||||
}
|
||||
return {
|
||||
id: bedrockDefaultModelId,
|
||||
info: bedrockModels[bedrockDefaultModelId]
|
||||
return {
|
||||
id: bedrockDefaultModelId,
|
||||
info: bedrockModels[bedrockDefaultModelId]
|
||||
}
|
||||
}
|
||||
|
||||
async completePrompt(prompt: string): Promise<string> {
|
||||
try {
|
||||
const modelConfig = this.getModel()
|
||||
|
||||
// Handle cross-region inference
|
||||
let modelId: string
|
||||
if (this.options.awsUseCrossRegionInference) {
|
||||
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
|
||||
switch (regionPrefix) {
|
||||
case "us-":
|
||||
modelId = `us.${modelConfig.id}`
|
||||
break
|
||||
case "eu-":
|
||||
modelId = `eu.${modelConfig.id}`
|
||||
break
|
||||
default:
|
||||
modelId = modelConfig.id
|
||||
break
|
||||
}
|
||||
} else {
|
||||
modelId = modelConfig.id
|
||||
}
|
||||
|
||||
const payload = {
|
||||
modelId,
|
||||
messages: convertToBedrockConverseMessages([{
|
||||
role: "user",
|
||||
content: prompt
|
||||
}]),
|
||||
inferenceConfig: {
|
||||
maxTokens: modelConfig.info.maxTokens || 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1
|
||||
}
|
||||
}
|
||||
|
||||
const command = new ConverseCommand(payload)
|
||||
const response = await this.client.send(command)
|
||||
|
||||
if (response.output && response.output instanceof Uint8Array) {
|
||||
try {
|
||||
const outputStr = new TextDecoder().decode(response.output)
|
||||
const output = JSON.parse(outputStr)
|
||||
if (output.content) {
|
||||
return output.content
|
||||
}
|
||||
} catch (parseError) {
|
||||
console.error('Failed to parse Bedrock response:', parseError)
|
||||
}
|
||||
}
|
||||
return ''
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`Bedrock completion error: ${error.message}`)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user