Add AWS cross-region inference toggle

This commit is contained in:
Saoud Rizwan
2024-11-07 13:51:13 -05:00
parent 2eb11aadc7
commit ad29ff2a03
4 changed files with 36 additions and 1 deletions

View File

@@ -25,8 +25,28 @@ export class AwsBedrockHandler implements ApiHandler {
} }
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
// cross region inference requires prefixing the model id with the region
let modelId: string
if (this.options.awsUseCrossRegionInference) {
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
switch (regionPrefix) {
case "us-":
modelId = `us.${this.getModel().id}`
break
case "eu-":
modelId = `eu.${this.getModel().id}`
break
default:
// cross region inference is not supported in this region, falling back to default model
modelId = this.getModel().id
break
}
} else {
modelId = this.getModel().id
}
const stream = await this.client.messages.create({ const stream = await this.client.messages.create({
model: this.getModel().id, model: modelId,
max_tokens: this.getModel().info.maxTokens || 8192, max_tokens: this.getModel().info.maxTokens || 8192,
temperature: 0, temperature: 0,
system: systemPrompt, system: systemPrompt,

View File

@@ -40,6 +40,7 @@ type GlobalStateKey =
| "apiProvider" | "apiProvider"
| "apiModelId" | "apiModelId"
| "awsRegion" | "awsRegion"
| "awsUseCrossRegionInference"
| "vertexProjectId" | "vertexProjectId"
| "vertexRegion" | "vertexRegion"
| "lastShownAnnouncementId" | "lastShownAnnouncementId"
@@ -350,6 +351,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
awsSecretKey, awsSecretKey,
awsSessionToken, awsSessionToken,
awsRegion, awsRegion,
awsUseCrossRegionInference,
vertexProjectId, vertexProjectId,
vertexRegion, vertexRegion,
openAiBaseUrl, openAiBaseUrl,
@@ -372,6 +374,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
await this.storeSecret("awsSecretKey", awsSecretKey) await this.storeSecret("awsSecretKey", awsSecretKey)
await this.storeSecret("awsSessionToken", awsSessionToken) await this.storeSecret("awsSessionToken", awsSessionToken)
await this.updateGlobalState("awsRegion", awsRegion) await this.updateGlobalState("awsRegion", awsRegion)
await this.updateGlobalState("awsUseCrossRegionInference", awsUseCrossRegionInference)
await this.updateGlobalState("vertexProjectId", vertexProjectId) await this.updateGlobalState("vertexProjectId", vertexProjectId)
await this.updateGlobalState("vertexRegion", vertexRegion) await this.updateGlobalState("vertexRegion", vertexRegion)
await this.updateGlobalState("openAiBaseUrl", openAiBaseUrl) await this.updateGlobalState("openAiBaseUrl", openAiBaseUrl)
@@ -824,6 +827,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
awsSecretKey, awsSecretKey,
awsSessionToken, awsSessionToken,
awsRegion, awsRegion,
awsUseCrossRegionInference,
vertexProjectId, vertexProjectId,
vertexRegion, vertexRegion,
openAiBaseUrl, openAiBaseUrl,
@@ -850,6 +854,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
this.getSecret("awsSecretKey") as Promise<string | undefined>, this.getSecret("awsSecretKey") as Promise<string | undefined>,
this.getSecret("awsSessionToken") as Promise<string | undefined>, this.getSecret("awsSessionToken") as Promise<string | undefined>,
this.getGlobalState("awsRegion") as Promise<string | undefined>, this.getGlobalState("awsRegion") as Promise<string | undefined>,
this.getGlobalState("awsUseCrossRegionInference") as Promise<boolean | undefined>,
this.getGlobalState("vertexProjectId") as Promise<string | undefined>, this.getGlobalState("vertexProjectId") as Promise<string | undefined>,
this.getGlobalState("vertexRegion") as Promise<string | undefined>, this.getGlobalState("vertexRegion") as Promise<string | undefined>,
this.getGlobalState("openAiBaseUrl") as Promise<string | undefined>, this.getGlobalState("openAiBaseUrl") as Promise<string | undefined>,
@@ -893,6 +898,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
awsSecretKey, awsSecretKey,
awsSessionToken, awsSessionToken,
awsRegion, awsRegion,
awsUseCrossRegionInference,
vertexProjectId, vertexProjectId,
vertexRegion, vertexRegion,
openAiBaseUrl, openAiBaseUrl,

View File

@@ -19,6 +19,7 @@ export interface ApiHandlerOptions {
awsSecretKey?: string awsSecretKey?: string
awsSessionToken?: string awsSessionToken?: string
awsRegion?: string awsRegion?: string
awsUseCrossRegionInference?: boolean
vertexProjectId?: string vertexProjectId?: string
vertexRegion?: string vertexRegion?: string
openAiBaseUrl?: string openAiBaseUrl?: string

View File

@@ -307,6 +307,14 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
{/* <VSCodeOption value="us-gov-east-1">us-gov-east-1</VSCodeOption> */} {/* <VSCodeOption value="us-gov-east-1">us-gov-east-1</VSCodeOption> */}
</VSCodeDropdown> </VSCodeDropdown>
</div> </div>
<VSCodeCheckbox
checked={apiConfiguration?.awsUseCrossRegionInference || false}
onChange={(e: any) => {
const isChecked = e.target.checked === true
setApiConfiguration({ ...apiConfiguration, awsUseCrossRegionInference: isChecked })
}}>
Use cross-region inference
</VSCodeCheckbox>
<p <p
style={{ style={{
fontSize: "12px", fontSize: "12px",