mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-20 04:11:10 -05:00
Merge pull request #392 from Lunchb0ne/aws-profile-support
Add support for using AWS Profile
This commit is contained in:
5
.changeset/afraid-pillows-kiss.md
Normal file
5
.changeset/afraid-pillows-kiss.md
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
"roo-cline": minor
|
||||||
|
---
|
||||||
|
|
||||||
|
Added suport for configuring Bedrock provider with AWS Profiles. Useful for users with SSO or other integrations who don't have access to long term credentials.
|
||||||
@@ -1,7 +1,16 @@
|
|||||||
|
// Mock AWS SDK credential providers
|
||||||
|
jest.mock("@aws-sdk/credential-providers", () => ({
|
||||||
|
fromIni: jest.fn().mockReturnValue({
|
||||||
|
accessKeyId: "profile-access-key",
|
||||||
|
secretAccessKey: "profile-secret-key",
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
|
||||||
import { AwsBedrockHandler } from "../bedrock"
|
import { AwsBedrockHandler } from "../bedrock"
|
||||||
import { MessageContent } from "../../../shared/api"
|
import { MessageContent } from "../../../shared/api"
|
||||||
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
|
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
|
||||||
import { Anthropic } from "@anthropic-ai/sdk"
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
|
import { fromIni } from "@aws-sdk/credential-providers"
|
||||||
|
|
||||||
describe("AwsBedrockHandler", () => {
|
describe("AwsBedrockHandler", () => {
|
||||||
let handler: AwsBedrockHandler
|
let handler: AwsBedrockHandler
|
||||||
@@ -30,6 +39,57 @@ describe("AwsBedrockHandler", () => {
|
|||||||
})
|
})
|
||||||
expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler)
|
expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it("should initialize with AWS profile credentials", () => {
|
||||||
|
const handlerWithProfile = new AwsBedrockHandler({
|
||||||
|
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
awsRegion: "us-east-1",
|
||||||
|
awsUseProfile: true,
|
||||||
|
awsProfile: "test-profile",
|
||||||
|
})
|
||||||
|
expect(handlerWithProfile).toBeInstanceOf(AwsBedrockHandler)
|
||||||
|
expect(handlerWithProfile["options"].awsUseProfile).toBe(true)
|
||||||
|
expect(handlerWithProfile["options"].awsProfile).toBe("test-profile")
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should initialize with AWS profile enabled but no profile set", () => {
|
||||||
|
const handlerWithoutProfile = new AwsBedrockHandler({
|
||||||
|
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
awsRegion: "us-east-1",
|
||||||
|
awsUseProfile: true,
|
||||||
|
})
|
||||||
|
expect(handlerWithoutProfile).toBeInstanceOf(AwsBedrockHandler)
|
||||||
|
expect(handlerWithoutProfile["options"].awsUseProfile).toBe(true)
|
||||||
|
expect(handlerWithoutProfile["options"].awsProfile).toBeUndefined()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("AWS SDK client configuration", () => {
|
||||||
|
it("should configure client with profile credentials when profile mode is enabled", async () => {
|
||||||
|
const handlerWithProfile = new AwsBedrockHandler({
|
||||||
|
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
awsRegion: "us-east-1",
|
||||||
|
awsUseProfile: true,
|
||||||
|
awsProfile: "test-profile",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Mock a simple API call to verify credentials are used
|
||||||
|
const mockResponse = {
|
||||||
|
output: new TextEncoder().encode(JSON.stringify({ content: "test" })),
|
||||||
|
}
|
||||||
|
const mockSend = jest.fn().mockResolvedValue(mockResponse)
|
||||||
|
handlerWithProfile["client"] = {
|
||||||
|
send: mockSend,
|
||||||
|
} as unknown as BedrockRuntimeClient
|
||||||
|
|
||||||
|
await handlerWithProfile.completePrompt("test")
|
||||||
|
|
||||||
|
// Verify the client was configured with profile credentials
|
||||||
|
expect(mockSend).toHaveBeenCalled()
|
||||||
|
expect(fromIni).toHaveBeenCalledWith({
|
||||||
|
profile: "test-profile",
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe("createMessage", () => {
|
describe("createMessage", () => {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import {
|
|||||||
ConverseCommand,
|
ConverseCommand,
|
||||||
BedrockRuntimeClientConfig,
|
BedrockRuntimeClientConfig,
|
||||||
} from "@aws-sdk/client-bedrock-runtime"
|
} from "@aws-sdk/client-bedrock-runtime"
|
||||||
|
import { fromIni } from "@aws-sdk/credential-providers"
|
||||||
import { Anthropic } from "@anthropic-ai/sdk"
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import { ApiHandler, SingleCompletionHandler } from "../"
|
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||||
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
|
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
|
||||||
@@ -50,13 +51,17 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
constructor(options: ApiHandlerOptions) {
|
constructor(options: ApiHandlerOptions) {
|
||||||
this.options = options
|
this.options = options
|
||||||
|
|
||||||
// Only include credentials if they actually exist
|
|
||||||
const clientConfig: BedrockRuntimeClientConfig = {
|
const clientConfig: BedrockRuntimeClientConfig = {
|
||||||
region: this.options.awsRegion || "us-east-1",
|
region: this.options.awsRegion || "us-east-1",
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this.options.awsAccessKey && this.options.awsSecretKey) {
|
if (this.options.awsUseProfile && this.options.awsProfile) {
|
||||||
// Create credentials object with all properties at once
|
// Use profile-based credentials if enabled and profile is set
|
||||||
|
clientConfig.credentials = fromIni({
|
||||||
|
profile: this.options.awsProfile,
|
||||||
|
})
|
||||||
|
} else if (this.options.awsAccessKey && this.options.awsSecretKey) {
|
||||||
|
// Use direct credentials if provided
|
||||||
clientConfig.credentials = {
|
clientConfig.credentials = {
|
||||||
accessKeyId: this.options.awsAccessKey,
|
accessKeyId: this.options.awsAccessKey,
|
||||||
secretAccessKey: this.options.awsSecretKey,
|
secretAccessKey: this.options.awsSecretKey,
|
||||||
|
|||||||
@@ -69,6 +69,8 @@ type GlobalStateKey =
|
|||||||
| "glamaModelInfo"
|
| "glamaModelInfo"
|
||||||
| "awsRegion"
|
| "awsRegion"
|
||||||
| "awsUseCrossRegionInference"
|
| "awsUseCrossRegionInference"
|
||||||
|
| "awsProfile"
|
||||||
|
| "awsUseProfile"
|
||||||
| "vertexProjectId"
|
| "vertexProjectId"
|
||||||
| "vertexRegion"
|
| "vertexRegion"
|
||||||
| "lastShownAnnouncementId"
|
| "lastShownAnnouncementId"
|
||||||
@@ -1264,6 +1266,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
awsSessionToken,
|
awsSessionToken,
|
||||||
awsRegion,
|
awsRegion,
|
||||||
awsUseCrossRegionInference,
|
awsUseCrossRegionInference,
|
||||||
|
awsProfile,
|
||||||
|
awsUseProfile,
|
||||||
vertexProjectId,
|
vertexProjectId,
|
||||||
vertexRegion,
|
vertexRegion,
|
||||||
openAiBaseUrl,
|
openAiBaseUrl,
|
||||||
@@ -1299,6 +1303,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
await this.storeSecret("awsSessionToken", awsSessionToken)
|
await this.storeSecret("awsSessionToken", awsSessionToken)
|
||||||
await this.updateGlobalState("awsRegion", awsRegion)
|
await this.updateGlobalState("awsRegion", awsRegion)
|
||||||
await this.updateGlobalState("awsUseCrossRegionInference", awsUseCrossRegionInference)
|
await this.updateGlobalState("awsUseCrossRegionInference", awsUseCrossRegionInference)
|
||||||
|
await this.updateGlobalState("awsProfile", awsProfile)
|
||||||
|
await this.updateGlobalState("awsUseProfile", awsUseProfile)
|
||||||
await this.updateGlobalState("vertexProjectId", vertexProjectId)
|
await this.updateGlobalState("vertexProjectId", vertexProjectId)
|
||||||
await this.updateGlobalState("vertexRegion", vertexRegion)
|
await this.updateGlobalState("vertexRegion", vertexRegion)
|
||||||
await this.updateGlobalState("openAiBaseUrl", openAiBaseUrl)
|
await this.updateGlobalState("openAiBaseUrl", openAiBaseUrl)
|
||||||
@@ -1919,6 +1925,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
awsSessionToken,
|
awsSessionToken,
|
||||||
awsRegion,
|
awsRegion,
|
||||||
awsUseCrossRegionInference,
|
awsUseCrossRegionInference,
|
||||||
|
awsProfile,
|
||||||
|
awsUseProfile,
|
||||||
vertexProjectId,
|
vertexProjectId,
|
||||||
vertexRegion,
|
vertexRegion,
|
||||||
openAiBaseUrl,
|
openAiBaseUrl,
|
||||||
@@ -1985,6 +1993,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
this.getSecret("awsSessionToken") as Promise<string | undefined>,
|
this.getSecret("awsSessionToken") as Promise<string | undefined>,
|
||||||
this.getGlobalState("awsRegion") as Promise<string | undefined>,
|
this.getGlobalState("awsRegion") as Promise<string | undefined>,
|
||||||
this.getGlobalState("awsUseCrossRegionInference") as Promise<boolean | undefined>,
|
this.getGlobalState("awsUseCrossRegionInference") as Promise<boolean | undefined>,
|
||||||
|
this.getGlobalState("awsProfile") as Promise<string | undefined>,
|
||||||
|
this.getGlobalState("awsUseProfile") as Promise<boolean | undefined>,
|
||||||
this.getGlobalState("vertexProjectId") as Promise<string | undefined>,
|
this.getGlobalState("vertexProjectId") as Promise<string | undefined>,
|
||||||
this.getGlobalState("vertexRegion") as Promise<string | undefined>,
|
this.getGlobalState("vertexRegion") as Promise<string | undefined>,
|
||||||
this.getGlobalState("openAiBaseUrl") as Promise<string | undefined>,
|
this.getGlobalState("openAiBaseUrl") as Promise<string | undefined>,
|
||||||
@@ -2068,6 +2078,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
awsSessionToken,
|
awsSessionToken,
|
||||||
awsRegion,
|
awsRegion,
|
||||||
awsUseCrossRegionInference,
|
awsUseCrossRegionInference,
|
||||||
|
awsProfile,
|
||||||
|
awsUseProfile,
|
||||||
vertexProjectId,
|
vertexProjectId,
|
||||||
vertexRegion,
|
vertexRegion,
|
||||||
openAiBaseUrl,
|
openAiBaseUrl,
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ export interface ApiHandlerOptions {
|
|||||||
awsUseCrossRegionInference?: boolean
|
awsUseCrossRegionInference?: boolean
|
||||||
awsUsePromptCache?: boolean
|
awsUsePromptCache?: boolean
|
||||||
awspromptCacheId?: string
|
awspromptCacheId?: string
|
||||||
|
awsProfile?: string
|
||||||
|
awsUseProfile?: boolean
|
||||||
vertexProjectId?: string
|
vertexProjectId?: string
|
||||||
vertexRegion?: string
|
vertexRegion?: string
|
||||||
openAiBaseUrl?: string
|
openAiBaseUrl?: string
|
||||||
|
|||||||
@@ -340,30 +340,56 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) =
|
|||||||
|
|
||||||
{selectedProvider === "bedrock" && (
|
{selectedProvider === "bedrock" && (
|
||||||
<div style={{ display: "flex", flexDirection: "column", gap: 5 }}>
|
<div style={{ display: "flex", flexDirection: "column", gap: 5 }}>
|
||||||
<VSCodeTextField
|
<VSCodeRadioGroup
|
||||||
value={apiConfiguration?.awsAccessKey || ""}
|
value={apiConfiguration?.awsUseProfile ? "profile" : "credentials"}
|
||||||
style={{ width: "100%" }}
|
onChange={(e) => {
|
||||||
type="password"
|
const value = (e.target as HTMLInputElement)?.value
|
||||||
onInput={handleInputChange("awsAccessKey")}
|
const useProfile = value === "profile"
|
||||||
placeholder="Enter Access Key...">
|
handleInputChange("awsUseProfile")({
|
||||||
<span style={{ fontWeight: 500 }}>AWS Access Key</span>
|
target: { value: useProfile },
|
||||||
</VSCodeTextField>
|
})
|
||||||
<VSCodeTextField
|
}}>
|
||||||
value={apiConfiguration?.awsSecretKey || ""}
|
<VSCodeRadio value="credentials">AWS Credentials</VSCodeRadio>
|
||||||
style={{ width: "100%" }}
|
<VSCodeRadio value="profile">AWS Profile</VSCodeRadio>
|
||||||
type="password"
|
</VSCodeRadioGroup>
|
||||||
onInput={handleInputChange("awsSecretKey")}
|
{/* AWS Profile Config Block */}
|
||||||
placeholder="Enter Secret Key...">
|
{apiConfiguration?.awsUseProfile ? (
|
||||||
<span style={{ fontWeight: 500 }}>AWS Secret Key</span>
|
<VSCodeTextField
|
||||||
</VSCodeTextField>
|
value={apiConfiguration?.awsProfile || ""}
|
||||||
<VSCodeTextField
|
style={{ width: "100%" }}
|
||||||
value={apiConfiguration?.awsSessionToken || ""}
|
onInput={handleInputChange("awsProfile")}
|
||||||
style={{ width: "100%" }}
|
placeholder="Enter profile name">
|
||||||
type="password"
|
<span style={{ fontWeight: 500 }}>AWS Profile Name</span>
|
||||||
onInput={handleInputChange("awsSessionToken")}
|
</VSCodeTextField>
|
||||||
placeholder="Enter Session Token...">
|
) : (
|
||||||
<span style={{ fontWeight: 500 }}>AWS Session Token</span>
|
<>
|
||||||
</VSCodeTextField>
|
{/* AWS Credentials Config Block */}
|
||||||
|
<VSCodeTextField
|
||||||
|
value={apiConfiguration?.awsAccessKey || ""}
|
||||||
|
style={{ width: "100%" }}
|
||||||
|
type="password"
|
||||||
|
onInput={handleInputChange("awsAccessKey")}
|
||||||
|
placeholder="Enter Access Key...">
|
||||||
|
<span style={{ fontWeight: 500 }}>AWS Access Key</span>
|
||||||
|
</VSCodeTextField>
|
||||||
|
<VSCodeTextField
|
||||||
|
value={apiConfiguration?.awsSecretKey || ""}
|
||||||
|
style={{ width: "100%" }}
|
||||||
|
type="password"
|
||||||
|
onInput={handleInputChange("awsSecretKey")}
|
||||||
|
placeholder="Enter Secret Key...">
|
||||||
|
<span style={{ fontWeight: 500 }}>AWS Secret Key</span>
|
||||||
|
</VSCodeTextField>
|
||||||
|
<VSCodeTextField
|
||||||
|
value={apiConfiguration?.awsSessionToken || ""}
|
||||||
|
style={{ width: "100%" }}
|
||||||
|
type="password"
|
||||||
|
onInput={handleInputChange("awsSessionToken")}
|
||||||
|
placeholder="Enter Session Token...">
|
||||||
|
<span style={{ fontWeight: 500 }}>AWS Session Token</span>
|
||||||
|
</VSCodeTextField>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
<div className="dropdown-container">
|
<div className="dropdown-container">
|
||||||
<label htmlFor="aws-region-dropdown">
|
<label htmlFor="aws-region-dropdown">
|
||||||
<span style={{ fontWeight: 500 }}>AWS Region</span>
|
<span style={{ fontWeight: 500 }}>AWS Region</span>
|
||||||
|
|||||||
Reference in New Issue
Block a user