diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index ad6e8df..58f75ad 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -25,8 +25,28 @@ export class AwsBedrockHandler implements ApiHandler { } async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { + // cross region inference requires prefixing the model id with the region + let modelId: string + if (this.options.awsUseCrossRegionInference) { + let regionPrefix = (this.options.awsRegion || "").slice(0, 3) + switch (regionPrefix) { + case "us-": + modelId = `us.${this.getModel().id}` + break + case "eu-": + modelId = `eu.${this.getModel().id}` + break + default: + // cross region inference is not supported in this region, falling back to default model + modelId = this.getModel().id + break + } + } else { + modelId = this.getModel().id + } + const stream = await this.client.messages.create({ - model: this.getModel().id, + model: modelId, max_tokens: this.getModel().info.maxTokens || 8192, temperature: 0, system: systemPrompt, diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 6f19138..6d24ce8 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -40,6 +40,7 @@ type GlobalStateKey = | "apiProvider" | "apiModelId" | "awsRegion" + | "awsUseCrossRegionInference" | "vertexProjectId" | "vertexRegion" | "lastShownAnnouncementId" @@ -350,6 +351,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { awsSecretKey, awsSessionToken, awsRegion, + awsUseCrossRegionInference, vertexProjectId, vertexRegion, openAiBaseUrl, @@ -372,6 +374,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { await this.storeSecret("awsSecretKey", awsSecretKey) await this.storeSecret("awsSessionToken", awsSessionToken) await this.updateGlobalState("awsRegion", awsRegion) + await this.updateGlobalState("awsUseCrossRegionInference", awsUseCrossRegionInference) await this.updateGlobalState("vertexProjectId", vertexProjectId) await this.updateGlobalState("vertexRegion", vertexRegion) await this.updateGlobalState("openAiBaseUrl", openAiBaseUrl) @@ -824,6 +827,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { awsSecretKey, awsSessionToken, awsRegion, + awsUseCrossRegionInference, vertexProjectId, vertexRegion, openAiBaseUrl, @@ -850,6 +854,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { this.getSecret("awsSecretKey") as Promise, this.getSecret("awsSessionToken") as Promise, this.getGlobalState("awsRegion") as Promise, + this.getGlobalState("awsUseCrossRegionInference") as Promise, this.getGlobalState("vertexProjectId") as Promise, this.getGlobalState("vertexRegion") as Promise, this.getGlobalState("openAiBaseUrl") as Promise, @@ -893,6 +898,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { awsSecretKey, awsSessionToken, awsRegion, + awsUseCrossRegionInference, vertexProjectId, vertexRegion, openAiBaseUrl, diff --git a/src/shared/api.ts b/src/shared/api.ts index a466972..b2c5525 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -19,6 +19,7 @@ export interface ApiHandlerOptions { awsSecretKey?: string awsSessionToken?: string awsRegion?: string + awsUseCrossRegionInference?: boolean vertexProjectId?: string vertexRegion?: string openAiBaseUrl?: string diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 7e0fbef..f5e44ca 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -307,6 +307,14 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }: {/* us-gov-east-1 */} + { + const isChecked = e.target.checked === true + setApiConfiguration({ ...apiConfiguration, awsUseCrossRegionInference: isChecked }) + }}> + Use cross-region inference +