From 0badfa27066a3b2ed882cd818c68a9b45a0dcce0 Mon Sep 17 00:00:00 2001 From: Saoud Rizwan <7799382+saoudrizwan@users.noreply.github.com> Date: Tue, 3 Sep 2024 13:58:21 -0400 Subject: [PATCH] Add support for aws credentials file or environment variables, and session token --- src/api/bedrock.ts | 5 +++-- src/providers/ClaudeDevProvider.ts | 15 +++++++++++++-- src/shared/api.ts | 1 + webview-ui/src/components/ApiOptions.tsx | 19 +++++++++++-------- .../src/context/ExtensionStateContext.tsx | 2 +- webview-ui/src/utils/validate.ts | 6 +++--- 6 files changed, 32 insertions(+), 16 deletions(-) diff --git a/src/api/bedrock.ts b/src/api/bedrock.ts index 5a9db5f..d979a94 100644 --- a/src/api/bedrock.ts +++ b/src/api/bedrock.ts @@ -13,8 +13,9 @@ export class AwsBedrockHandler implements ApiHandler { this.client = new AnthropicBedrock({ // Authenticate by either providing the keys below or use the default AWS credential providers, such as // using ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables. - awsAccessKey: this.options.awsAccessKey, - awsSecretKey: this.options.awsSecretKey, + ...(this.options.awsAccessKey ? { awsAccessKey: this.options.awsAccessKey } : {}), + ...(this.options.awsSecretKey ? { awsSecretKey: this.options.awsSecretKey } : {}), + ...(this.options.awsSessionToken ? { awsSessionToken: this.options.awsSessionToken } : {}), // awsRegion changes the aws region to which the request is made. By default, we read AWS_REGION, // and if that's not present, we default to us-east-1. Note that we do not read ~/.aws/config for the region. diff --git a/src/providers/ClaudeDevProvider.ts b/src/providers/ClaudeDevProvider.ts index 5b3c2a8..9c02256 100644 --- a/src/providers/ClaudeDevProvider.ts +++ b/src/providers/ClaudeDevProvider.ts @@ -16,7 +16,7 @@ https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default https://github.com/KumarVariable/vscode-extension-sidebar-html/blob/master/src/customSidebarViewProvider.ts */ -type SecretKey = "apiKey" | "openRouterApiKey" | "awsAccessKey" | "awsSecretKey" +type SecretKey = "apiKey" | "openRouterApiKey" | "awsAccessKey" | "awsSecretKey" | "awsSessionToken" type GlobalStateKey = | "apiProvider" | "apiModelId" @@ -310,6 +310,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { openRouterApiKey, awsAccessKey, awsSecretKey, + awsSessionToken, awsRegion, vertexProjectId, vertexRegion, @@ -320,6 +321,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { await this.storeSecret("openRouterApiKey", openRouterApiKey) await this.storeSecret("awsAccessKey", awsAccessKey) await this.storeSecret("awsSecretKey", awsSecretKey) + await this.storeSecret("awsSessionToken", awsSessionToken) await this.updateGlobalState("awsRegion", awsRegion) await this.updateGlobalState("vertexProjectId", vertexProjectId) await this.updateGlobalState("vertexRegion", vertexRegion) @@ -609,6 +611,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { openRouterApiKey, awsAccessKey, awsSecretKey, + awsSessionToken, awsRegion, vertexProjectId, vertexRegion, @@ -623,6 +626,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { this.getSecret("openRouterApiKey") as Promise, this.getSecret("awsAccessKey") as Promise, this.getSecret("awsSecretKey") as Promise, + this.getSecret("awsSessionToken") as Promise, this.getGlobalState("awsRegion") as Promise, this.getGlobalState("vertexProjectId") as Promise, this.getGlobalState("vertexRegion") as Promise, @@ -654,6 +658,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { openRouterApiKey, awsAccessKey, awsSecretKey, + awsSessionToken, awsRegion, vertexProjectId, vertexRegion, @@ -728,7 +733,13 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { for (const key of this.context.globalState.keys()) { await this.context.globalState.update(key, undefined) } - const secretKeys: SecretKey[] = ["apiKey", "openRouterApiKey", "awsAccessKey", "awsSecretKey"] + const secretKeys: SecretKey[] = [ + "apiKey", + "openRouterApiKey", + "awsAccessKey", + "awsSecretKey", + "awsSessionToken", + ] for (const key of secretKeys) { await this.storeSecret(key, undefined) } diff --git a/src/shared/api.ts b/src/shared/api.ts index 88aa2b9..38e926f 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -6,6 +6,7 @@ export interface ApiHandlerOptions { openRouterApiKey?: string awsAccessKey?: string awsSecretKey?: string + awsSessionToken?: string awsRegion?: string vertexProjectId?: string vertexRegion?: string diff --git a/webview-ui/src/components/ApiOptions.tsx b/webview-ui/src/components/ApiOptions.tsx index 14bf84b..e0d657d 100644 --- a/webview-ui/src/components/ApiOptions.tsx +++ b/webview-ui/src/components/ApiOptions.tsx @@ -153,6 +153,14 @@ const ApiOptions: React.FC = ({ showModelOptions, apiErrorMessa placeholder="Enter Secret Key..."> AWS Secret Key + + AWS Session Token +
)} diff --git a/webview-ui/src/context/ExtensionStateContext.tsx b/webview-ui/src/context/ExtensionStateContext.tsx index 3f076b9..9d14424 100644 --- a/webview-ui/src/context/ExtensionStateContext.tsx +++ b/webview-ui/src/context/ExtensionStateContext.tsx @@ -31,7 +31,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode setState(message.state) const config = message.state?.apiConfiguration const hasKey = config - ? [config.apiKey, config.openRouterApiKey, config.awsAccessKey, config.vertexProjectId].some( + ? [config.apiKey, config.openRouterApiKey, config.awsRegion, config.vertexProjectId].some( (key) => key !== undefined ) : false diff --git a/webview-ui/src/utils/validate.ts b/webview-ui/src/utils/validate.ts index 4812495..a64a602 100644 --- a/webview-ui/src/utils/validate.ts +++ b/webview-ui/src/utils/validate.ts @@ -9,8 +9,8 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s } break case "bedrock": - if (!apiConfiguration.awsAccessKey || !apiConfiguration.awsSecretKey || !apiConfiguration.awsRegion) { - return "You must provide a valid AWS access key, secret key, and region." + if (!apiConfiguration.awsRegion) { + return "You must choose a region to use with AWS Bedrock." } break case "openrouter": @@ -26,4 +26,4 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s } } return undefined -} \ No newline at end of file +}