From 7a61e6ab74b9ad367ac5a7f271731dfceebae0ba Mon Sep 17 00:00:00 2001 From: Lunchb0ne Date: Thu, 16 Jan 2025 18:01:49 +0000 Subject: [PATCH 1/2] Support AWS profile to configure Bedrock Authentication Added support for configurations under ~/.aws/credentials or ~/.aws/config. --- .changeset/afraid-pillows-kiss.md | 5 ++ src/api/providers/__tests__/bedrock.test.ts | 68 +++++++++++++++++ src/api/providers/bedrock.ts | 11 ++- src/core/webview/ClineProvider.ts | 12 +++ src/shared/api.ts | 2 + .../src/components/settings/ApiOptions.tsx | 74 +++++++++++++------ 6 files changed, 145 insertions(+), 27 deletions(-) create mode 100644 .changeset/afraid-pillows-kiss.md diff --git a/.changeset/afraid-pillows-kiss.md b/.changeset/afraid-pillows-kiss.md new file mode 100644 index 0000000..31e8667 --- /dev/null +++ b/.changeset/afraid-pillows-kiss.md @@ -0,0 +1,5 @@ +--- +"roo-cline": minor +--- + +Added suport for configuring Bedrock provider with AWS Profiles. Useful for users with SSO or other integrations who don't have access to long term credentials. diff --git a/src/api/providers/__tests__/bedrock.test.ts b/src/api/providers/__tests__/bedrock.test.ts index e8c1a44..357e41c 100644 --- a/src/api/providers/__tests__/bedrock.test.ts +++ b/src/api/providers/__tests__/bedrock.test.ts @@ -1,7 +1,16 @@ +// Mock AWS SDK credential providers +jest.mock("@aws-sdk/credential-providers", () => ({ + fromIni: jest.fn().mockReturnValue({ + accessKeyId: "profile-access-key", + secretAccessKey: "profile-secret-key", + }), +})) + import { AwsBedrockHandler } from "../bedrock" import { MessageContent } from "../../../shared/api" import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime" import { Anthropic } from "@anthropic-ai/sdk" +import { fromIni } from "@aws-sdk/credential-providers" describe("AwsBedrockHandler", () => { let handler: AwsBedrockHandler @@ -30,6 +39,65 @@ describe("AwsBedrockHandler", () => { }) expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler) }) + + it("should initialize with AWS profile credentials", () => { + const handlerWithProfile = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsRegion: "us-east-1", + awsUseProfile: true, + awsProfile: "test-profile", + }) + expect(handlerWithProfile).toBeInstanceOf(AwsBedrockHandler) + expect(handlerWithProfile["options"].awsUseProfile).toBe(true) + expect(handlerWithProfile["options"].awsProfile).toBe("test-profile") + }) + + it("should initialize with AWS profile enabled but no profile set", () => { + const handlerWithoutProfile = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsRegion: "us-east-1", + awsUseProfile: true, + }) + expect(handlerWithoutProfile).toBeInstanceOf(AwsBedrockHandler) + expect(handlerWithoutProfile["options"].awsUseProfile).toBe(true) + expect(handlerWithoutProfile["options"].awsProfile).toBeUndefined() + }) + }) + + describe("AWS SDK client configuration", () => { + it("should configure client with profile credentials when profile mode is enabled", async () => { + // Import the fromIni function to mock it + jest.mock("@aws-sdk/credential-providers", () => ({ + fromIni: jest.fn().mockReturnValue({ + accessKeyId: "profile-access-key", + secretAccessKey: "profile-secret-key", + }), + })) + + const handlerWithProfile = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsRegion: "us-east-1", + awsUseProfile: true, + awsProfile: "test-profile", + }) + + // Mock a simple API call to verify credentials are used + const mockResponse = { + output: new TextEncoder().encode(JSON.stringify({ content: "test" })), + } + const mockSend = jest.fn().mockResolvedValue(mockResponse) + handlerWithProfile["client"] = { + send: mockSend, + } as unknown as BedrockRuntimeClient + + await handlerWithProfile.completePrompt("test") + + // Verify the client was configured with profile credentials + expect(mockSend).toHaveBeenCalled() + expect(fromIni).toHaveBeenCalledWith({ + profile: "test-profile", + }) + }) }) describe("createMessage", () => { diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 87591b7..0e90c2b 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -4,6 +4,7 @@ import { ConverseCommand, BedrockRuntimeClientConfig, } from "@aws-sdk/client-bedrock-runtime" +import { fromIni } from "@aws-sdk/credential-providers" import { Anthropic } from "@anthropic-ai/sdk" import { ApiHandler, SingleCompletionHandler } from "../" import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api" @@ -50,13 +51,17 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler { constructor(options: ApiHandlerOptions) { this.options = options - // Only include credentials if they actually exist const clientConfig: BedrockRuntimeClientConfig = { region: this.options.awsRegion || "us-east-1", } - if (this.options.awsAccessKey && this.options.awsSecretKey) { - // Create credentials object with all properties at once + if (this.options.awsUseProfile && this.options.awsProfile) { + // Use profile-based credentials if enabled and profile is set + clientConfig.credentials = fromIni({ + profile: this.options.awsProfile, + }) + } else if (this.options.awsAccessKey && this.options.awsSecretKey) { + // Use direct credentials if provided clientConfig.credentials = { accessKeyId: this.options.awsAccessKey, secretAccessKey: this.options.awsSecretKey, diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 0a775e2..aa4cc61 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -56,6 +56,8 @@ type GlobalStateKey = | "glamaModelInfo" | "awsRegion" | "awsUseCrossRegionInference" + | "awsProfile" + | "awsUseProfile" | "vertexProjectId" | "vertexRegion" | "lastShownAnnouncementId" @@ -1147,6 +1149,8 @@ export class ClineProvider implements vscode.WebviewViewProvider { awsSessionToken, awsRegion, awsUseCrossRegionInference, + awsProfile, + awsUseProfile, vertexProjectId, vertexRegion, openAiBaseUrl, @@ -1180,6 +1184,8 @@ export class ClineProvider implements vscode.WebviewViewProvider { await this.storeSecret("awsSessionToken", awsSessionToken) await this.updateGlobalState("awsRegion", awsRegion) await this.updateGlobalState("awsUseCrossRegionInference", awsUseCrossRegionInference) + await this.updateGlobalState("awsProfile", awsProfile) + await this.updateGlobalState("awsUseProfile", awsUseProfile) await this.updateGlobalState("vertexProjectId", vertexProjectId) await this.updateGlobalState("vertexRegion", vertexRegion) await this.updateGlobalState("openAiBaseUrl", openAiBaseUrl) @@ -1795,6 +1801,8 @@ export class ClineProvider implements vscode.WebviewViewProvider { awsSessionToken, awsRegion, awsUseCrossRegionInference, + awsProfile, + awsUseProfile, vertexProjectId, vertexRegion, openAiBaseUrl, @@ -1857,6 +1865,8 @@ export class ClineProvider implements vscode.WebviewViewProvider { this.getSecret("awsSessionToken") as Promise, this.getGlobalState("awsRegion") as Promise, this.getGlobalState("awsUseCrossRegionInference") as Promise, + this.getGlobalState("awsProfile") as Promise, + this.getGlobalState("awsUseProfile") as Promise, this.getGlobalState("vertexProjectId") as Promise, this.getGlobalState("vertexRegion") as Promise, this.getGlobalState("openAiBaseUrl") as Promise, @@ -1936,6 +1946,8 @@ export class ClineProvider implements vscode.WebviewViewProvider { awsSessionToken, awsRegion, awsUseCrossRegionInference, + awsProfile, + awsUseProfile, vertexProjectId, vertexRegion, openAiBaseUrl, diff --git a/src/shared/api.ts b/src/shared/api.ts index 4fd25ba..a39e010 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -33,6 +33,8 @@ export interface ApiHandlerOptions { awsUseCrossRegionInference?: boolean awsUsePromptCache?: boolean awspromptCacheId?: string + awsProfile?: string + awsUseProfile?: boolean vertexProjectId?: string vertexRegion?: string openAiBaseUrl?: string diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 8e6fe42..b7ab5a3 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -342,30 +342,56 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = {selectedProvider === "bedrock" && (
- - AWS Access Key - - - AWS Secret Key - - - AWS Session Token - + { + const value = (e.target as HTMLInputElement)?.value + const useProfile = value === "profile" + handleInputChange("awsUseProfile")({ + target: { value: useProfile }, + }) + }}> + AWS Credentials + AWS Profile + + {/* AWS Profile Config Block */} + {apiConfiguration?.awsUseProfile ? ( + + AWS Profile Name + + ) : ( + <> + {/* AWS Credentials Config Block */} + + AWS Access Key + + + AWS Secret Key + + + AWS Session Token + + + )}