Merge pull request #392 from Lunchb0ne/aws-profile-support

Add support for using AWS Profile
This commit is contained in:
Matt Rubens
2025-01-24 09:14:40 -05:00
committed by GitHub
6 changed files with 137 additions and 27 deletions

View File

@@ -0,0 +1,5 @@
---
"roo-cline": minor
---
Added suport for configuring Bedrock provider with AWS Profiles. Useful for users with SSO or other integrations who don't have access to long term credentials.

View File

@@ -1,7 +1,16 @@
// Mock AWS SDK credential providers
jest.mock("@aws-sdk/credential-providers", () => ({
fromIni: jest.fn().mockReturnValue({
accessKeyId: "profile-access-key",
secretAccessKey: "profile-secret-key",
}),
}))
import { AwsBedrockHandler } from "../bedrock" import { AwsBedrockHandler } from "../bedrock"
import { MessageContent } from "../../../shared/api" import { MessageContent } from "../../../shared/api"
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime" import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
import { Anthropic } from "@anthropic-ai/sdk" import { Anthropic } from "@anthropic-ai/sdk"
import { fromIni } from "@aws-sdk/credential-providers"
describe("AwsBedrockHandler", () => { describe("AwsBedrockHandler", () => {
let handler: AwsBedrockHandler let handler: AwsBedrockHandler
@@ -30,6 +39,57 @@ describe("AwsBedrockHandler", () => {
}) })
expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler) expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler)
}) })
it("should initialize with AWS profile credentials", () => {
const handlerWithProfile = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsRegion: "us-east-1",
awsUseProfile: true,
awsProfile: "test-profile",
})
expect(handlerWithProfile).toBeInstanceOf(AwsBedrockHandler)
expect(handlerWithProfile["options"].awsUseProfile).toBe(true)
expect(handlerWithProfile["options"].awsProfile).toBe("test-profile")
})
it("should initialize with AWS profile enabled but no profile set", () => {
const handlerWithoutProfile = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsRegion: "us-east-1",
awsUseProfile: true,
})
expect(handlerWithoutProfile).toBeInstanceOf(AwsBedrockHandler)
expect(handlerWithoutProfile["options"].awsUseProfile).toBe(true)
expect(handlerWithoutProfile["options"].awsProfile).toBeUndefined()
})
})
describe("AWS SDK client configuration", () => {
it("should configure client with profile credentials when profile mode is enabled", async () => {
const handlerWithProfile = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsRegion: "us-east-1",
awsUseProfile: true,
awsProfile: "test-profile",
})
// Mock a simple API call to verify credentials are used
const mockResponse = {
output: new TextEncoder().encode(JSON.stringify({ content: "test" })),
}
const mockSend = jest.fn().mockResolvedValue(mockResponse)
handlerWithProfile["client"] = {
send: mockSend,
} as unknown as BedrockRuntimeClient
await handlerWithProfile.completePrompt("test")
// Verify the client was configured with profile credentials
expect(mockSend).toHaveBeenCalled()
expect(fromIni).toHaveBeenCalledWith({
profile: "test-profile",
})
})
}) })
describe("createMessage", () => { describe("createMessage", () => {

View File

@@ -4,6 +4,7 @@ import {
ConverseCommand, ConverseCommand,
BedrockRuntimeClientConfig, BedrockRuntimeClientConfig,
} from "@aws-sdk/client-bedrock-runtime" } from "@aws-sdk/client-bedrock-runtime"
import { fromIni } from "@aws-sdk/credential-providers"
import { Anthropic } from "@anthropic-ai/sdk" import { Anthropic } from "@anthropic-ai/sdk"
import { ApiHandler, SingleCompletionHandler } from "../" import { ApiHandler, SingleCompletionHandler } from "../"
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api" import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
@@ -50,13 +51,17 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
constructor(options: ApiHandlerOptions) { constructor(options: ApiHandlerOptions) {
this.options = options this.options = options
// Only include credentials if they actually exist
const clientConfig: BedrockRuntimeClientConfig = { const clientConfig: BedrockRuntimeClientConfig = {
region: this.options.awsRegion || "us-east-1", region: this.options.awsRegion || "us-east-1",
} }
if (this.options.awsAccessKey && this.options.awsSecretKey) { if (this.options.awsUseProfile && this.options.awsProfile) {
// Create credentials object with all properties at once // Use profile-based credentials if enabled and profile is set
clientConfig.credentials = fromIni({
profile: this.options.awsProfile,
})
} else if (this.options.awsAccessKey && this.options.awsSecretKey) {
// Use direct credentials if provided
clientConfig.credentials = { clientConfig.credentials = {
accessKeyId: this.options.awsAccessKey, accessKeyId: this.options.awsAccessKey,
secretAccessKey: this.options.awsSecretKey, secretAccessKey: this.options.awsSecretKey,

View File

@@ -69,6 +69,8 @@ type GlobalStateKey =
| "glamaModelInfo" | "glamaModelInfo"
| "awsRegion" | "awsRegion"
| "awsUseCrossRegionInference" | "awsUseCrossRegionInference"
| "awsProfile"
| "awsUseProfile"
| "vertexProjectId" | "vertexProjectId"
| "vertexRegion" | "vertexRegion"
| "lastShownAnnouncementId" | "lastShownAnnouncementId"
@@ -1264,6 +1266,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
awsSessionToken, awsSessionToken,
awsRegion, awsRegion,
awsUseCrossRegionInference, awsUseCrossRegionInference,
awsProfile,
awsUseProfile,
vertexProjectId, vertexProjectId,
vertexRegion, vertexRegion,
openAiBaseUrl, openAiBaseUrl,
@@ -1299,6 +1303,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
await this.storeSecret("awsSessionToken", awsSessionToken) await this.storeSecret("awsSessionToken", awsSessionToken)
await this.updateGlobalState("awsRegion", awsRegion) await this.updateGlobalState("awsRegion", awsRegion)
await this.updateGlobalState("awsUseCrossRegionInference", awsUseCrossRegionInference) await this.updateGlobalState("awsUseCrossRegionInference", awsUseCrossRegionInference)
await this.updateGlobalState("awsProfile", awsProfile)
await this.updateGlobalState("awsUseProfile", awsUseProfile)
await this.updateGlobalState("vertexProjectId", vertexProjectId) await this.updateGlobalState("vertexProjectId", vertexProjectId)
await this.updateGlobalState("vertexRegion", vertexRegion) await this.updateGlobalState("vertexRegion", vertexRegion)
await this.updateGlobalState("openAiBaseUrl", openAiBaseUrl) await this.updateGlobalState("openAiBaseUrl", openAiBaseUrl)
@@ -1919,6 +1925,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
awsSessionToken, awsSessionToken,
awsRegion, awsRegion,
awsUseCrossRegionInference, awsUseCrossRegionInference,
awsProfile,
awsUseProfile,
vertexProjectId, vertexProjectId,
vertexRegion, vertexRegion,
openAiBaseUrl, openAiBaseUrl,
@@ -1985,6 +1993,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
this.getSecret("awsSessionToken") as Promise<string | undefined>, this.getSecret("awsSessionToken") as Promise<string | undefined>,
this.getGlobalState("awsRegion") as Promise<string | undefined>, this.getGlobalState("awsRegion") as Promise<string | undefined>,
this.getGlobalState("awsUseCrossRegionInference") as Promise<boolean | undefined>, this.getGlobalState("awsUseCrossRegionInference") as Promise<boolean | undefined>,
this.getGlobalState("awsProfile") as Promise<string | undefined>,
this.getGlobalState("awsUseProfile") as Promise<boolean | undefined>,
this.getGlobalState("vertexProjectId") as Promise<string | undefined>, this.getGlobalState("vertexProjectId") as Promise<string | undefined>,
this.getGlobalState("vertexRegion") as Promise<string | undefined>, this.getGlobalState("vertexRegion") as Promise<string | undefined>,
this.getGlobalState("openAiBaseUrl") as Promise<string | undefined>, this.getGlobalState("openAiBaseUrl") as Promise<string | undefined>,
@@ -2068,6 +2078,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
awsSessionToken, awsSessionToken,
awsRegion, awsRegion,
awsUseCrossRegionInference, awsUseCrossRegionInference,
awsProfile,
awsUseProfile,
vertexProjectId, vertexProjectId,
vertexRegion, vertexRegion,
openAiBaseUrl, openAiBaseUrl,

View File

@@ -33,6 +33,8 @@ export interface ApiHandlerOptions {
awsUseCrossRegionInference?: boolean awsUseCrossRegionInference?: boolean
awsUsePromptCache?: boolean awsUsePromptCache?: boolean
awspromptCacheId?: string awspromptCacheId?: string
awsProfile?: string
awsUseProfile?: boolean
vertexProjectId?: string vertexProjectId?: string
vertexRegion?: string vertexRegion?: string
openAiBaseUrl?: string openAiBaseUrl?: string

View File

@@ -340,30 +340,56 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) =
{selectedProvider === "bedrock" && ( {selectedProvider === "bedrock" && (
<div style={{ display: "flex", flexDirection: "column", gap: 5 }}> <div style={{ display: "flex", flexDirection: "column", gap: 5 }}>
<VSCodeTextField <VSCodeRadioGroup
value={apiConfiguration?.awsAccessKey || ""} value={apiConfiguration?.awsUseProfile ? "profile" : "credentials"}
style={{ width: "100%" }} onChange={(e) => {
type="password" const value = (e.target as HTMLInputElement)?.value
onInput={handleInputChange("awsAccessKey")} const useProfile = value === "profile"
placeholder="Enter Access Key..."> handleInputChange("awsUseProfile")({
<span style={{ fontWeight: 500 }}>AWS Access Key</span> target: { value: useProfile },
</VSCodeTextField> })
<VSCodeTextField }}>
value={apiConfiguration?.awsSecretKey || ""} <VSCodeRadio value="credentials">AWS Credentials</VSCodeRadio>
style={{ width: "100%" }} <VSCodeRadio value="profile">AWS Profile</VSCodeRadio>
type="password" </VSCodeRadioGroup>
onInput={handleInputChange("awsSecretKey")} {/* AWS Profile Config Block */}
placeholder="Enter Secret Key..."> {apiConfiguration?.awsUseProfile ? (
<span style={{ fontWeight: 500 }}>AWS Secret Key</span> <VSCodeTextField
</VSCodeTextField> value={apiConfiguration?.awsProfile || ""}
<VSCodeTextField style={{ width: "100%" }}
value={apiConfiguration?.awsSessionToken || ""} onInput={handleInputChange("awsProfile")}
style={{ width: "100%" }} placeholder="Enter profile name">
type="password" <span style={{ fontWeight: 500 }}>AWS Profile Name</span>
onInput={handleInputChange("awsSessionToken")} </VSCodeTextField>
placeholder="Enter Session Token..."> ) : (
<span style={{ fontWeight: 500 }}>AWS Session Token</span> <>
</VSCodeTextField> {/* AWS Credentials Config Block */}
<VSCodeTextField
value={apiConfiguration?.awsAccessKey || ""}
style={{ width: "100%" }}
type="password"
onInput={handleInputChange("awsAccessKey")}
placeholder="Enter Access Key...">
<span style={{ fontWeight: 500 }}>AWS Access Key</span>
</VSCodeTextField>
<VSCodeTextField
value={apiConfiguration?.awsSecretKey || ""}
style={{ width: "100%" }}
type="password"
onInput={handleInputChange("awsSecretKey")}
placeholder="Enter Secret Key...">
<span style={{ fontWeight: 500 }}>AWS Secret Key</span>
</VSCodeTextField>
<VSCodeTextField
value={apiConfiguration?.awsSessionToken || ""}
style={{ width: "100%" }}
type="password"
onInput={handleInputChange("awsSessionToken")}
placeholder="Enter Session Token...">
<span style={{ fontWeight: 500 }}>AWS Session Token</span>
</VSCodeTextField>
</>
)}
<div className="dropdown-container"> <div className="dropdown-container">
<label htmlFor="aws-region-dropdown"> <label htmlFor="aws-region-dropdown">
<span style={{ fontWeight: 500 }}>AWS Region</span> <span style={{ fontWeight: 500 }}>AWS Region</span>