mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-20 04:11:10 -05:00
Prettier backfill
This commit is contained in:
@@ -1,20 +1,20 @@
|
|||||||
// Half-works to simplify the format but needs 'overwrite_changeset_changelog.py' in GHA to finish formatting
|
// Half-works to simplify the format but needs 'overwrite_changeset_changelog.py' in GHA to finish formatting
|
||||||
|
|
||||||
const getReleaseLine = async (changeset) => {
|
const getReleaseLine = async (changeset) => {
|
||||||
const [firstLine] = changeset.summary
|
const [firstLine] = changeset.summary
|
||||||
.split('\n')
|
.split("\n")
|
||||||
.map(l => l.trim())
|
.map((l) => l.trim())
|
||||||
.filter(Boolean);
|
.filter(Boolean)
|
||||||
return `- ${firstLine}`;
|
return `- ${firstLine}`
|
||||||
};
|
}
|
||||||
|
|
||||||
const getDependencyReleaseLine = async () => {
|
const getDependencyReleaseLine = async () => {
|
||||||
return '';
|
return ""
|
||||||
};
|
}
|
||||||
|
|
||||||
const changelogFunctions = {
|
const changelogFunctions = {
|
||||||
getReleaseLine,
|
getReleaseLine,
|
||||||
getDependencyReleaseLine,
|
getDependencyReleaseLine,
|
||||||
};
|
}
|
||||||
|
|
||||||
module.exports = changelogFunctions;
|
module.exports = changelogFunctions
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
{
|
{
|
||||||
"$schema": "https://unpkg.com/@changesets/config@3.0.4/schema.json",
|
"$schema": "https://unpkg.com/@changesets/config@3.0.4/schema.json",
|
||||||
"changelog": "./changelog-config.js",
|
"changelog": "./changelog-config.js",
|
||||||
"commit": false,
|
"commit": false,
|
||||||
"fixed": [],
|
"fixed": [],
|
||||||
"linked": [],
|
"linked": [],
|
||||||
"access": "restricted",
|
"access": "restricted",
|
||||||
"baseBranch": "main",
|
"baseBranch": "main",
|
||||||
"updateInternalDependencies": "patch",
|
"updateInternalDependencies": "patch",
|
||||||
"ignore": []
|
"ignore": []
|
||||||
}
|
}
|
||||||
|
|||||||
11
.github/pull_request_template.md
vendored
11
.github/pull_request_template.md
vendored
@@ -1,28 +1,37 @@
|
|||||||
<!-- **Note:** Consider creating PRs as a DRAFT. For early feedback and self-review. -->
|
<!-- **Note:** Consider creating PRs as a DRAFT. For early feedback and self-review. -->
|
||||||
|
|
||||||
## Description
|
## Description
|
||||||
|
|
||||||
## Type of change
|
## Type of change
|
||||||
|
|
||||||
<!-- Please ignore options that are not relevant -->
|
<!-- Please ignore options that are not relevant -->
|
||||||
|
|
||||||
- [ ] Bug fix (non-breaking change which fixes an issue)
|
- [ ] Bug fix (non-breaking change which fixes an issue)
|
||||||
- [ ] New feature
|
- [ ] New feature
|
||||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||||
- [ ] This change requires a documentation update
|
- [ ] This change requires a documentation update
|
||||||
|
|
||||||
## How Has This Been Tested?
|
## How Has This Been Tested?
|
||||||
|
|
||||||
<!-- Please describe the tests that you ran to verify your changes -->
|
<!-- Please describe the tests that you ran to verify your changes -->
|
||||||
|
|
||||||
## Checklist:
|
## Checklist:
|
||||||
|
|
||||||
<!-- Go over all the following points, and put an `x` in all the boxes that apply -->
|
<!-- Go over all the following points, and put an `x` in all the boxes that apply -->
|
||||||
|
|
||||||
- [ ] My code follows the patterns of this project
|
- [ ] My code follows the patterns of this project
|
||||||
- [ ] I have performed a self-review of my own code
|
- [ ] I have performed a self-review of my own code
|
||||||
- [ ] I have commented my code, particularly in hard-to-understand areas
|
- [ ] I have commented my code, particularly in hard-to-understand areas
|
||||||
- [ ] I have made corresponding changes to the documentation
|
- [ ] I have made corresponding changes to the documentation
|
||||||
|
|
||||||
## Additional context
|
## Additional context
|
||||||
|
|
||||||
<!-- Add any other context or screenshots about the pull request here -->
|
<!-- Add any other context or screenshots about the pull request here -->
|
||||||
|
|
||||||
## Related Issues
|
## Related Issues
|
||||||
|
|
||||||
<!-- List any related issues here. Use the GitHub issue linking syntax: #issue-number -->
|
<!-- List any related issues here. Use the GitHub issue linking syntax: #issue-number -->
|
||||||
|
|
||||||
## Reviewers
|
## Reviewers
|
||||||
<!-- @mention specific team members or individuals who should review this PR -->
|
|
||||||
|
<!-- @mention specific team members or individuals who should review this PR -->
|
||||||
|
|||||||
43
README.md
43
README.md
@@ -10,14 +10,15 @@ Hot off the heels of **v3.0** introducing Code, Architect, and Ask chat modes, o
|
|||||||
|
|
||||||
You can now tailor the **role definition** and **custom instructions** for every chat mode to perfectly fit your workflow. Want to adjust Architect mode to focus more on system scalability? Or tweak Ask mode for deeper research queries? Done. Plus, you can define these via **mode-specific `.clinerules-[mode]` files**. You’ll find all of this in the new **Prompts** tab in the top menu.
|
You can now tailor the **role definition** and **custom instructions** for every chat mode to perfectly fit your workflow. Want to adjust Architect mode to focus more on system scalability? Or tweak Ask mode for deeper research queries? Done. Plus, you can define these via **mode-specific `.clinerules-[mode]` files**. You’ll find all of this in the new **Prompts** tab in the top menu.
|
||||||
|
|
||||||
The second big feature in this release is a complete revamp of **prompt enhancements**. This feature helps you craft messages to get even better results from Cline. Here’s what’s new:
|
The second big feature in this release is a complete revamp of **prompt enhancements**. This feature helps you craft messages to get even better results from Cline. Here’s what’s new:
|
||||||
- Works with **any provider** and API configuration, not just OpenRouter.
|
|
||||||
- Fully customizable prompts to match your unique needs.
|
- Works with **any provider** and API configuration, not just OpenRouter.
|
||||||
|
- Fully customizable prompts to match your unique needs.
|
||||||
- Same simple workflow: just hit the ✨ **Enhance Prompt** button in the chat input to try it out.
|
- Same simple workflow: just hit the ✨ **Enhance Prompt** button in the chat input to try it out.
|
||||||
|
|
||||||
Whether you’re using GPT-4, other APIs, or switching configurations, this gives you total control over how your prompts are optimized.
|
Whether you’re using GPT-4, other APIs, or switching configurations, this gives you total control over how your prompts are optimized.
|
||||||
|
|
||||||
As always, we’d love to hear your thoughts and ideas! What features do you want to see in **v3.2**? Drop by https://www.reddit.com/r/roocline and join the discussion - we're building Roo Cline together. 🚀
|
As always, we’d love to hear your thoughts and ideas! What features do you want to see in **v3.2**? Drop by https://www.reddit.com/r/roocline and join the discussion - we're building Roo Cline together. 🚀
|
||||||
|
|
||||||
## New in 3.0 - Chat Modes!
|
## New in 3.0 - Chat Modes!
|
||||||
|
|
||||||
@@ -33,6 +34,7 @@ You can now choose between different prompts for Roo Cline to better suit your w
|
|||||||
It’s super simple! There’s a dropdown in the bottom left of the chat input to switch modes. Right next to it, you’ll find a way to switch between the API configuration profiles associated with the current mode (configured on the settings screen).
|
It’s super simple! There’s a dropdown in the bottom left of the chat input to switch modes. Right next to it, you’ll find a way to switch between the API configuration profiles associated with the current mode (configured on the settings screen).
|
||||||
|
|
||||||
**Why Add This?**
|
**Why Add This?**
|
||||||
|
|
||||||
- It keeps Cline from being overly eager to jump into solving problems when you just want to think or ask questions.
|
- It keeps Cline from being overly eager to jump into solving problems when you just want to think or ask questions.
|
||||||
- Each mode remembers the API configuration you last used with it. For example, you can use more thoughtful models like OpenAI o1 for Architect and Ask, while sticking with Sonnet or DeepSeek for coding tasks.
|
- Each mode remembers the API configuration you last used with it. For example, you can use more thoughtful models like OpenAI o1 for Architect and Ask, while sticking with Sonnet or DeepSeek for coding tasks.
|
||||||
- It builds on research suggesting better results when separating "thinking" from "coding," explained well in this very thoughtful [article](https://aider.chat/2024/09/26/architect.html) from aider.
|
- It builds on research suggesting better results when separating "thinking" from "coding," explained well in this very thoughtful [article](https://aider.chat/2024/09/26/architect.html) from aider.
|
||||||
@@ -50,25 +52,27 @@ Here's an example of Roo-Cline autonomously creating a snake game with "Always a
|
|||||||
https://github.com/user-attachments/assets/c2bb31dc-e9b2-4d73-885d-17f1471a4987
|
https://github.com/user-attachments/assets/c2bb31dc-e9b2-4d73-885d-17f1471a4987
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
To contribute to the project, start by exploring [open issues](https://github.com/RooVetGit/Roo-Cline/issues) or checking our [feature request board](https://github.com/RooVetGit/Roo-Cline/discussions/categories/feature-requests). We'd also love to have you join the [Roo Cline Reddit](https://www.reddit.com/r/roocline/) to share ideas and connect with other contributors.
|
To contribute to the project, start by exploring [open issues](https://github.com/RooVetGit/Roo-Cline/issues) or checking our [feature request board](https://github.com/RooVetGit/Roo-Cline/discussions/categories/feature-requests). We'd also love to have you join the [Roo Cline Reddit](https://www.reddit.com/r/roocline/) to share ideas and connect with other contributors.
|
||||||
|
|
||||||
### Local Setup
|
### Local Setup
|
||||||
|
|
||||||
1. Install dependencies:
|
1. Install dependencies:
|
||||||
```bash
|
|
||||||
npm run install:all
|
```bash
|
||||||
```
|
npm run install:all
|
||||||
|
```
|
||||||
|
|
||||||
2. Build the VSIX file:
|
2. Build the VSIX file:
|
||||||
```bash
|
```bash
|
||||||
npm run build
|
npm run build
|
||||||
```
|
```
|
||||||
3. The new VSIX file will be created in the `bin/` directory
|
3. The new VSIX file will be created in the `bin/` directory
|
||||||
4. Install the extension from the VSIX file as described below:
|
4. Install the extension from the VSIX file as described below:
|
||||||
|
|
||||||
- **Option 1:** Drag and drop the `.vsix` file into your VSCode-compatible editor's Extensions panel (Cmd/Ctrl+Shift+X).
|
- **Option 1:** Drag and drop the `.vsix` file into your VSCode-compatible editor's Extensions panel (Cmd/Ctrl+Shift+X).
|
||||||
|
|
||||||
- **Option 2:** Install the plugin using the CLI, make sure you have your VSCode-compatible CLI installed and in your `PATH` variable. Cursor example: `export PATH="$PATH:/Applications/Cursor.app/Contents/MacOS"`
|
- **Option 2:** Install the plugin using the CLI, make sure you have your VSCode-compatible CLI installed and in your `PATH` variable. Cursor example: `export PATH="$PATH:/Applications/Cursor.app/Contents/MacOS"`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ex: cursor --install-extension bin/roo-cline-2.0.1.vsix
|
# Ex: cursor --install-extension bin/roo-cline-2.0.1.vsix
|
||||||
@@ -83,16 +87,17 @@ We use [changesets](https://github.com/changesets/changesets) for versioning and
|
|||||||
|
|
||||||
1. Create a PR with your changes
|
1. Create a PR with your changes
|
||||||
2. Create a new changeset by running `npm run changeset`
|
2. Create a new changeset by running `npm run changeset`
|
||||||
- Select the appropriate kind of change - `patch` for bug fixes, `minor` for new features, or `major` for breaking changes
|
- Select the appropriate kind of change - `patch` for bug fixes, `minor` for new features, or `major` for breaking changes
|
||||||
- Write a clear description of your changes that will be included in the changelog
|
- Write a clear description of your changes that will be included in the changelog
|
||||||
3. Get the PR approved and pass all checks
|
3. Get the PR approved and pass all checks
|
||||||
4. Merge it
|
4. Merge it
|
||||||
|
|
||||||
Once your merge is successful:
|
Once your merge is successful:
|
||||||
|
|
||||||
- The release workflow will automatically create a new "Changeset version bump" PR
|
- The release workflow will automatically create a new "Changeset version bump" PR
|
||||||
- This PR will:
|
- This PR will:
|
||||||
- Update the version based on your changeset
|
- Update the version based on your changeset
|
||||||
- Update the `CHANGELOG.md` file
|
- Update the `CHANGELOG.md` file
|
||||||
- Once the PR is approved and merged, a new version will be published
|
- Once the PR is approved and merged, a new version will be published
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -193,9 +198,9 @@ Try asking Cline to "test the app", and watch as he runs a command like `npm run
|
|||||||
|
|
||||||
Thanks to the [Model Context Protocol](https://github.com/modelcontextprotocol), Cline can extend his capabilities through custom tools. While you can use [community-made servers](https://github.com/modelcontextprotocol/servers), Cline can instead create and install tools tailored to your specific workflow. Just ask Cline to "add a tool" and he will handle everything, from creating a new MCP server to installing it into the extension. These custom tools then become part of Cline's toolkit, ready to use in future tasks.
|
Thanks to the [Model Context Protocol](https://github.com/modelcontextprotocol), Cline can extend his capabilities through custom tools. While you can use [community-made servers](https://github.com/modelcontextprotocol/servers), Cline can instead create and install tools tailored to your specific workflow. Just ask Cline to "add a tool" and he will handle everything, from creating a new MCP server to installing it into the extension. These custom tools then become part of Cline's toolkit, ready to use in future tasks.
|
||||||
|
|
||||||
- "add a tool that fetches Jira tickets": Retrieve ticket ACs and put Cline to work
|
- "add a tool that fetches Jira tickets": Retrieve ticket ACs and put Cline to work
|
||||||
- "add a tool that manages AWS EC2s": Check server metrics and scale instances up or down
|
- "add a tool that manages AWS EC2s": Check server metrics and scale instances up or down
|
||||||
- "add a tool that pulls the latest PagerDuty incidents": Fetch details and ask Cline to fix bugs
|
- "add a tool that pulls the latest PagerDuty incidents": Fetch details and ask Cline to fix bugs
|
||||||
|
|
||||||
<!-- Transparent pixel to create line break after floating image -->
|
<!-- Transparent pixel to create line break after floating image -->
|
||||||
|
|
||||||
|
|||||||
@@ -1,137 +1,146 @@
|
|||||||
|
|
||||||
## For All Settings
|
## For All Settings
|
||||||
|
|
||||||
1. Add the setting to ExtensionMessage.ts:
|
1. Add the setting to ExtensionMessage.ts:
|
||||||
- Add the setting to the ExtensionState interface
|
|
||||||
- Make it required if it has a default value, optional if it can be undefined
|
- Add the setting to the ExtensionState interface
|
||||||
- Example: `preferredLanguage: string`
|
- Make it required if it has a default value, optional if it can be undefined
|
||||||
|
- Example: `preferredLanguage: string`
|
||||||
|
|
||||||
2. Add test coverage:
|
2. Add test coverage:
|
||||||
- Add the setting to mockState in ClineProvider.test.ts
|
- Add the setting to mockState in ClineProvider.test.ts
|
||||||
- Add test cases for setting persistence and state updates
|
- Add test cases for setting persistence and state updates
|
||||||
- Ensure all tests pass before submitting changes
|
- Ensure all tests pass before submitting changes
|
||||||
|
|
||||||
## For Checkbox Settings
|
## For Checkbox Settings
|
||||||
|
|
||||||
1. Add the message type to WebviewMessage.ts:
|
1. Add the message type to WebviewMessage.ts:
|
||||||
- Add the setting name to the WebviewMessage type's type union
|
|
||||||
- Example: `| "multisearchDiffEnabled"`
|
- Add the setting name to the WebviewMessage type's type union
|
||||||
|
- Example: `| "multisearchDiffEnabled"`
|
||||||
|
|
||||||
2. Add the setting to ExtensionStateContext.tsx:
|
2. Add the setting to ExtensionStateContext.tsx:
|
||||||
- Add the setting to the ExtensionStateContextType interface
|
|
||||||
- Add the setter function to the interface
|
- Add the setting to the ExtensionStateContextType interface
|
||||||
- Add the setting to the initial state in useState
|
- Add the setter function to the interface
|
||||||
- Add the setting to the contextValue object
|
- Add the setting to the initial state in useState
|
||||||
- Example:
|
- Add the setting to the contextValue object
|
||||||
```typescript
|
- Example:
|
||||||
interface ExtensionStateContextType {
|
```typescript
|
||||||
multisearchDiffEnabled: boolean;
|
interface ExtensionStateContextType {
|
||||||
setMultisearchDiffEnabled: (value: boolean) => void;
|
multisearchDiffEnabled: boolean
|
||||||
}
|
setMultisearchDiffEnabled: (value: boolean) => void
|
||||||
```
|
}
|
||||||
|
```
|
||||||
|
|
||||||
3. Add the setting to ClineProvider.ts:
|
3. Add the setting to ClineProvider.ts:
|
||||||
- Add the setting name to the GlobalStateKey type union
|
|
||||||
- Add the setting to the Promise.all array in getState
|
- Add the setting name to the GlobalStateKey type union
|
||||||
- Add the setting to the return value in getState with a default value
|
- Add the setting to the Promise.all array in getState
|
||||||
- Add the setting to the destructured variables in getStateToPostToWebview
|
- Add the setting to the return value in getState with a default value
|
||||||
- Add the setting to the return value in getStateToPostToWebview
|
- Add the setting to the destructured variables in getStateToPostToWebview
|
||||||
- Add a case in setWebviewMessageListener to handle the setting's message type
|
- Add the setting to the return value in getStateToPostToWebview
|
||||||
- Example:
|
- Add a case in setWebviewMessageListener to handle the setting's message type
|
||||||
```typescript
|
- Example:
|
||||||
case "multisearchDiffEnabled":
|
```typescript
|
||||||
await this.updateGlobalState("multisearchDiffEnabled", message.bool)
|
case "multisearchDiffEnabled":
|
||||||
await this.postStateToWebview()
|
await this.updateGlobalState("multisearchDiffEnabled", message.bool)
|
||||||
break
|
await this.postStateToWebview()
|
||||||
```
|
break
|
||||||
|
```
|
||||||
|
|
||||||
4. Add the checkbox UI to SettingsView.tsx:
|
4. Add the checkbox UI to SettingsView.tsx:
|
||||||
- Import the setting and its setter from ExtensionStateContext
|
|
||||||
- Add the VSCodeCheckbox component with the setting's state and onChange handler
|
- Import the setting and its setter from ExtensionStateContext
|
||||||
- Add appropriate labels and description text
|
- Add the VSCodeCheckbox component with the setting's state and onChange handler
|
||||||
- Example:
|
- Add appropriate labels and description text
|
||||||
```typescript
|
- Example:
|
||||||
<VSCodeCheckbox
|
```typescript
|
||||||
checked={multisearchDiffEnabled}
|
<VSCodeCheckbox
|
||||||
onChange={(e: any) => setMultisearchDiffEnabled(e.target.checked)}
|
checked={multisearchDiffEnabled}
|
||||||
>
|
onChange={(e: any) => setMultisearchDiffEnabled(e.target.checked)}
|
||||||
<span style={{ fontWeight: "500" }}>Enable multi-search diff matching</span>
|
>
|
||||||
</VSCodeCheckbox>
|
<span style={{ fontWeight: "500" }}>Enable multi-search diff matching</span>
|
||||||
```
|
</VSCodeCheckbox>
|
||||||
|
```
|
||||||
|
|
||||||
5. Add the setting to handleSubmit in SettingsView.tsx:
|
5. Add the setting to handleSubmit in SettingsView.tsx:
|
||||||
- Add a vscode.postMessage call to send the setting's value when clicking Done
|
- Add a vscode.postMessage call to send the setting's value when clicking Done
|
||||||
- Example:
|
- Example:
|
||||||
```typescript
|
```typescript
|
||||||
vscode.postMessage({ type: "multisearchDiffEnabled", bool: multisearchDiffEnabled })
|
vscode.postMessage({ type: "multisearchDiffEnabled", bool: multisearchDiffEnabled })
|
||||||
```
|
```
|
||||||
|
|
||||||
## For Select/Dropdown Settings
|
## For Select/Dropdown Settings
|
||||||
|
|
||||||
1. Add the message type to WebviewMessage.ts:
|
1. Add the message type to WebviewMessage.ts:
|
||||||
- Add the setting name to the WebviewMessage type's type union
|
|
||||||
- Example: `| "preferredLanguage"`
|
- Add the setting name to the WebviewMessage type's type union
|
||||||
|
- Example: `| "preferredLanguage"`
|
||||||
|
|
||||||
2. Add the setting to ExtensionStateContext.tsx:
|
2. Add the setting to ExtensionStateContext.tsx:
|
||||||
- Add the setting to the ExtensionStateContextType interface
|
|
||||||
- Add the setter function to the interface
|
- Add the setting to the ExtensionStateContextType interface
|
||||||
- Add the setting to the initial state in useState with a default value
|
- Add the setter function to the interface
|
||||||
- Add the setting to the contextValue object
|
- Add the setting to the initial state in useState with a default value
|
||||||
- Example:
|
- Add the setting to the contextValue object
|
||||||
```typescript
|
- Example:
|
||||||
interface ExtensionStateContextType {
|
```typescript
|
||||||
preferredLanguage: string;
|
interface ExtensionStateContextType {
|
||||||
setPreferredLanguage: (value: string) => void;
|
preferredLanguage: string
|
||||||
}
|
setPreferredLanguage: (value: string) => void
|
||||||
```
|
}
|
||||||
|
```
|
||||||
|
|
||||||
3. Add the setting to ClineProvider.ts:
|
3. Add the setting to ClineProvider.ts:
|
||||||
- Add the setting name to the GlobalStateKey type union
|
|
||||||
- Add the setting to the Promise.all array in getState
|
- Add the setting name to the GlobalStateKey type union
|
||||||
- Add the setting to the return value in getState with a default value
|
- Add the setting to the Promise.all array in getState
|
||||||
- Add the setting to the destructured variables in getStateToPostToWebview
|
- Add the setting to the return value in getState with a default value
|
||||||
- Add the setting to the return value in getStateToPostToWebview
|
- Add the setting to the destructured variables in getStateToPostToWebview
|
||||||
- Add a case in setWebviewMessageListener to handle the setting's message type
|
- Add the setting to the return value in getStateToPostToWebview
|
||||||
- Example:
|
- Add a case in setWebviewMessageListener to handle the setting's message type
|
||||||
```typescript
|
- Example:
|
||||||
case "preferredLanguage":
|
```typescript
|
||||||
await this.updateGlobalState("preferredLanguage", message.text)
|
case "preferredLanguage":
|
||||||
await this.postStateToWebview()
|
await this.updateGlobalState("preferredLanguage", message.text)
|
||||||
break
|
await this.postStateToWebview()
|
||||||
```
|
break
|
||||||
|
```
|
||||||
|
|
||||||
4. Add the select UI to SettingsView.tsx:
|
4. Add the select UI to SettingsView.tsx:
|
||||||
- Import the setting and its setter from ExtensionStateContext
|
|
||||||
- Add the select element with appropriate styling to match VSCode's theme
|
- Import the setting and its setter from ExtensionStateContext
|
||||||
- Add options for the dropdown
|
- Add the select element with appropriate styling to match VSCode's theme
|
||||||
- Add appropriate labels and description text
|
- Add options for the dropdown
|
||||||
- Example:
|
- Add appropriate labels and description text
|
||||||
```typescript
|
- Example:
|
||||||
<select
|
```typescript
|
||||||
value={preferredLanguage}
|
<select
|
||||||
onChange={(e) => setPreferredLanguage(e.target.value)}
|
value={preferredLanguage}
|
||||||
style={{
|
onChange={(e) => setPreferredLanguage(e.target.value)}
|
||||||
width: "100%",
|
style={{
|
||||||
padding: "4px 8px",
|
width: "100%",
|
||||||
backgroundColor: "var(--vscode-input-background)",
|
padding: "4px 8px",
|
||||||
color: "var(--vscode-input-foreground)",
|
backgroundColor: "var(--vscode-input-background)",
|
||||||
border: "1px solid var(--vscode-input-border)",
|
color: "var(--vscode-input-foreground)",
|
||||||
borderRadius: "2px"
|
border: "1px solid var(--vscode-input-border)",
|
||||||
}}>
|
borderRadius: "2px"
|
||||||
<option value="English">English</option>
|
}}>
|
||||||
<option value="Spanish">Spanish</option>
|
<option value="English">English</option>
|
||||||
...
|
<option value="Spanish">Spanish</option>
|
||||||
</select>
|
...
|
||||||
```
|
</select>
|
||||||
|
```
|
||||||
|
|
||||||
5. Add the setting to handleSubmit in SettingsView.tsx:
|
5. Add the setting to handleSubmit in SettingsView.tsx:
|
||||||
- Add a vscode.postMessage call to send the setting's value when clicking Done
|
- Add a vscode.postMessage call to send the setting's value when clicking Done
|
||||||
- Example:
|
- Example:
|
||||||
```typescript
|
```typescript
|
||||||
vscode.postMessage({ type: "preferredLanguage", text: preferredLanguage })
|
vscode.postMessage({ type: "preferredLanguage", text: preferredLanguage })
|
||||||
```
|
```
|
||||||
|
|
||||||
These steps ensure that:
|
These steps ensure that:
|
||||||
|
|
||||||
- The setting's state is properly typed throughout the application
|
- The setting's state is properly typed throughout the application
|
||||||
- The setting persists between sessions
|
- The setting persists between sessions
|
||||||
- The setting's value is properly synchronized between the webview and extension
|
- The setting's value is properly synchronized between the webview and extension
|
||||||
|
|||||||
@@ -1,41 +1,40 @@
|
|||||||
/** @type {import('ts-jest').JestConfigWithTsJest} */
|
/** @type {import('ts-jest').JestConfigWithTsJest} */
|
||||||
module.exports = {
|
module.exports = {
|
||||||
preset: 'ts-jest',
|
preset: "ts-jest",
|
||||||
testEnvironment: 'node',
|
testEnvironment: "node",
|
||||||
moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'],
|
moduleFileExtensions: ["ts", "tsx", "js", "jsx", "json", "node"],
|
||||||
transform: {
|
transform: {
|
||||||
'^.+\\.tsx?$': ['ts-jest', {
|
"^.+\\.tsx?$": [
|
||||||
tsconfig: {
|
"ts-jest",
|
||||||
"module": "CommonJS",
|
{
|
||||||
"moduleResolution": "node",
|
tsconfig: {
|
||||||
"esModuleInterop": true,
|
module: "CommonJS",
|
||||||
"allowJs": true
|
moduleResolution: "node",
|
||||||
},
|
esModuleInterop: true,
|
||||||
diagnostics: false,
|
allowJs: true,
|
||||||
isolatedModules: true
|
},
|
||||||
}]
|
diagnostics: false,
|
||||||
},
|
isolatedModules: true,
|
||||||
testMatch: ['**/__tests__/**/*.test.ts'],
|
},
|
||||||
moduleNameMapper: {
|
],
|
||||||
'^vscode$': '<rootDir>/src/__mocks__/vscode.js',
|
},
|
||||||
'@modelcontextprotocol/sdk$': '<rootDir>/src/__mocks__/@modelcontextprotocol/sdk/index.js',
|
testMatch: ["**/__tests__/**/*.test.ts"],
|
||||||
'@modelcontextprotocol/sdk/(.*)': '<rootDir>/src/__mocks__/@modelcontextprotocol/sdk/$1',
|
moduleNameMapper: {
|
||||||
'^delay$': '<rootDir>/src/__mocks__/delay.js',
|
"^vscode$": "<rootDir>/src/__mocks__/vscode.js",
|
||||||
'^p-wait-for$': '<rootDir>/src/__mocks__/p-wait-for.js',
|
"@modelcontextprotocol/sdk$": "<rootDir>/src/__mocks__/@modelcontextprotocol/sdk/index.js",
|
||||||
'^globby$': '<rootDir>/src/__mocks__/globby.js',
|
"@modelcontextprotocol/sdk/(.*)": "<rootDir>/src/__mocks__/@modelcontextprotocol/sdk/$1",
|
||||||
'^serialize-error$': '<rootDir>/src/__mocks__/serialize-error.js',
|
"^delay$": "<rootDir>/src/__mocks__/delay.js",
|
||||||
'^strip-ansi$': '<rootDir>/src/__mocks__/strip-ansi.js',
|
"^p-wait-for$": "<rootDir>/src/__mocks__/p-wait-for.js",
|
||||||
'^default-shell$': '<rootDir>/src/__mocks__/default-shell.js',
|
"^globby$": "<rootDir>/src/__mocks__/globby.js",
|
||||||
'^os-name$': '<rootDir>/src/__mocks__/os-name.js'
|
"^serialize-error$": "<rootDir>/src/__mocks__/serialize-error.js",
|
||||||
},
|
"^strip-ansi$": "<rootDir>/src/__mocks__/strip-ansi.js",
|
||||||
transformIgnorePatterns: [
|
"^default-shell$": "<rootDir>/src/__mocks__/default-shell.js",
|
||||||
'node_modules/(?!(@modelcontextprotocol|delay|p-wait-for|globby|serialize-error|strip-ansi|default-shell|os-name)/)'
|
"^os-name$": "<rootDir>/src/__mocks__/os-name.js",
|
||||||
],
|
},
|
||||||
modulePathIgnorePatterns: [
|
transformIgnorePatterns: [
|
||||||
'.vscode-test'
|
"node_modules/(?!(@modelcontextprotocol|delay|p-wait-for|globby|serialize-error|strip-ansi|default-shell|os-name)/)",
|
||||||
],
|
],
|
||||||
reporters: [
|
modulePathIgnorePatterns: [".vscode-test"],
|
||||||
["jest-simple-dot-reporter", {}]
|
reporters: [["jest-simple-dot-reporter", {}]],
|
||||||
],
|
setupFiles: [],
|
||||||
setupFiles: []
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
class Client {
|
class Client {
|
||||||
constructor() {
|
constructor() {
|
||||||
this.request = jest.fn()
|
this.request = jest.fn()
|
||||||
}
|
}
|
||||||
|
|
||||||
connect() {
|
connect() {
|
||||||
return Promise.resolve()
|
return Promise.resolve()
|
||||||
}
|
}
|
||||||
|
|
||||||
close() {
|
close() {
|
||||||
return Promise.resolve()
|
return Promise.resolve()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
Client
|
Client,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,22 +1,22 @@
|
|||||||
class StdioClientTransport {
|
class StdioClientTransport {
|
||||||
constructor() {
|
constructor() {
|
||||||
this.start = jest.fn().mockResolvedValue(undefined)
|
this.start = jest.fn().mockResolvedValue(undefined)
|
||||||
this.close = jest.fn().mockResolvedValue(undefined)
|
this.close = jest.fn().mockResolvedValue(undefined)
|
||||||
this.stderr = {
|
this.stderr = {
|
||||||
on: jest.fn()
|
on: jest.fn(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class StdioServerParameters {
|
class StdioServerParameters {
|
||||||
constructor() {
|
constructor() {
|
||||||
this.command = ''
|
this.command = ""
|
||||||
this.args = []
|
this.args = []
|
||||||
this.env = {}
|
this.env = {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
StdioClientTransport,
|
StdioClientTransport,
|
||||||
StdioServerParameters
|
StdioServerParameters,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,24 +1,24 @@
|
|||||||
const { Client } = require('./client/index.js')
|
const { Client } = require("./client/index.js")
|
||||||
const { StdioClientTransport, StdioServerParameters } = require('./client/stdio.js')
|
const { StdioClientTransport, StdioServerParameters } = require("./client/stdio.js")
|
||||||
const {
|
const {
|
||||||
CallToolResultSchema,
|
CallToolResultSchema,
|
||||||
ListToolsResultSchema,
|
ListToolsResultSchema,
|
||||||
ListResourcesResultSchema,
|
ListResourcesResultSchema,
|
||||||
ListResourceTemplatesResultSchema,
|
ListResourceTemplatesResultSchema,
|
||||||
ReadResourceResultSchema,
|
ReadResourceResultSchema,
|
||||||
ErrorCode,
|
ErrorCode,
|
||||||
McpError
|
McpError,
|
||||||
} = require('./types.js')
|
} = require("./types.js")
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
Client,
|
Client,
|
||||||
StdioClientTransport,
|
StdioClientTransport,
|
||||||
StdioServerParameters,
|
StdioServerParameters,
|
||||||
CallToolResultSchema,
|
CallToolResultSchema,
|
||||||
ListToolsResultSchema,
|
ListToolsResultSchema,
|
||||||
ListResourcesResultSchema,
|
ListResourcesResultSchema,
|
||||||
ListResourceTemplatesResultSchema,
|
ListResourceTemplatesResultSchema,
|
||||||
ReadResourceResultSchema,
|
ReadResourceResultSchema,
|
||||||
ErrorCode,
|
ErrorCode,
|
||||||
McpError
|
McpError,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,51 +1,51 @@
|
|||||||
const CallToolResultSchema = {
|
const CallToolResultSchema = {
|
||||||
parse: jest.fn().mockReturnValue({})
|
parse: jest.fn().mockReturnValue({}),
|
||||||
}
|
}
|
||||||
|
|
||||||
const ListToolsResultSchema = {
|
const ListToolsResultSchema = {
|
||||||
parse: jest.fn().mockReturnValue({
|
parse: jest.fn().mockReturnValue({
|
||||||
tools: []
|
tools: [],
|
||||||
})
|
}),
|
||||||
}
|
}
|
||||||
|
|
||||||
const ListResourcesResultSchema = {
|
const ListResourcesResultSchema = {
|
||||||
parse: jest.fn().mockReturnValue({
|
parse: jest.fn().mockReturnValue({
|
||||||
resources: []
|
resources: [],
|
||||||
})
|
}),
|
||||||
}
|
}
|
||||||
|
|
||||||
const ListResourceTemplatesResultSchema = {
|
const ListResourceTemplatesResultSchema = {
|
||||||
parse: jest.fn().mockReturnValue({
|
parse: jest.fn().mockReturnValue({
|
||||||
resourceTemplates: []
|
resourceTemplates: [],
|
||||||
})
|
}),
|
||||||
}
|
}
|
||||||
|
|
||||||
const ReadResourceResultSchema = {
|
const ReadResourceResultSchema = {
|
||||||
parse: jest.fn().mockReturnValue({
|
parse: jest.fn().mockReturnValue({
|
||||||
contents: []
|
contents: [],
|
||||||
})
|
}),
|
||||||
}
|
}
|
||||||
|
|
||||||
const ErrorCode = {
|
const ErrorCode = {
|
||||||
InvalidRequest: 'InvalidRequest',
|
InvalidRequest: "InvalidRequest",
|
||||||
MethodNotFound: 'MethodNotFound',
|
MethodNotFound: "MethodNotFound",
|
||||||
InvalidParams: 'InvalidParams',
|
InvalidParams: "InvalidParams",
|
||||||
InternalError: 'InternalError'
|
InternalError: "InternalError",
|
||||||
}
|
}
|
||||||
|
|
||||||
class McpError extends Error {
|
class McpError extends Error {
|
||||||
constructor(code, message) {
|
constructor(code, message) {
|
||||||
super(message)
|
super(message)
|
||||||
this.code = code
|
this.code = code
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
CallToolResultSchema,
|
CallToolResultSchema,
|
||||||
ListToolsResultSchema,
|
ListToolsResultSchema,
|
||||||
ListResourcesResultSchema,
|
ListResourcesResultSchema,
|
||||||
ListResourceTemplatesResultSchema,
|
ListResourceTemplatesResultSchema,
|
||||||
ReadResourceResultSchema,
|
ReadResourceResultSchema,
|
||||||
ErrorCode,
|
ErrorCode,
|
||||||
McpError
|
McpError,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
export class McpHub {
|
export class McpHub {
|
||||||
connections = []
|
connections = []
|
||||||
isConnecting = false
|
isConnecting = false
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
this.toggleToolAlwaysAllow = jest.fn()
|
this.toggleToolAlwaysAllow = jest.fn()
|
||||||
this.callTool = jest.fn()
|
this.callTool = jest.fn()
|
||||||
}
|
}
|
||||||
|
|
||||||
async toggleToolAlwaysAllow(serverName: string, toolName: string, shouldAllow: boolean): Promise<void> {
|
async toggleToolAlwaysAllow(serverName: string, toolName: string, shouldAllow: boolean): Promise<void> {
|
||||||
return Promise.resolve()
|
return Promise.resolve()
|
||||||
}
|
}
|
||||||
|
|
||||||
async callTool(serverName: string, toolName: string, toolArguments?: Record<string, unknown>): Promise<any> {
|
async callTool(serverName: string, toolName: string, toolArguments?: Record<string, unknown>): Promise<any> {
|
||||||
return Promise.resolve({ result: 'success' })
|
return Promise.resolve({ result: "success" })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
// Mock default shell based on platform
|
// Mock default shell based on platform
|
||||||
const os = require('os');
|
const os = require("os")
|
||||||
|
|
||||||
let defaultShell;
|
let defaultShell
|
||||||
if (os.platform() === 'win32') {
|
if (os.platform() === "win32") {
|
||||||
defaultShell = 'cmd.exe';
|
defaultShell = "cmd.exe"
|
||||||
} else {
|
} else {
|
||||||
defaultShell = '/bin/bash';
|
defaultShell = "/bin/bash"
|
||||||
}
|
}
|
||||||
|
|
||||||
module.exports = defaultShell;
|
module.exports = defaultShell
|
||||||
module.exports.default = defaultShell;
|
module.exports.default = defaultShell
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
function delay(ms) {
|
function delay(ms) {
|
||||||
return new Promise(resolve => setTimeout(resolve, ms));
|
return new Promise((resolve) => setTimeout(resolve, ms))
|
||||||
}
|
}
|
||||||
|
|
||||||
module.exports = delay;
|
module.exports = delay
|
||||||
module.exports.default = delay;
|
module.exports.default = delay
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
function globby(patterns, options) {
|
function globby(patterns, options) {
|
||||||
return Promise.resolve([]);
|
return Promise.resolve([])
|
||||||
}
|
}
|
||||||
|
|
||||||
globby.sync = function(patterns, options) {
|
globby.sync = function (patterns, options) {
|
||||||
return [];
|
return []
|
||||||
};
|
}
|
||||||
|
|
||||||
module.exports = globby;
|
module.exports = globby
|
||||||
module.exports.default = globby;
|
module.exports.default = globby
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
function osName() {
|
function osName() {
|
||||||
return 'macOS';
|
return "macOS"
|
||||||
}
|
}
|
||||||
|
|
||||||
module.exports = osName;
|
module.exports = osName
|
||||||
module.exports.default = osName;
|
module.exports.default = osName
|
||||||
|
|||||||
@@ -1,20 +1,20 @@
|
|||||||
function pWaitFor(condition, options = {}) {
|
function pWaitFor(condition, options = {}) {
|
||||||
return new Promise((resolve, reject) => {
|
return new Promise((resolve, reject) => {
|
||||||
const interval = setInterval(() => {
|
const interval = setInterval(() => {
|
||||||
if (condition()) {
|
if (condition()) {
|
||||||
clearInterval(interval);
|
clearInterval(interval)
|
||||||
resolve();
|
resolve()
|
||||||
}
|
}
|
||||||
}, options.interval || 20);
|
}, options.interval || 20)
|
||||||
|
|
||||||
if (options.timeout) {
|
if (options.timeout) {
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
clearInterval(interval);
|
clearInterval(interval)
|
||||||
reject(new Error('Timed out'));
|
reject(new Error("Timed out"))
|
||||||
}, options.timeout);
|
}, options.timeout)
|
||||||
}
|
}
|
||||||
});
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
module.exports = pWaitFor;
|
module.exports = pWaitFor
|
||||||
module.exports.default = pWaitFor;
|
module.exports.default = pWaitFor
|
||||||
|
|||||||
@@ -1,25 +1,25 @@
|
|||||||
function serializeError(error) {
|
function serializeError(error) {
|
||||||
if (error instanceof Error) {
|
if (error instanceof Error) {
|
||||||
return {
|
return {
|
||||||
name: error.name,
|
name: error.name,
|
||||||
message: error.message,
|
message: error.message,
|
||||||
stack: error.stack
|
stack: error.stack,
|
||||||
};
|
}
|
||||||
}
|
}
|
||||||
return error;
|
return error
|
||||||
}
|
}
|
||||||
|
|
||||||
function deserializeError(errorData) {
|
function deserializeError(errorData) {
|
||||||
if (errorData && typeof errorData === 'object') {
|
if (errorData && typeof errorData === "object") {
|
||||||
const error = new Error(errorData.message);
|
const error = new Error(errorData.message)
|
||||||
error.name = errorData.name;
|
error.name = errorData.name
|
||||||
error.stack = errorData.stack;
|
error.stack = errorData.stack
|
||||||
return error;
|
return error
|
||||||
}
|
}
|
||||||
return errorData;
|
return errorData
|
||||||
}
|
}
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
serializeError,
|
serializeError,
|
||||||
deserializeError
|
deserializeError,
|
||||||
};
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
function stripAnsi(string) {
|
function stripAnsi(string) {
|
||||||
// Simple mock that just returns the input string
|
// Simple mock that just returns the input string
|
||||||
return string;
|
return string
|
||||||
}
|
}
|
||||||
|
|
||||||
module.exports = stripAnsi;
|
module.exports = stripAnsi
|
||||||
module.exports.default = stripAnsi;
|
module.exports.default = stripAnsi
|
||||||
|
|||||||
@@ -1,57 +1,57 @@
|
|||||||
const vscode = {
|
const vscode = {
|
||||||
window: {
|
window: {
|
||||||
showInformationMessage: jest.fn(),
|
showInformationMessage: jest.fn(),
|
||||||
showErrorMessage: jest.fn(),
|
showErrorMessage: jest.fn(),
|
||||||
createTextEditorDecorationType: jest.fn().mockReturnValue({
|
createTextEditorDecorationType: jest.fn().mockReturnValue({
|
||||||
dispose: jest.fn()
|
dispose: jest.fn(),
|
||||||
})
|
}),
|
||||||
},
|
},
|
||||||
workspace: {
|
workspace: {
|
||||||
onDidSaveTextDocument: jest.fn()
|
onDidSaveTextDocument: jest.fn(),
|
||||||
},
|
},
|
||||||
Disposable: class {
|
Disposable: class {
|
||||||
dispose() {}
|
dispose() {}
|
||||||
},
|
},
|
||||||
Uri: {
|
Uri: {
|
||||||
file: (path) => ({
|
file: (path) => ({
|
||||||
fsPath: path,
|
fsPath: path,
|
||||||
scheme: 'file',
|
scheme: "file",
|
||||||
authority: '',
|
authority: "",
|
||||||
path: path,
|
path: path,
|
||||||
query: '',
|
query: "",
|
||||||
fragment: '',
|
fragment: "",
|
||||||
with: jest.fn(),
|
with: jest.fn(),
|
||||||
toJSON: jest.fn()
|
toJSON: jest.fn(),
|
||||||
})
|
}),
|
||||||
},
|
},
|
||||||
EventEmitter: class {
|
EventEmitter: class {
|
||||||
constructor() {
|
constructor() {
|
||||||
this.event = jest.fn();
|
this.event = jest.fn()
|
||||||
this.fire = jest.fn();
|
this.fire = jest.fn()
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
ConfigurationTarget: {
|
ConfigurationTarget: {
|
||||||
Global: 1,
|
Global: 1,
|
||||||
Workspace: 2,
|
Workspace: 2,
|
||||||
WorkspaceFolder: 3
|
WorkspaceFolder: 3,
|
||||||
},
|
},
|
||||||
Position: class {
|
Position: class {
|
||||||
constructor(line, character) {
|
constructor(line, character) {
|
||||||
this.line = line;
|
this.line = line
|
||||||
this.character = character;
|
this.character = character
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Range: class {
|
Range: class {
|
||||||
constructor(startLine, startCharacter, endLine, endCharacter) {
|
constructor(startLine, startCharacter, endLine, endCharacter) {
|
||||||
this.start = new vscode.Position(startLine, startCharacter);
|
this.start = new vscode.Position(startLine, startCharacter)
|
||||||
this.end = new vscode.Position(endLine, endCharacter);
|
this.end = new vscode.Position(endLine, endCharacter)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
ThemeColor: class {
|
ThemeColor: class {
|
||||||
constructor(id) {
|
constructor(id) {
|
||||||
this.id = id;
|
this.id = id
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
};
|
}
|
||||||
|
|
||||||
module.exports = vscode;
|
module.exports = vscode
|
||||||
|
|||||||
@@ -1,239 +1,238 @@
|
|||||||
import { AnthropicHandler } from '../anthropic';
|
import { AnthropicHandler } from "../anthropic"
|
||||||
import { ApiHandlerOptions } from '../../../shared/api';
|
import { ApiHandlerOptions } from "../../../shared/api"
|
||||||
import { ApiStream } from '../../transform/stream';
|
import { ApiStream } from "../../transform/stream"
|
||||||
import { Anthropic } from '@anthropic-ai/sdk';
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
|
|
||||||
// Mock Anthropic client
|
// Mock Anthropic client
|
||||||
const mockBetaCreate = jest.fn();
|
const mockBetaCreate = jest.fn()
|
||||||
const mockCreate = jest.fn();
|
const mockCreate = jest.fn()
|
||||||
jest.mock('@anthropic-ai/sdk', () => {
|
jest.mock("@anthropic-ai/sdk", () => {
|
||||||
return {
|
return {
|
||||||
Anthropic: jest.fn().mockImplementation(() => ({
|
Anthropic: jest.fn().mockImplementation(() => ({
|
||||||
beta: {
|
beta: {
|
||||||
promptCaching: {
|
promptCaching: {
|
||||||
messages: {
|
messages: {
|
||||||
create: mockBetaCreate.mockImplementation(async () => ({
|
create: mockBetaCreate.mockImplementation(async () => ({
|
||||||
async *[Symbol.asyncIterator]() {
|
async *[Symbol.asyncIterator]() {
|
||||||
yield {
|
yield {
|
||||||
type: 'message_start',
|
type: "message_start",
|
||||||
message: {
|
message: {
|
||||||
usage: {
|
usage: {
|
||||||
input_tokens: 100,
|
input_tokens: 100,
|
||||||
output_tokens: 50,
|
output_tokens: 50,
|
||||||
cache_creation_input_tokens: 20,
|
cache_creation_input_tokens: 20,
|
||||||
cache_read_input_tokens: 10
|
cache_read_input_tokens: 10,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
};
|
}
|
||||||
yield {
|
yield {
|
||||||
type: 'content_block_start',
|
type: "content_block_start",
|
||||||
index: 0,
|
index: 0,
|
||||||
content_block: {
|
content_block: {
|
||||||
type: 'text',
|
type: "text",
|
||||||
text: 'Hello'
|
text: "Hello",
|
||||||
}
|
},
|
||||||
};
|
}
|
||||||
yield {
|
yield {
|
||||||
type: 'content_block_delta',
|
type: "content_block_delta",
|
||||||
delta: {
|
delta: {
|
||||||
type: 'text_delta',
|
type: "text_delta",
|
||||||
text: ' world'
|
text: " world",
|
||||||
}
|
},
|
||||||
};
|
}
|
||||||
}
|
},
|
||||||
}))
|
})),
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
messages: {
|
messages: {
|
||||||
create: mockCreate.mockImplementation(async (options) => {
|
create: mockCreate.mockImplementation(async (options) => {
|
||||||
if (!options.stream) {
|
if (!options.stream) {
|
||||||
return {
|
return {
|
||||||
id: 'test-completion',
|
id: "test-completion",
|
||||||
content: [
|
content: [{ type: "text", text: "Test response" }],
|
||||||
{ type: 'text', text: 'Test response' }
|
role: "assistant",
|
||||||
],
|
model: options.model,
|
||||||
role: 'assistant',
|
usage: {
|
||||||
model: options.model,
|
input_tokens: 10,
|
||||||
usage: {
|
output_tokens: 5,
|
||||||
input_tokens: 10,
|
},
|
||||||
output_tokens: 5
|
}
|
||||||
}
|
}
|
||||||
}
|
return {
|
||||||
}
|
async *[Symbol.asyncIterator]() {
|
||||||
return {
|
yield {
|
||||||
async *[Symbol.asyncIterator]() {
|
type: "message_start",
|
||||||
yield {
|
message: {
|
||||||
type: 'message_start',
|
usage: {
|
||||||
message: {
|
input_tokens: 10,
|
||||||
usage: {
|
output_tokens: 5,
|
||||||
input_tokens: 10,
|
},
|
||||||
output_tokens: 5
|
},
|
||||||
}
|
}
|
||||||
}
|
yield {
|
||||||
}
|
type: "content_block_start",
|
||||||
yield {
|
content_block: {
|
||||||
type: 'content_block_start',
|
type: "text",
|
||||||
content_block: {
|
text: "Test response",
|
||||||
type: 'text',
|
},
|
||||||
text: 'Test response'
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}),
|
||||||
}
|
},
|
||||||
})
|
})),
|
||||||
}
|
}
|
||||||
}))
|
})
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('AnthropicHandler', () => {
|
describe("AnthropicHandler", () => {
|
||||||
let handler: AnthropicHandler;
|
let handler: AnthropicHandler
|
||||||
let mockOptions: ApiHandlerOptions;
|
let mockOptions: ApiHandlerOptions
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
mockOptions = {
|
mockOptions = {
|
||||||
apiKey: 'test-api-key',
|
apiKey: "test-api-key",
|
||||||
apiModelId: 'claude-3-5-sonnet-20241022'
|
apiModelId: "claude-3-5-sonnet-20241022",
|
||||||
};
|
}
|
||||||
handler = new AnthropicHandler(mockOptions);
|
handler = new AnthropicHandler(mockOptions)
|
||||||
mockBetaCreate.mockClear();
|
mockBetaCreate.mockClear()
|
||||||
mockCreate.mockClear();
|
mockCreate.mockClear()
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('constructor', () => {
|
describe("constructor", () => {
|
||||||
it('should initialize with provided options', () => {
|
it("should initialize with provided options", () => {
|
||||||
expect(handler).toBeInstanceOf(AnthropicHandler);
|
expect(handler).toBeInstanceOf(AnthropicHandler)
|
||||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
|
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should initialize with undefined API key', () => {
|
it("should initialize with undefined API key", () => {
|
||||||
// The SDK will handle API key validation, so we just verify it initializes
|
// The SDK will handle API key validation, so we just verify it initializes
|
||||||
const handlerWithoutKey = new AnthropicHandler({
|
const handlerWithoutKey = new AnthropicHandler({
|
||||||
...mockOptions,
|
...mockOptions,
|
||||||
apiKey: undefined
|
apiKey: undefined,
|
||||||
});
|
})
|
||||||
expect(handlerWithoutKey).toBeInstanceOf(AnthropicHandler);
|
expect(handlerWithoutKey).toBeInstanceOf(AnthropicHandler)
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should use custom base URL if provided', () => {
|
it("should use custom base URL if provided", () => {
|
||||||
const customBaseUrl = 'https://custom.anthropic.com';
|
const customBaseUrl = "https://custom.anthropic.com"
|
||||||
const handlerWithCustomUrl = new AnthropicHandler({
|
const handlerWithCustomUrl = new AnthropicHandler({
|
||||||
...mockOptions,
|
...mockOptions,
|
||||||
anthropicBaseUrl: customBaseUrl
|
anthropicBaseUrl: customBaseUrl,
|
||||||
});
|
})
|
||||||
expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler);
|
expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler)
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('createMessage', () => {
|
describe("createMessage", () => {
|
||||||
const systemPrompt = 'You are a helpful assistant.';
|
const systemPrompt = "You are a helpful assistant."
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: "user",
|
||||||
content: [{
|
content: [
|
||||||
type: 'text' as const,
|
{
|
||||||
text: 'Hello!'
|
type: "text" as const,
|
||||||
}]
|
text: "Hello!",
|
||||||
}
|
},
|
||||||
];
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
it('should handle prompt caching for supported models', async () => {
|
it("should handle prompt caching for supported models", async () => {
|
||||||
const stream = handler.createMessage(systemPrompt, [
|
const stream = handler.createMessage(systemPrompt, [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: "user",
|
||||||
content: [{ type: 'text' as const, text: 'First message' }]
|
content: [{ type: "text" as const, text: "First message" }],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
role: 'assistant',
|
role: "assistant",
|
||||||
content: [{ type: 'text' as const, text: 'Response' }]
|
content: [{ type: "text" as const, text: "Response" }],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: "user",
|
||||||
content: [{ type: 'text' as const, text: 'Second message' }]
|
content: [{ type: "text" as const, text: "Second message" }],
|
||||||
}
|
},
|
||||||
]);
|
])
|
||||||
|
|
||||||
const chunks: any[] = [];
|
const chunks: any[] = []
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
chunks.push(chunk);
|
chunks.push(chunk)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify usage information
|
// Verify usage information
|
||||||
const usageChunk = chunks.find(chunk => chunk.type === 'usage');
|
const usageChunk = chunks.find((chunk) => chunk.type === "usage")
|
||||||
expect(usageChunk).toBeDefined();
|
expect(usageChunk).toBeDefined()
|
||||||
expect(usageChunk?.inputTokens).toBe(100);
|
expect(usageChunk?.inputTokens).toBe(100)
|
||||||
expect(usageChunk?.outputTokens).toBe(50);
|
expect(usageChunk?.outputTokens).toBe(50)
|
||||||
expect(usageChunk?.cacheWriteTokens).toBe(20);
|
expect(usageChunk?.cacheWriteTokens).toBe(20)
|
||||||
expect(usageChunk?.cacheReadTokens).toBe(10);
|
expect(usageChunk?.cacheReadTokens).toBe(10)
|
||||||
|
|
||||||
// Verify text content
|
// Verify text content
|
||||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||||
expect(textChunks).toHaveLength(2);
|
expect(textChunks).toHaveLength(2)
|
||||||
expect(textChunks[0].text).toBe('Hello');
|
expect(textChunks[0].text).toBe("Hello")
|
||||||
expect(textChunks[1].text).toBe(' world');
|
expect(textChunks[1].text).toBe(" world")
|
||||||
|
|
||||||
// Verify beta API was used
|
// Verify beta API was used
|
||||||
expect(mockBetaCreate).toHaveBeenCalled();
|
expect(mockBetaCreate).toHaveBeenCalled()
|
||||||
expect(mockCreate).not.toHaveBeenCalled();
|
expect(mockCreate).not.toHaveBeenCalled()
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('completePrompt', () => {
|
describe("completePrompt", () => {
|
||||||
it('should complete prompt successfully', async () => {
|
it("should complete prompt successfully", async () => {
|
||||||
const result = await handler.completePrompt('Test prompt');
|
const result = await handler.completePrompt("Test prompt")
|
||||||
expect(result).toBe('Test response');
|
expect(result).toBe("Test response")
|
||||||
expect(mockCreate).toHaveBeenCalledWith({
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
model: mockOptions.apiModelId,
|
model: mockOptions.apiModelId,
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
messages: [{ role: "user", content: "Test prompt" }],
|
||||||
max_tokens: 8192,
|
max_tokens: 8192,
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
stream: false
|
stream: false,
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
it("should handle API errors", async () => {
|
||||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||||
await expect(handler.completePrompt('Test prompt'))
|
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Anthropic completion error: API Error")
|
||||||
.rejects.toThrow('Anthropic completion error: API Error');
|
})
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle non-text content', async () => {
|
it("should handle non-text content", async () => {
|
||||||
mockCreate.mockImplementationOnce(async () => ({
|
mockCreate.mockImplementationOnce(async () => ({
|
||||||
content: [{ type: 'image' }]
|
content: [{ type: "image" }],
|
||||||
}));
|
}))
|
||||||
const result = await handler.completePrompt('Test prompt');
|
const result = await handler.completePrompt("Test prompt")
|
||||||
expect(result).toBe('');
|
expect(result).toBe("")
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle empty response', async () => {
|
it("should handle empty response", async () => {
|
||||||
mockCreate.mockImplementationOnce(async () => ({
|
mockCreate.mockImplementationOnce(async () => ({
|
||||||
content: [{ type: 'text', text: '' }]
|
content: [{ type: "text", text: "" }],
|
||||||
}));
|
}))
|
||||||
const result = await handler.completePrompt('Test prompt');
|
const result = await handler.completePrompt("Test prompt")
|
||||||
expect(result).toBe('');
|
expect(result).toBe("")
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('getModel', () => {
|
describe("getModel", () => {
|
||||||
it('should return default model if no model ID is provided', () => {
|
it("should return default model if no model ID is provided", () => {
|
||||||
const handlerWithoutModel = new AnthropicHandler({
|
const handlerWithoutModel = new AnthropicHandler({
|
||||||
...mockOptions,
|
...mockOptions,
|
||||||
apiModelId: undefined
|
apiModelId: undefined,
|
||||||
});
|
})
|
||||||
const model = handlerWithoutModel.getModel();
|
const model = handlerWithoutModel.getModel()
|
||||||
expect(model.id).toBeDefined();
|
expect(model.id).toBeDefined()
|
||||||
expect(model.info).toBeDefined();
|
expect(model.info).toBeDefined()
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should return specified model if valid model ID is provided', () => {
|
it("should return specified model if valid model ID is provided", () => {
|
||||||
const model = handler.getModel();
|
const model = handler.getModel()
|
||||||
expect(model.id).toBe(mockOptions.apiModelId);
|
expect(model.id).toBe(mockOptions.apiModelId)
|
||||||
expect(model.info).toBeDefined();
|
expect(model.info).toBeDefined()
|
||||||
expect(model.info.maxTokens).toBe(8192);
|
expect(model.info.maxTokens).toBe(8192)
|
||||||
expect(model.info.contextWindow).toBe(200_000);
|
expect(model.info.contextWindow).toBe(200_000)
|
||||||
expect(model.info.supportsImages).toBe(true);
|
expect(model.info.supportsImages).toBe(true)
|
||||||
expect(model.info.supportsPromptCache).toBe(true);
|
expect(model.info.supportsPromptCache).toBe(true)
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|||||||
@@ -1,246 +1,259 @@
|
|||||||
import { AwsBedrockHandler } from '../bedrock';
|
import { AwsBedrockHandler } from "../bedrock"
|
||||||
import { MessageContent } from '../../../shared/api';
|
import { MessageContent } from "../../../shared/api"
|
||||||
import { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime';
|
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
|
||||||
import { Anthropic } from '@anthropic-ai/sdk';
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
|
|
||||||
describe('AwsBedrockHandler', () => {
|
describe("AwsBedrockHandler", () => {
|
||||||
let handler: AwsBedrockHandler;
|
let handler: AwsBedrockHandler
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
handler = new AwsBedrockHandler({
|
handler = new AwsBedrockHandler({
|
||||||
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
awsAccessKey: 'test-access-key',
|
awsAccessKey: "test-access-key",
|
||||||
awsSecretKey: 'test-secret-key',
|
awsSecretKey: "test-secret-key",
|
||||||
awsRegion: 'us-east-1'
|
awsRegion: "us-east-1",
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('constructor', () => {
|
describe("constructor", () => {
|
||||||
it('should initialize with provided config', () => {
|
it("should initialize with provided config", () => {
|
||||||
expect(handler['options'].awsAccessKey).toBe('test-access-key');
|
expect(handler["options"].awsAccessKey).toBe("test-access-key")
|
||||||
expect(handler['options'].awsSecretKey).toBe('test-secret-key');
|
expect(handler["options"].awsSecretKey).toBe("test-secret-key")
|
||||||
expect(handler['options'].awsRegion).toBe('us-east-1');
|
expect(handler["options"].awsRegion).toBe("us-east-1")
|
||||||
expect(handler['options'].apiModelId).toBe('anthropic.claude-3-5-sonnet-20241022-v2:0');
|
expect(handler["options"].apiModelId).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should initialize with missing AWS credentials', () => {
|
it("should initialize with missing AWS credentials", () => {
|
||||||
const handlerWithoutCreds = new AwsBedrockHandler({
|
const handlerWithoutCreds = new AwsBedrockHandler({
|
||||||
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
awsRegion: 'us-east-1'
|
awsRegion: "us-east-1",
|
||||||
});
|
})
|
||||||
expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler);
|
expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler)
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('createMessage', () => {
|
describe("createMessage", () => {
|
||||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: "user",
|
||||||
content: 'Hello'
|
content: "Hello",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
role: 'assistant',
|
role: "assistant",
|
||||||
content: 'Hi there!'
|
content: "Hi there!",
|
||||||
}
|
},
|
||||||
];
|
]
|
||||||
|
|
||||||
const systemPrompt = 'You are a helpful assistant';
|
const systemPrompt = "You are a helpful assistant"
|
||||||
|
|
||||||
it('should handle text messages correctly', async () => {
|
it("should handle text messages correctly", async () => {
|
||||||
const mockResponse = {
|
const mockResponse = {
|
||||||
messages: [{
|
messages: [
|
||||||
role: 'assistant',
|
{
|
||||||
content: [{ type: 'text', text: 'Hello! How can I help you?' }]
|
role: "assistant",
|
||||||
}],
|
content: [{ type: "text", text: "Hello! How can I help you?" }],
|
||||||
usage: {
|
},
|
||||||
input_tokens: 10,
|
],
|
||||||
output_tokens: 5
|
usage: {
|
||||||
}
|
input_tokens: 10,
|
||||||
};
|
output_tokens: 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// Mock AWS SDK invoke
|
// Mock AWS SDK invoke
|
||||||
const mockStream = {
|
const mockStream = {
|
||||||
[Symbol.asyncIterator]: async function* () {
|
[Symbol.asyncIterator]: async function* () {
|
||||||
yield {
|
yield {
|
||||||
metadata: {
|
metadata: {
|
||||||
usage: {
|
usage: {
|
||||||
inputTokens: 10,
|
inputTokens: 10,
|
||||||
outputTokens: 5
|
outputTokens: 5,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
};
|
}
|
||||||
}
|
},
|
||||||
};
|
}
|
||||||
|
|
||||||
const mockInvoke = jest.fn().mockResolvedValue({
|
const mockInvoke = jest.fn().mockResolvedValue({
|
||||||
stream: mockStream
|
stream: mockStream,
|
||||||
});
|
})
|
||||||
|
|
||||||
handler['client'] = {
|
handler["client"] = {
|
||||||
send: mockInvoke
|
send: mockInvoke,
|
||||||
} as unknown as BedrockRuntimeClient;
|
} as unknown as BedrockRuntimeClient
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||||
const chunks = [];
|
const chunks = []
|
||||||
|
|
||||||
for await (const chunk of stream) {
|
|
||||||
chunks.push(chunk);
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(chunks.length).toBeGreaterThan(0);
|
for await (const chunk of stream) {
|
||||||
expect(chunks[0]).toEqual({
|
chunks.push(chunk)
|
||||||
type: 'usage',
|
}
|
||||||
inputTokens: 10,
|
|
||||||
outputTokens: 5
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(mockInvoke).toHaveBeenCalledWith(expect.objectContaining({
|
expect(chunks.length).toBeGreaterThan(0)
|
||||||
input: expect.objectContaining({
|
expect(chunks[0]).toEqual({
|
||||||
modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0'
|
type: "usage",
|
||||||
})
|
inputTokens: 10,
|
||||||
}));
|
outputTokens: 5,
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
expect(mockInvoke).toHaveBeenCalledWith(
|
||||||
// Mock AWS SDK invoke with error
|
expect.objectContaining({
|
||||||
const mockInvoke = jest.fn().mockRejectedValue(new Error('AWS Bedrock error'));
|
input: expect.objectContaining({
|
||||||
|
modelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
handler['client'] = {
|
it("should handle API errors", async () => {
|
||||||
send: mockInvoke
|
// Mock AWS SDK invoke with error
|
||||||
} as unknown as BedrockRuntimeClient;
|
const mockInvoke = jest.fn().mockRejectedValue(new Error("AWS Bedrock error"))
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
handler["client"] = {
|
||||||
|
send: mockInvoke,
|
||||||
|
} as unknown as BedrockRuntimeClient
|
||||||
|
|
||||||
await expect(async () => {
|
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||||
for await (const chunk of stream) {
|
|
||||||
// Should throw before yielding any chunks
|
|
||||||
}
|
|
||||||
}).rejects.toThrow('AWS Bedrock error');
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('completePrompt', () => {
|
await expect(async () => {
|
||||||
it('should complete prompt successfully', async () => {
|
for await (const chunk of stream) {
|
||||||
const mockResponse = {
|
// Should throw before yielding any chunks
|
||||||
output: new TextEncoder().encode(JSON.stringify({
|
}
|
||||||
content: 'Test response'
|
}).rejects.toThrow("AWS Bedrock error")
|
||||||
}))
|
})
|
||||||
};
|
})
|
||||||
|
|
||||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
describe("completePrompt", () => {
|
||||||
handler['client'] = {
|
it("should complete prompt successfully", async () => {
|
||||||
send: mockSend
|
const mockResponse = {
|
||||||
} as unknown as BedrockRuntimeClient;
|
output: new TextEncoder().encode(
|
||||||
|
JSON.stringify({
|
||||||
|
content: "Test response",
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
const result = await handler.completePrompt('Test prompt');
|
const mockSend = jest.fn().mockResolvedValue(mockResponse)
|
||||||
expect(result).toBe('Test response');
|
handler["client"] = {
|
||||||
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
|
send: mockSend,
|
||||||
input: expect.objectContaining({
|
} as unknown as BedrockRuntimeClient
|
||||||
modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
|
||||||
messages: expect.arrayContaining([
|
|
||||||
expect.objectContaining({
|
|
||||||
role: 'user',
|
|
||||||
content: [{ text: 'Test prompt' }]
|
|
||||||
})
|
|
||||||
]),
|
|
||||||
inferenceConfig: expect.objectContaining({
|
|
||||||
maxTokens: 5000,
|
|
||||||
temperature: 0.3,
|
|
||||||
topP: 0.1
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}));
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
const result = await handler.completePrompt("Test prompt")
|
||||||
const mockError = new Error('AWS Bedrock error');
|
expect(result).toBe("Test response")
|
||||||
const mockSend = jest.fn().mockRejectedValue(mockError);
|
expect(mockSend).toHaveBeenCalledWith(
|
||||||
handler['client'] = {
|
expect.objectContaining({
|
||||||
send: mockSend
|
input: expect.objectContaining({
|
||||||
} as unknown as BedrockRuntimeClient;
|
modelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
messages: expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
role: "user",
|
||||||
|
content: [{ text: "Test prompt" }],
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
inferenceConfig: expect.objectContaining({
|
||||||
|
maxTokens: 5000,
|
||||||
|
temperature: 0.3,
|
||||||
|
topP: 0.1,
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
await expect(handler.completePrompt('Test prompt'))
|
it("should handle API errors", async () => {
|
||||||
.rejects.toThrow('Bedrock completion error: AWS Bedrock error');
|
const mockError = new Error("AWS Bedrock error")
|
||||||
});
|
const mockSend = jest.fn().mockRejectedValue(mockError)
|
||||||
|
handler["client"] = {
|
||||||
|
send: mockSend,
|
||||||
|
} as unknown as BedrockRuntimeClient
|
||||||
|
|
||||||
it('should handle invalid response format', async () => {
|
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||||
const mockResponse = {
|
"Bedrock completion error: AWS Bedrock error",
|
||||||
output: new TextEncoder().encode('invalid json')
|
)
|
||||||
};
|
})
|
||||||
|
|
||||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
it("should handle invalid response format", async () => {
|
||||||
handler['client'] = {
|
const mockResponse = {
|
||||||
send: mockSend
|
output: new TextEncoder().encode("invalid json"),
|
||||||
} as unknown as BedrockRuntimeClient;
|
}
|
||||||
|
|
||||||
const result = await handler.completePrompt('Test prompt');
|
const mockSend = jest.fn().mockResolvedValue(mockResponse)
|
||||||
expect(result).toBe('');
|
handler["client"] = {
|
||||||
});
|
send: mockSend,
|
||||||
|
} as unknown as BedrockRuntimeClient
|
||||||
|
|
||||||
it('should handle empty response', async () => {
|
const result = await handler.completePrompt("Test prompt")
|
||||||
const mockResponse = {
|
expect(result).toBe("")
|
||||||
output: new TextEncoder().encode(JSON.stringify({}))
|
})
|
||||||
};
|
|
||||||
|
|
||||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
it("should handle empty response", async () => {
|
||||||
handler['client'] = {
|
const mockResponse = {
|
||||||
send: mockSend
|
output: new TextEncoder().encode(JSON.stringify({})),
|
||||||
} as unknown as BedrockRuntimeClient;
|
}
|
||||||
|
|
||||||
const result = await handler.completePrompt('Test prompt');
|
const mockSend = jest.fn().mockResolvedValue(mockResponse)
|
||||||
expect(result).toBe('');
|
handler["client"] = {
|
||||||
});
|
send: mockSend,
|
||||||
|
} as unknown as BedrockRuntimeClient
|
||||||
|
|
||||||
it('should handle cross-region inference', async () => {
|
const result = await handler.completePrompt("Test prompt")
|
||||||
handler = new AwsBedrockHandler({
|
expect(result).toBe("")
|
||||||
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
})
|
||||||
awsAccessKey: 'test-access-key',
|
|
||||||
awsSecretKey: 'test-secret-key',
|
|
||||||
awsRegion: 'us-east-1',
|
|
||||||
awsUseCrossRegionInference: true
|
|
||||||
});
|
|
||||||
|
|
||||||
const mockResponse = {
|
it("should handle cross-region inference", async () => {
|
||||||
output: new TextEncoder().encode(JSON.stringify({
|
handler = new AwsBedrockHandler({
|
||||||
content: 'Test response'
|
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
}))
|
awsAccessKey: "test-access-key",
|
||||||
};
|
awsSecretKey: "test-secret-key",
|
||||||
|
awsRegion: "us-east-1",
|
||||||
|
awsUseCrossRegionInference: true,
|
||||||
|
})
|
||||||
|
|
||||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
const mockResponse = {
|
||||||
handler['client'] = {
|
output: new TextEncoder().encode(
|
||||||
send: mockSend
|
JSON.stringify({
|
||||||
} as unknown as BedrockRuntimeClient;
|
content: "Test response",
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
const result = await handler.completePrompt('Test prompt');
|
const mockSend = jest.fn().mockResolvedValue(mockResponse)
|
||||||
expect(result).toBe('Test response');
|
handler["client"] = {
|
||||||
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
|
send: mockSend,
|
||||||
input: expect.objectContaining({
|
} as unknown as BedrockRuntimeClient
|
||||||
modelId: 'us.anthropic.claude-3-5-sonnet-20241022-v2:0'
|
|
||||||
})
|
|
||||||
}));
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('getModel', () => {
|
const result = await handler.completePrompt("Test prompt")
|
||||||
it('should return correct model info in test environment', () => {
|
expect(result).toBe("Test response")
|
||||||
const modelInfo = handler.getModel();
|
expect(mockSend).toHaveBeenCalledWith(
|
||||||
expect(modelInfo.id).toBe('anthropic.claude-3-5-sonnet-20241022-v2:0');
|
expect.objectContaining({
|
||||||
expect(modelInfo.info).toBeDefined();
|
input: expect.objectContaining({
|
||||||
expect(modelInfo.info.maxTokens).toBe(5000); // Test environment value
|
modelId: "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
expect(modelInfo.info.contextWindow).toBe(128_000); // Test environment value
|
}),
|
||||||
});
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
it('should return test model info for invalid model in test environment', () => {
|
describe("getModel", () => {
|
||||||
const invalidHandler = new AwsBedrockHandler({
|
it("should return correct model info in test environment", () => {
|
||||||
apiModelId: 'invalid-model',
|
const modelInfo = handler.getModel()
|
||||||
awsAccessKey: 'test-access-key',
|
expect(modelInfo.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||||
awsSecretKey: 'test-secret-key',
|
expect(modelInfo.info).toBeDefined()
|
||||||
awsRegion: 'us-east-1'
|
expect(modelInfo.info.maxTokens).toBe(5000) // Test environment value
|
||||||
});
|
expect(modelInfo.info.contextWindow).toBe(128_000) // Test environment value
|
||||||
const modelInfo = invalidHandler.getModel();
|
})
|
||||||
expect(modelInfo.id).toBe('invalid-model'); // In test env, returns whatever is passed
|
|
||||||
expect(modelInfo.info.maxTokens).toBe(5000);
|
it("should return test model info for invalid model in test environment", () => {
|
||||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
const invalidHandler = new AwsBedrockHandler({
|
||||||
});
|
apiModelId: "invalid-model",
|
||||||
});
|
awsAccessKey: "test-access-key",
|
||||||
});
|
awsSecretKey: "test-secret-key",
|
||||||
|
awsRegion: "us-east-1",
|
||||||
|
})
|
||||||
|
const modelInfo = invalidHandler.getModel()
|
||||||
|
expect(modelInfo.id).toBe("invalid-model") // In test env, returns whatever is passed
|
||||||
|
expect(modelInfo.info.maxTokens).toBe(5000)
|
||||||
|
expect(modelInfo.info.contextWindow).toBe(128_000)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -1,203 +1,217 @@
|
|||||||
import { DeepSeekHandler } from '../deepseek';
|
import { DeepSeekHandler } from "../deepseek"
|
||||||
import { ApiHandlerOptions, deepSeekDefaultModelId } from '../../../shared/api';
|
import { ApiHandlerOptions, deepSeekDefaultModelId } from "../../../shared/api"
|
||||||
import OpenAI from 'openai';
|
import OpenAI from "openai"
|
||||||
import { Anthropic } from '@anthropic-ai/sdk';
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
|
|
||||||
// Mock OpenAI client
|
// Mock OpenAI client
|
||||||
const mockCreate = jest.fn();
|
const mockCreate = jest.fn()
|
||||||
jest.mock('openai', () => {
|
jest.mock("openai", () => {
|
||||||
return {
|
return {
|
||||||
__esModule: true,
|
__esModule: true,
|
||||||
default: jest.fn().mockImplementation(() => ({
|
default: jest.fn().mockImplementation(() => ({
|
||||||
chat: {
|
chat: {
|
||||||
completions: {
|
completions: {
|
||||||
create: mockCreate.mockImplementation(async (options) => {
|
create: mockCreate.mockImplementation(async (options) => {
|
||||||
if (!options.stream) {
|
if (!options.stream) {
|
||||||
return {
|
return {
|
||||||
id: 'test-completion',
|
id: "test-completion",
|
||||||
choices: [{
|
choices: [
|
||||||
message: { role: 'assistant', content: 'Test response', refusal: null },
|
{
|
||||||
finish_reason: 'stop',
|
message: { role: "assistant", content: "Test response", refusal: null },
|
||||||
index: 0
|
finish_reason: "stop",
|
||||||
}],
|
index: 0,
|
||||||
usage: {
|
},
|
||||||
prompt_tokens: 10,
|
],
|
||||||
completion_tokens: 5,
|
usage: {
|
||||||
total_tokens: 15
|
prompt_tokens: 10,
|
||||||
}
|
completion_tokens: 5,
|
||||||
};
|
total_tokens: 15,
|
||||||
}
|
},
|
||||||
|
}
|
||||||
// Return async iterator for streaming
|
}
|
||||||
return {
|
|
||||||
[Symbol.asyncIterator]: async function* () {
|
|
||||||
yield {
|
|
||||||
choices: [{
|
|
||||||
delta: { content: 'Test response' },
|
|
||||||
index: 0
|
|
||||||
}],
|
|
||||||
usage: null
|
|
||||||
};
|
|
||||||
yield {
|
|
||||||
choices: [{
|
|
||||||
delta: {},
|
|
||||||
index: 0
|
|
||||||
}],
|
|
||||||
usage: {
|
|
||||||
prompt_tokens: 10,
|
|
||||||
completion_tokens: 5,
|
|
||||||
total_tokens: 15
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('DeepSeekHandler', () => {
|
// Return async iterator for streaming
|
||||||
let handler: DeepSeekHandler;
|
return {
|
||||||
let mockOptions: ApiHandlerOptions;
|
[Symbol.asyncIterator]: async function* () {
|
||||||
|
yield {
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: { content: "Test response" },
|
||||||
|
index: 0,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usage: null,
|
||||||
|
}
|
||||||
|
yield {
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {},
|
||||||
|
index: 0,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 5,
|
||||||
|
total_tokens: 15,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
beforeEach(() => {
|
describe("DeepSeekHandler", () => {
|
||||||
mockOptions = {
|
let handler: DeepSeekHandler
|
||||||
deepSeekApiKey: 'test-api-key',
|
let mockOptions: ApiHandlerOptions
|
||||||
deepSeekModelId: 'deepseek-chat',
|
|
||||||
deepSeekBaseUrl: 'https://api.deepseek.com/v1'
|
|
||||||
};
|
|
||||||
handler = new DeepSeekHandler(mockOptions);
|
|
||||||
mockCreate.mockClear();
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('constructor', () => {
|
beforeEach(() => {
|
||||||
it('should initialize with provided options', () => {
|
mockOptions = {
|
||||||
expect(handler).toBeInstanceOf(DeepSeekHandler);
|
deepSeekApiKey: "test-api-key",
|
||||||
expect(handler.getModel().id).toBe(mockOptions.deepSeekModelId);
|
deepSeekModelId: "deepseek-chat",
|
||||||
});
|
deepSeekBaseUrl: "https://api.deepseek.com/v1",
|
||||||
|
}
|
||||||
|
handler = new DeepSeekHandler(mockOptions)
|
||||||
|
mockCreate.mockClear()
|
||||||
|
})
|
||||||
|
|
||||||
it('should throw error if API key is missing', () => {
|
describe("constructor", () => {
|
||||||
expect(() => {
|
it("should initialize with provided options", () => {
|
||||||
new DeepSeekHandler({
|
expect(handler).toBeInstanceOf(DeepSeekHandler)
|
||||||
...mockOptions,
|
expect(handler.getModel().id).toBe(mockOptions.deepSeekModelId)
|
||||||
deepSeekApiKey: undefined
|
})
|
||||||
});
|
|
||||||
}).toThrow('DeepSeek API key is required');
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should use default model ID if not provided', () => {
|
it("should throw error if API key is missing", () => {
|
||||||
const handlerWithoutModel = new DeepSeekHandler({
|
expect(() => {
|
||||||
...mockOptions,
|
new DeepSeekHandler({
|
||||||
deepSeekModelId: undefined
|
...mockOptions,
|
||||||
});
|
deepSeekApiKey: undefined,
|
||||||
expect(handlerWithoutModel.getModel().id).toBe(deepSeekDefaultModelId);
|
})
|
||||||
});
|
}).toThrow("DeepSeek API key is required")
|
||||||
|
})
|
||||||
|
|
||||||
it('should use default base URL if not provided', () => {
|
it("should use default model ID if not provided", () => {
|
||||||
const handlerWithoutBaseUrl = new DeepSeekHandler({
|
const handlerWithoutModel = new DeepSeekHandler({
|
||||||
...mockOptions,
|
...mockOptions,
|
||||||
deepSeekBaseUrl: undefined
|
deepSeekModelId: undefined,
|
||||||
});
|
})
|
||||||
expect(handlerWithoutBaseUrl).toBeInstanceOf(DeepSeekHandler);
|
expect(handlerWithoutModel.getModel().id).toBe(deepSeekDefaultModelId)
|
||||||
// The base URL is passed to OpenAI client internally
|
})
|
||||||
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({
|
|
||||||
baseURL: 'https://api.deepseek.com/v1'
|
|
||||||
}));
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should use custom base URL if provided', () => {
|
it("should use default base URL if not provided", () => {
|
||||||
const customBaseUrl = 'https://custom.deepseek.com/v1';
|
const handlerWithoutBaseUrl = new DeepSeekHandler({
|
||||||
const handlerWithCustomUrl = new DeepSeekHandler({
|
...mockOptions,
|
||||||
...mockOptions,
|
deepSeekBaseUrl: undefined,
|
||||||
deepSeekBaseUrl: customBaseUrl
|
})
|
||||||
});
|
expect(handlerWithoutBaseUrl).toBeInstanceOf(DeepSeekHandler)
|
||||||
expect(handlerWithCustomUrl).toBeInstanceOf(DeepSeekHandler);
|
// The base URL is passed to OpenAI client internally
|
||||||
// The custom base URL is passed to OpenAI client
|
expect(OpenAI).toHaveBeenCalledWith(
|
||||||
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({
|
expect.objectContaining({
|
||||||
baseURL: customBaseUrl
|
baseURL: "https://api.deepseek.com/v1",
|
||||||
}));
|
}),
|
||||||
});
|
)
|
||||||
|
})
|
||||||
|
|
||||||
it('should set includeMaxTokens to true', () => {
|
it("should use custom base URL if provided", () => {
|
||||||
// Create a new handler and verify OpenAI client was called with includeMaxTokens
|
const customBaseUrl = "https://custom.deepseek.com/v1"
|
||||||
new DeepSeekHandler(mockOptions);
|
const handlerWithCustomUrl = new DeepSeekHandler({
|
||||||
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({
|
...mockOptions,
|
||||||
apiKey: mockOptions.deepSeekApiKey
|
deepSeekBaseUrl: customBaseUrl,
|
||||||
}));
|
})
|
||||||
});
|
expect(handlerWithCustomUrl).toBeInstanceOf(DeepSeekHandler)
|
||||||
});
|
// The custom base URL is passed to OpenAI client
|
||||||
|
expect(OpenAI).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
baseURL: customBaseUrl,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
describe('getModel', () => {
|
it("should set includeMaxTokens to true", () => {
|
||||||
it('should return model info for valid model ID', () => {
|
// Create a new handler and verify OpenAI client was called with includeMaxTokens
|
||||||
const model = handler.getModel();
|
new DeepSeekHandler(mockOptions)
|
||||||
expect(model.id).toBe(mockOptions.deepSeekModelId);
|
expect(OpenAI).toHaveBeenCalledWith(
|
||||||
expect(model.info).toBeDefined();
|
expect.objectContaining({
|
||||||
expect(model.info.maxTokens).toBe(8192);
|
apiKey: mockOptions.deepSeekApiKey,
|
||||||
expect(model.info.contextWindow).toBe(64_000);
|
}),
|
||||||
expect(model.info.supportsImages).toBe(false);
|
)
|
||||||
expect(model.info.supportsPromptCache).toBe(false);
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should return provided model ID with default model info if model does not exist', () => {
|
describe("getModel", () => {
|
||||||
const handlerWithInvalidModel = new DeepSeekHandler({
|
it("should return model info for valid model ID", () => {
|
||||||
...mockOptions,
|
const model = handler.getModel()
|
||||||
deepSeekModelId: 'invalid-model'
|
expect(model.id).toBe(mockOptions.deepSeekModelId)
|
||||||
});
|
expect(model.info).toBeDefined()
|
||||||
const model = handlerWithInvalidModel.getModel();
|
expect(model.info.maxTokens).toBe(8192)
|
||||||
expect(model.id).toBe('invalid-model'); // Returns provided ID
|
expect(model.info.contextWindow).toBe(64_000)
|
||||||
expect(model.info).toBeDefined();
|
expect(model.info.supportsImages).toBe(false)
|
||||||
expect(model.info).toBe(handler.getModel().info); // But uses default model info
|
expect(model.info.supportsPromptCache).toBe(false)
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should return default model if no model ID is provided', () => {
|
it("should return provided model ID with default model info if model does not exist", () => {
|
||||||
const handlerWithoutModel = new DeepSeekHandler({
|
const handlerWithInvalidModel = new DeepSeekHandler({
|
||||||
...mockOptions,
|
...mockOptions,
|
||||||
deepSeekModelId: undefined
|
deepSeekModelId: "invalid-model",
|
||||||
});
|
})
|
||||||
const model = handlerWithoutModel.getModel();
|
const model = handlerWithInvalidModel.getModel()
|
||||||
expect(model.id).toBe(deepSeekDefaultModelId);
|
expect(model.id).toBe("invalid-model") // Returns provided ID
|
||||||
expect(model.info).toBeDefined();
|
expect(model.info).toBeDefined()
|
||||||
});
|
expect(model.info).toBe(handler.getModel().info) // But uses default model info
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('createMessage', () => {
|
it("should return default model if no model ID is provided", () => {
|
||||||
const systemPrompt = 'You are a helpful assistant.';
|
const handlerWithoutModel = new DeepSeekHandler({
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [
|
...mockOptions,
|
||||||
{
|
deepSeekModelId: undefined,
|
||||||
role: 'user',
|
})
|
||||||
content: [{
|
const model = handlerWithoutModel.getModel()
|
||||||
type: 'text' as const,
|
expect(model.id).toBe(deepSeekDefaultModelId)
|
||||||
text: 'Hello!'
|
expect(model.info).toBeDefined()
|
||||||
}]
|
})
|
||||||
}
|
})
|
||||||
];
|
|
||||||
|
|
||||||
it('should handle streaming responses', async () => {
|
describe("createMessage", () => {
|
||||||
const stream = handler.createMessage(systemPrompt, messages);
|
const systemPrompt = "You are a helpful assistant."
|
||||||
const chunks: any[] = [];
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
for await (const chunk of stream) {
|
{
|
||||||
chunks.push(chunk);
|
role: "user",
|
||||||
}
|
content: [
|
||||||
|
{
|
||||||
|
type: "text" as const,
|
||||||
|
text: "Hello!",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
expect(chunks.length).toBeGreaterThan(0);
|
it("should handle streaming responses", async () => {
|
||||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
const stream = handler.createMessage(systemPrompt, messages)
|
||||||
expect(textChunks).toHaveLength(1);
|
const chunks: any[] = []
|
||||||
expect(textChunks[0].text).toBe('Test response');
|
for await (const chunk of stream) {
|
||||||
});
|
chunks.push(chunk)
|
||||||
|
}
|
||||||
|
|
||||||
it('should include usage information', async () => {
|
expect(chunks.length).toBeGreaterThan(0)
|
||||||
const stream = handler.createMessage(systemPrompt, messages);
|
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||||
const chunks: any[] = [];
|
expect(textChunks).toHaveLength(1)
|
||||||
for await (const chunk of stream) {
|
expect(textChunks[0].text).toBe("Test response")
|
||||||
chunks.push(chunk);
|
})
|
||||||
}
|
|
||||||
|
|
||||||
const usageChunks = chunks.filter(chunk => chunk.type === 'usage');
|
it("should include usage information", async () => {
|
||||||
expect(usageChunks.length).toBeGreaterThan(0);
|
const stream = handler.createMessage(systemPrompt, messages)
|
||||||
expect(usageChunks[0].inputTokens).toBe(10);
|
const chunks: any[] = []
|
||||||
expect(usageChunks[0].outputTokens).toBe(5);
|
for await (const chunk of stream) {
|
||||||
});
|
chunks.push(chunk)
|
||||||
});
|
}
|
||||||
});
|
|
||||||
|
const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
|
||||||
|
expect(usageChunks.length).toBeGreaterThan(0)
|
||||||
|
expect(usageChunks[0].inputTokens).toBe(10)
|
||||||
|
expect(usageChunks[0].outputTokens).toBe(5)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -1,212 +1,210 @@
|
|||||||
import { GeminiHandler } from '../gemini';
|
import { GeminiHandler } from "../gemini"
|
||||||
import { Anthropic } from '@anthropic-ai/sdk';
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import { GoogleGenerativeAI } from '@google/generative-ai';
|
import { GoogleGenerativeAI } from "@google/generative-ai"
|
||||||
|
|
||||||
// Mock the Google Generative AI SDK
|
// Mock the Google Generative AI SDK
|
||||||
jest.mock('@google/generative-ai', () => ({
|
jest.mock("@google/generative-ai", () => ({
|
||||||
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
|
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
|
||||||
getGenerativeModel: jest.fn().mockReturnValue({
|
getGenerativeModel: jest.fn().mockReturnValue({
|
||||||
generateContentStream: jest.fn(),
|
generateContentStream: jest.fn(),
|
||||||
generateContent: jest.fn().mockResolvedValue({
|
generateContent: jest.fn().mockResolvedValue({
|
||||||
response: {
|
response: {
|
||||||
text: () => 'Test response'
|
text: () => "Test response",
|
||||||
}
|
},
|
||||||
})
|
}),
|
||||||
})
|
}),
|
||||||
}))
|
})),
|
||||||
}));
|
}))
|
||||||
|
|
||||||
describe('GeminiHandler', () => {
|
describe("GeminiHandler", () => {
|
||||||
let handler: GeminiHandler;
|
let handler: GeminiHandler
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
handler = new GeminiHandler({
|
handler = new GeminiHandler({
|
||||||
apiKey: 'test-key',
|
apiKey: "test-key",
|
||||||
apiModelId: 'gemini-2.0-flash-thinking-exp-1219',
|
apiModelId: "gemini-2.0-flash-thinking-exp-1219",
|
||||||
geminiApiKey: 'test-key'
|
geminiApiKey: "test-key",
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('constructor', () => {
|
describe("constructor", () => {
|
||||||
it('should initialize with provided config', () => {
|
it("should initialize with provided config", () => {
|
||||||
expect(handler['options'].geminiApiKey).toBe('test-key');
|
expect(handler["options"].geminiApiKey).toBe("test-key")
|
||||||
expect(handler['options'].apiModelId).toBe('gemini-2.0-flash-thinking-exp-1219');
|
expect(handler["options"].apiModelId).toBe("gemini-2.0-flash-thinking-exp-1219")
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should throw if API key is missing', () => {
|
it("should throw if API key is missing", () => {
|
||||||
expect(() => {
|
expect(() => {
|
||||||
new GeminiHandler({
|
new GeminiHandler({
|
||||||
apiModelId: 'gemini-2.0-flash-thinking-exp-1219',
|
apiModelId: "gemini-2.0-flash-thinking-exp-1219",
|
||||||
geminiApiKey: ''
|
geminiApiKey: "",
|
||||||
});
|
})
|
||||||
}).toThrow('API key is required for Google Gemini');
|
}).toThrow("API key is required for Google Gemini")
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('createMessage', () => {
|
describe("createMessage", () => {
|
||||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: "user",
|
||||||
content: 'Hello'
|
content: "Hello",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
role: 'assistant',
|
role: "assistant",
|
||||||
content: 'Hi there!'
|
content: "Hi there!",
|
||||||
}
|
},
|
||||||
];
|
]
|
||||||
|
|
||||||
const systemPrompt = 'You are a helpful assistant';
|
const systemPrompt = "You are a helpful assistant"
|
||||||
|
|
||||||
it('should handle text messages correctly', async () => {
|
it("should handle text messages correctly", async () => {
|
||||||
// Mock the stream response
|
// Mock the stream response
|
||||||
const mockStream = {
|
const mockStream = {
|
||||||
stream: [
|
stream: [{ text: () => "Hello" }, { text: () => " world!" }],
|
||||||
{ text: () => 'Hello' },
|
response: {
|
||||||
{ text: () => ' world!' }
|
usageMetadata: {
|
||||||
],
|
promptTokenCount: 10,
|
||||||
response: {
|
candidatesTokenCount: 5,
|
||||||
usageMetadata: {
|
},
|
||||||
promptTokenCount: 10,
|
},
|
||||||
candidatesTokenCount: 5
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Setup the mock implementation
|
// Setup the mock implementation
|
||||||
const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream);
|
const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream)
|
||||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||||
generateContentStream: mockGenerateContentStream
|
generateContentStream: mockGenerateContentStream,
|
||||||
});
|
})
|
||||||
|
|
||||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||||
const chunks = [];
|
const chunks = []
|
||||||
|
|
||||||
for await (const chunk of stream) {
|
|
||||||
chunks.push(chunk);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should have 3 chunks: 'Hello', ' world!', and usage info
|
for await (const chunk of stream) {
|
||||||
expect(chunks.length).toBe(3);
|
chunks.push(chunk)
|
||||||
expect(chunks[0]).toEqual({
|
}
|
||||||
type: 'text',
|
|
||||||
text: 'Hello'
|
|
||||||
});
|
|
||||||
expect(chunks[1]).toEqual({
|
|
||||||
type: 'text',
|
|
||||||
text: ' world!'
|
|
||||||
});
|
|
||||||
expect(chunks[2]).toEqual({
|
|
||||||
type: 'usage',
|
|
||||||
inputTokens: 10,
|
|
||||||
outputTokens: 5
|
|
||||||
});
|
|
||||||
|
|
||||||
// Verify the model configuration
|
// Should have 3 chunks: 'Hello', ' world!', and usage info
|
||||||
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
|
expect(chunks.length).toBe(3)
|
||||||
model: 'gemini-2.0-flash-thinking-exp-1219',
|
expect(chunks[0]).toEqual({
|
||||||
systemInstruction: systemPrompt
|
type: "text",
|
||||||
});
|
text: "Hello",
|
||||||
|
})
|
||||||
|
expect(chunks[1]).toEqual({
|
||||||
|
type: "text",
|
||||||
|
text: " world!",
|
||||||
|
})
|
||||||
|
expect(chunks[2]).toEqual({
|
||||||
|
type: "usage",
|
||||||
|
inputTokens: 10,
|
||||||
|
outputTokens: 5,
|
||||||
|
})
|
||||||
|
|
||||||
// Verify generation config
|
// Verify the model configuration
|
||||||
expect(mockGenerateContentStream).toHaveBeenCalledWith(
|
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
|
||||||
expect.objectContaining({
|
model: "gemini-2.0-flash-thinking-exp-1219",
|
||||||
generationConfig: {
|
systemInstruction: systemPrompt,
|
||||||
temperature: 0
|
})
|
||||||
}
|
|
||||||
})
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
// Verify generation config
|
||||||
const mockError = new Error('Gemini API error');
|
expect(mockGenerateContentStream).toHaveBeenCalledWith(
|
||||||
const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError);
|
expect.objectContaining({
|
||||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
generationConfig: {
|
||||||
generateContentStream: mockGenerateContentStream
|
temperature: 0,
|
||||||
});
|
},
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
it("should handle API errors", async () => {
|
||||||
|
const mockError = new Error("Gemini API error")
|
||||||
|
const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError)
|
||||||
|
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||||
|
generateContentStream: mockGenerateContentStream,
|
||||||
|
})
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||||
|
|
||||||
await expect(async () => {
|
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||||
for await (const chunk of stream) {
|
|
||||||
// Should throw before yielding any chunks
|
|
||||||
}
|
|
||||||
}).rejects.toThrow('Gemini API error');
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('completePrompt', () => {
|
await expect(async () => {
|
||||||
it('should complete prompt successfully', async () => {
|
for await (const chunk of stream) {
|
||||||
const mockGenerateContent = jest.fn().mockResolvedValue({
|
// Should throw before yielding any chunks
|
||||||
response: {
|
}
|
||||||
text: () => 'Test response'
|
}).rejects.toThrow("Gemini API error")
|
||||||
}
|
})
|
||||||
});
|
})
|
||||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
|
||||||
generateContent: mockGenerateContent
|
|
||||||
});
|
|
||||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
|
||||||
|
|
||||||
const result = await handler.completePrompt('Test prompt');
|
describe("completePrompt", () => {
|
||||||
expect(result).toBe('Test response');
|
it("should complete prompt successfully", async () => {
|
||||||
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
|
const mockGenerateContent = jest.fn().mockResolvedValue({
|
||||||
model: 'gemini-2.0-flash-thinking-exp-1219'
|
response: {
|
||||||
});
|
text: () => "Test response",
|
||||||
expect(mockGenerateContent).toHaveBeenCalledWith({
|
},
|
||||||
contents: [{ role: 'user', parts: [{ text: 'Test prompt' }] }],
|
})
|
||||||
generationConfig: {
|
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||||
temperature: 0
|
generateContent: mockGenerateContent,
|
||||||
}
|
})
|
||||||
});
|
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
const result = await handler.completePrompt("Test prompt")
|
||||||
const mockError = new Error('Gemini API error');
|
expect(result).toBe("Test response")
|
||||||
const mockGenerateContent = jest.fn().mockRejectedValue(mockError);
|
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
|
||||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
model: "gemini-2.0-flash-thinking-exp-1219",
|
||||||
generateContent: mockGenerateContent
|
})
|
||||||
});
|
expect(mockGenerateContent).toHaveBeenCalledWith({
|
||||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
|
||||||
|
generationConfig: {
|
||||||
|
temperature: 0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
await expect(handler.completePrompt('Test prompt'))
|
it("should handle API errors", async () => {
|
||||||
.rejects.toThrow('Gemini completion error: Gemini API error');
|
const mockError = new Error("Gemini API error")
|
||||||
});
|
const mockGenerateContent = jest.fn().mockRejectedValue(mockError)
|
||||||
|
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||||
|
generateContent: mockGenerateContent,
|
||||||
|
})
|
||||||
|
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||||
|
|
||||||
it('should handle empty response', async () => {
|
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||||
const mockGenerateContent = jest.fn().mockResolvedValue({
|
"Gemini completion error: Gemini API error",
|
||||||
response: {
|
)
|
||||||
text: () => ''
|
})
|
||||||
}
|
|
||||||
});
|
|
||||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
|
||||||
generateContent: mockGenerateContent
|
|
||||||
});
|
|
||||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
|
||||||
|
|
||||||
const result = await handler.completePrompt('Test prompt');
|
it("should handle empty response", async () => {
|
||||||
expect(result).toBe('');
|
const mockGenerateContent = jest.fn().mockResolvedValue({
|
||||||
});
|
response: {
|
||||||
});
|
text: () => "",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||||
|
generateContent: mockGenerateContent,
|
||||||
|
})
|
||||||
|
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||||
|
|
||||||
describe('getModel', () => {
|
const result = await handler.completePrompt("Test prompt")
|
||||||
it('should return correct model info', () => {
|
expect(result).toBe("")
|
||||||
const modelInfo = handler.getModel();
|
})
|
||||||
expect(modelInfo.id).toBe('gemini-2.0-flash-thinking-exp-1219');
|
})
|
||||||
expect(modelInfo.info).toBeDefined();
|
|
||||||
expect(modelInfo.info.maxTokens).toBe(8192);
|
|
||||||
expect(modelInfo.info.contextWindow).toBe(32_767);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should return default model if invalid model specified', () => {
|
describe("getModel", () => {
|
||||||
const invalidHandler = new GeminiHandler({
|
it("should return correct model info", () => {
|
||||||
apiModelId: 'invalid-model',
|
const modelInfo = handler.getModel()
|
||||||
geminiApiKey: 'test-key'
|
expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219")
|
||||||
});
|
expect(modelInfo.info).toBeDefined()
|
||||||
const modelInfo = invalidHandler.getModel();
|
expect(modelInfo.info.maxTokens).toBe(8192)
|
||||||
expect(modelInfo.id).toBe('gemini-2.0-flash-thinking-exp-1219'); // Default model
|
expect(modelInfo.info.contextWindow).toBe(32_767)
|
||||||
});
|
})
|
||||||
});
|
|
||||||
});
|
it("should return default model if invalid model specified", () => {
|
||||||
|
const invalidHandler = new GeminiHandler({
|
||||||
|
apiModelId: "invalid-model",
|
||||||
|
geminiApiKey: "test-key",
|
||||||
|
})
|
||||||
|
const modelInfo = invalidHandler.getModel()
|
||||||
|
expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219") // Default model
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -1,226 +1,238 @@
|
|||||||
import { GlamaHandler } from '../glama';
|
import { GlamaHandler } from "../glama"
|
||||||
import { ApiHandlerOptions } from '../../../shared/api';
|
import { ApiHandlerOptions } from "../../../shared/api"
|
||||||
import OpenAI from 'openai';
|
import OpenAI from "openai"
|
||||||
import { Anthropic } from '@anthropic-ai/sdk';
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import axios from 'axios';
|
import axios from "axios"
|
||||||
|
|
||||||
// Mock OpenAI client
|
// Mock OpenAI client
|
||||||
const mockCreate = jest.fn();
|
const mockCreate = jest.fn()
|
||||||
const mockWithResponse = jest.fn();
|
const mockWithResponse = jest.fn()
|
||||||
|
|
||||||
jest.mock('openai', () => {
|
jest.mock("openai", () => {
|
||||||
return {
|
return {
|
||||||
__esModule: true,
|
__esModule: true,
|
||||||
default: jest.fn().mockImplementation(() => ({
|
default: jest.fn().mockImplementation(() => ({
|
||||||
chat: {
|
chat: {
|
||||||
completions: {
|
completions: {
|
||||||
create: (...args: any[]) => {
|
create: (...args: any[]) => {
|
||||||
const stream = {
|
const stream = {
|
||||||
[Symbol.asyncIterator]: async function* () {
|
[Symbol.asyncIterator]: async function* () {
|
||||||
yield {
|
yield {
|
||||||
choices: [{
|
choices: [
|
||||||
delta: { content: 'Test response' },
|
{
|
||||||
index: 0
|
delta: { content: "Test response" },
|
||||||
}],
|
index: 0,
|
||||||
usage: null
|
},
|
||||||
};
|
],
|
||||||
yield {
|
usage: null,
|
||||||
choices: [{
|
}
|
||||||
delta: {},
|
yield {
|
||||||
index: 0
|
choices: [
|
||||||
}],
|
{
|
||||||
usage: {
|
delta: {},
|
||||||
prompt_tokens: 10,
|
index: 0,
|
||||||
completion_tokens: 5,
|
},
|
||||||
total_tokens: 15
|
],
|
||||||
}
|
usage: {
|
||||||
};
|
prompt_tokens: 10,
|
||||||
}
|
completion_tokens: 5,
|
||||||
};
|
total_tokens: 15,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
const result = mockCreate(...args);
|
const result = mockCreate(...args)
|
||||||
if (args[0].stream) {
|
if (args[0].stream) {
|
||||||
mockWithResponse.mockReturnValue(Promise.resolve({
|
mockWithResponse.mockReturnValue(
|
||||||
data: stream,
|
Promise.resolve({
|
||||||
response: {
|
data: stream,
|
||||||
headers: {
|
response: {
|
||||||
get: (name: string) => name === 'x-completion-request-id' ? 'test-request-id' : null
|
headers: {
|
||||||
}
|
get: (name: string) =>
|
||||||
}
|
name === "x-completion-request-id" ? "test-request-id" : null,
|
||||||
}));
|
},
|
||||||
result.withResponse = mockWithResponse;
|
},
|
||||||
}
|
}),
|
||||||
return result;
|
)
|
||||||
}
|
result.withResponse = mockWithResponse
|
||||||
}
|
}
|
||||||
}
|
return result
|
||||||
}))
|
},
|
||||||
};
|
},
|
||||||
});
|
},
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
describe('GlamaHandler', () => {
|
describe("GlamaHandler", () => {
|
||||||
let handler: GlamaHandler;
|
let handler: GlamaHandler
|
||||||
let mockOptions: ApiHandlerOptions;
|
let mockOptions: ApiHandlerOptions
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
mockOptions = {
|
mockOptions = {
|
||||||
apiModelId: 'anthropic/claude-3-5-sonnet',
|
apiModelId: "anthropic/claude-3-5-sonnet",
|
||||||
glamaModelId: 'anthropic/claude-3-5-sonnet',
|
glamaModelId: "anthropic/claude-3-5-sonnet",
|
||||||
glamaApiKey: 'test-api-key'
|
glamaApiKey: "test-api-key",
|
||||||
};
|
}
|
||||||
handler = new GlamaHandler(mockOptions);
|
handler = new GlamaHandler(mockOptions)
|
||||||
mockCreate.mockClear();
|
mockCreate.mockClear()
|
||||||
mockWithResponse.mockClear();
|
mockWithResponse.mockClear()
|
||||||
|
|
||||||
// Default mock implementation for non-streaming responses
|
// Default mock implementation for non-streaming responses
|
||||||
mockCreate.mockResolvedValue({
|
mockCreate.mockResolvedValue({
|
||||||
id: 'test-completion',
|
id: "test-completion",
|
||||||
choices: [{
|
choices: [
|
||||||
message: { role: 'assistant', content: 'Test response' },
|
{
|
||||||
finish_reason: 'stop',
|
message: { role: "assistant", content: "Test response" },
|
||||||
index: 0
|
finish_reason: "stop",
|
||||||
}],
|
index: 0,
|
||||||
usage: {
|
},
|
||||||
prompt_tokens: 10,
|
],
|
||||||
completion_tokens: 5,
|
usage: {
|
||||||
total_tokens: 15
|
prompt_tokens: 10,
|
||||||
}
|
completion_tokens: 5,
|
||||||
});
|
total_tokens: 15,
|
||||||
});
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
describe('constructor', () => {
|
describe("constructor", () => {
|
||||||
it('should initialize with provided options', () => {
|
it("should initialize with provided options", () => {
|
||||||
expect(handler).toBeInstanceOf(GlamaHandler);
|
expect(handler).toBeInstanceOf(GlamaHandler)
|
||||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
|
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('createMessage', () => {
|
describe("createMessage", () => {
|
||||||
const systemPrompt = 'You are a helpful assistant.';
|
const systemPrompt = "You are a helpful assistant."
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: "user",
|
||||||
content: 'Hello!'
|
content: "Hello!",
|
||||||
}
|
},
|
||||||
];
|
]
|
||||||
|
|
||||||
it('should handle streaming responses', async () => {
|
it("should handle streaming responses", async () => {
|
||||||
// Mock axios for token usage request
|
// Mock axios for token usage request
|
||||||
const mockAxios = jest.spyOn(axios, 'get').mockResolvedValueOnce({
|
const mockAxios = jest.spyOn(axios, "get").mockResolvedValueOnce({
|
||||||
data: {
|
data: {
|
||||||
tokenUsage: {
|
tokenUsage: {
|
||||||
promptTokens: 10,
|
promptTokens: 10,
|
||||||
completionTokens: 5,
|
completionTokens: 5,
|
||||||
cacheCreationInputTokens: 0,
|
cacheCreationInputTokens: 0,
|
||||||
cacheReadInputTokens: 0
|
cacheReadInputTokens: 0,
|
||||||
},
|
},
|
||||||
totalCostUsd: "0.00"
|
totalCostUsd: "0.00",
|
||||||
}
|
},
|
||||||
});
|
})
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, messages);
|
const stream = handler.createMessage(systemPrompt, messages)
|
||||||
const chunks: any[] = [];
|
const chunks: any[] = []
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
chunks.push(chunk);
|
chunks.push(chunk)
|
||||||
}
|
}
|
||||||
|
|
||||||
expect(chunks.length).toBe(2); // Text chunk and usage chunk
|
expect(chunks.length).toBe(2) // Text chunk and usage chunk
|
||||||
expect(chunks[0]).toEqual({
|
expect(chunks[0]).toEqual({
|
||||||
type: 'text',
|
type: "text",
|
||||||
text: 'Test response'
|
text: "Test response",
|
||||||
});
|
})
|
||||||
expect(chunks[1]).toEqual({
|
expect(chunks[1]).toEqual({
|
||||||
type: 'usage',
|
type: "usage",
|
||||||
inputTokens: 10,
|
inputTokens: 10,
|
||||||
outputTokens: 5,
|
outputTokens: 5,
|
||||||
cacheWriteTokens: 0,
|
cacheWriteTokens: 0,
|
||||||
cacheReadTokens: 0,
|
cacheReadTokens: 0,
|
||||||
totalCost: 0
|
totalCost: 0,
|
||||||
});
|
})
|
||||||
|
|
||||||
mockAxios.mockRestore();
|
mockAxios.mockRestore()
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
it("should handle API errors", async () => {
|
||||||
mockCreate.mockImplementationOnce(() => {
|
mockCreate.mockImplementationOnce(() => {
|
||||||
throw new Error('API Error');
|
throw new Error("API Error")
|
||||||
});
|
})
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, messages);
|
const stream = handler.createMessage(systemPrompt, messages)
|
||||||
const chunks = [];
|
const chunks = []
|
||||||
|
|
||||||
try {
|
try {
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
chunks.push(chunk);
|
chunks.push(chunk)
|
||||||
}
|
}
|
||||||
fail('Expected error to be thrown');
|
fail("Expected error to be thrown")
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
expect(error).toBeInstanceOf(Error);
|
expect(error).toBeInstanceOf(Error)
|
||||||
expect(error.message).toBe('API Error');
|
expect(error.message).toBe("API Error")
|
||||||
}
|
}
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('completePrompt', () => {
|
describe("completePrompt", () => {
|
||||||
it('should complete prompt successfully', async () => {
|
it("should complete prompt successfully", async () => {
|
||||||
const result = await handler.completePrompt('Test prompt');
|
const result = await handler.completePrompt("Test prompt")
|
||||||
expect(result).toBe('Test response');
|
expect(result).toBe("Test response")
|
||||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
expect(mockCreate).toHaveBeenCalledWith(
|
||||||
model: mockOptions.apiModelId,
|
expect.objectContaining({
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
model: mockOptions.apiModelId,
|
||||||
temperature: 0,
|
messages: [{ role: "user", content: "Test prompt" }],
|
||||||
max_tokens: 8192
|
temperature: 0,
|
||||||
}));
|
max_tokens: 8192,
|
||||||
});
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
it("should handle API errors", async () => {
|
||||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||||
await expect(handler.completePrompt('Test prompt'))
|
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Glama completion error: API Error")
|
||||||
.rejects.toThrow('Glama completion error: API Error');
|
})
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle empty response', async () => {
|
it("should handle empty response", async () => {
|
||||||
mockCreate.mockResolvedValueOnce({
|
mockCreate.mockResolvedValueOnce({
|
||||||
choices: [{ message: { content: '' } }]
|
choices: [{ message: { content: "" } }],
|
||||||
});
|
})
|
||||||
const result = await handler.completePrompt('Test prompt');
|
const result = await handler.completePrompt("Test prompt")
|
||||||
expect(result).toBe('');
|
expect(result).toBe("")
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should not set max_tokens for non-Anthropic models', async () => {
|
it("should not set max_tokens for non-Anthropic models", async () => {
|
||||||
// Reset mock to clear any previous calls
|
// Reset mock to clear any previous calls
|
||||||
mockCreate.mockClear();
|
mockCreate.mockClear()
|
||||||
|
|
||||||
const nonAnthropicOptions = {
|
|
||||||
apiModelId: 'openai/gpt-4',
|
|
||||||
glamaModelId: 'openai/gpt-4',
|
|
||||||
glamaApiKey: 'test-key',
|
|
||||||
glamaModelInfo: {
|
|
||||||
maxTokens: 4096,
|
|
||||||
contextWindow: 8192,
|
|
||||||
supportsImages: true,
|
|
||||||
supportsPromptCache: false
|
|
||||||
}
|
|
||||||
};
|
|
||||||
const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions);
|
|
||||||
|
|
||||||
await nonAnthropicHandler.completePrompt('Test prompt');
|
const nonAnthropicOptions = {
|
||||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
apiModelId: "openai/gpt-4",
|
||||||
model: 'openai/gpt-4',
|
glamaModelId: "openai/gpt-4",
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
glamaApiKey: "test-key",
|
||||||
temperature: 0
|
glamaModelInfo: {
|
||||||
}));
|
maxTokens: 4096,
|
||||||
expect(mockCreate.mock.calls[0][0]).not.toHaveProperty('max_tokens');
|
contextWindow: 8192,
|
||||||
});
|
supportsImages: true,
|
||||||
});
|
supportsPromptCache: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions)
|
||||||
|
|
||||||
describe('getModel', () => {
|
await nonAnthropicHandler.completePrompt("Test prompt")
|
||||||
it('should return model info', () => {
|
expect(mockCreate).toHaveBeenCalledWith(
|
||||||
const modelInfo = handler.getModel();
|
expect.objectContaining({
|
||||||
expect(modelInfo.id).toBe(mockOptions.apiModelId);
|
model: "openai/gpt-4",
|
||||||
expect(modelInfo.info).toBeDefined();
|
messages: [{ role: "user", content: "Test prompt" }],
|
||||||
expect(modelInfo.info.maxTokens).toBe(8192);
|
temperature: 0,
|
||||||
expect(modelInfo.info.contextWindow).toBe(200_000);
|
}),
|
||||||
});
|
)
|
||||||
});
|
expect(mockCreate.mock.calls[0][0]).not.toHaveProperty("max_tokens")
|
||||||
});
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("getModel", () => {
|
||||||
|
it("should return model info", () => {
|
||||||
|
const modelInfo = handler.getModel()
|
||||||
|
expect(modelInfo.id).toBe(mockOptions.apiModelId)
|
||||||
|
expect(modelInfo.info).toBeDefined()
|
||||||
|
expect(modelInfo.info.maxTokens).toBe(8192)
|
||||||
|
expect(modelInfo.info.contextWindow).toBe(200_000)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -1,160 +1,167 @@
|
|||||||
import { LmStudioHandler } from '../lmstudio';
|
import { LmStudioHandler } from "../lmstudio"
|
||||||
import { ApiHandlerOptions } from '../../../shared/api';
|
import { ApiHandlerOptions } from "../../../shared/api"
|
||||||
import OpenAI from 'openai';
|
import OpenAI from "openai"
|
||||||
import { Anthropic } from '@anthropic-ai/sdk';
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
|
|
||||||
// Mock OpenAI client
|
// Mock OpenAI client
|
||||||
const mockCreate = jest.fn();
|
const mockCreate = jest.fn()
|
||||||
jest.mock('openai', () => {
|
jest.mock("openai", () => {
|
||||||
return {
|
return {
|
||||||
__esModule: true,
|
__esModule: true,
|
||||||
default: jest.fn().mockImplementation(() => ({
|
default: jest.fn().mockImplementation(() => ({
|
||||||
chat: {
|
chat: {
|
||||||
completions: {
|
completions: {
|
||||||
create: mockCreate.mockImplementation(async (options) => {
|
create: mockCreate.mockImplementation(async (options) => {
|
||||||
if (!options.stream) {
|
if (!options.stream) {
|
||||||
return {
|
return {
|
||||||
id: 'test-completion',
|
id: "test-completion",
|
||||||
choices: [{
|
choices: [
|
||||||
message: { role: 'assistant', content: 'Test response' },
|
{
|
||||||
finish_reason: 'stop',
|
message: { role: "assistant", content: "Test response" },
|
||||||
index: 0
|
finish_reason: "stop",
|
||||||
}],
|
index: 0,
|
||||||
usage: {
|
},
|
||||||
prompt_tokens: 10,
|
],
|
||||||
completion_tokens: 5,
|
usage: {
|
||||||
total_tokens: 15
|
prompt_tokens: 10,
|
||||||
}
|
completion_tokens: 5,
|
||||||
};
|
total_tokens: 15,
|
||||||
}
|
},
|
||||||
|
}
|
||||||
return {
|
}
|
||||||
[Symbol.asyncIterator]: async function* () {
|
|
||||||
yield {
|
|
||||||
choices: [{
|
|
||||||
delta: { content: 'Test response' },
|
|
||||||
index: 0
|
|
||||||
}],
|
|
||||||
usage: null
|
|
||||||
};
|
|
||||||
yield {
|
|
||||||
choices: [{
|
|
||||||
delta: {},
|
|
||||||
index: 0
|
|
||||||
}],
|
|
||||||
usage: {
|
|
||||||
prompt_tokens: 10,
|
|
||||||
completion_tokens: 5,
|
|
||||||
total_tokens: 15
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('LmStudioHandler', () => {
|
return {
|
||||||
let handler: LmStudioHandler;
|
[Symbol.asyncIterator]: async function* () {
|
||||||
let mockOptions: ApiHandlerOptions;
|
yield {
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: { content: "Test response" },
|
||||||
|
index: 0,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usage: null,
|
||||||
|
}
|
||||||
|
yield {
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {},
|
||||||
|
index: 0,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 5,
|
||||||
|
total_tokens: 15,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
beforeEach(() => {
|
describe("LmStudioHandler", () => {
|
||||||
mockOptions = {
|
let handler: LmStudioHandler
|
||||||
apiModelId: 'local-model',
|
let mockOptions: ApiHandlerOptions
|
||||||
lmStudioModelId: 'local-model',
|
|
||||||
lmStudioBaseUrl: 'http://localhost:1234/v1'
|
|
||||||
};
|
|
||||||
handler = new LmStudioHandler(mockOptions);
|
|
||||||
mockCreate.mockClear();
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('constructor', () => {
|
beforeEach(() => {
|
||||||
it('should initialize with provided options', () => {
|
mockOptions = {
|
||||||
expect(handler).toBeInstanceOf(LmStudioHandler);
|
apiModelId: "local-model",
|
||||||
expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId);
|
lmStudioModelId: "local-model",
|
||||||
});
|
lmStudioBaseUrl: "http://localhost:1234/v1",
|
||||||
|
}
|
||||||
|
handler = new LmStudioHandler(mockOptions)
|
||||||
|
mockCreate.mockClear()
|
||||||
|
})
|
||||||
|
|
||||||
it('should use default base URL if not provided', () => {
|
describe("constructor", () => {
|
||||||
const handlerWithoutUrl = new LmStudioHandler({
|
it("should initialize with provided options", () => {
|
||||||
apiModelId: 'local-model',
|
expect(handler).toBeInstanceOf(LmStudioHandler)
|
||||||
lmStudioModelId: 'local-model'
|
expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId)
|
||||||
});
|
})
|
||||||
expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('createMessage', () => {
|
it("should use default base URL if not provided", () => {
|
||||||
const systemPrompt = 'You are a helpful assistant.';
|
const handlerWithoutUrl = new LmStudioHandler({
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [
|
apiModelId: "local-model",
|
||||||
{
|
lmStudioModelId: "local-model",
|
||||||
role: 'user',
|
})
|
||||||
content: 'Hello!'
|
expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler)
|
||||||
}
|
})
|
||||||
];
|
})
|
||||||
|
|
||||||
it('should handle streaming responses', async () => {
|
describe("createMessage", () => {
|
||||||
const stream = handler.createMessage(systemPrompt, messages);
|
const systemPrompt = "You are a helpful assistant."
|
||||||
const chunks: any[] = [];
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
for await (const chunk of stream) {
|
{
|
||||||
chunks.push(chunk);
|
role: "user",
|
||||||
}
|
content: "Hello!",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
expect(chunks.length).toBeGreaterThan(0);
|
it("should handle streaming responses", async () => {
|
||||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
const stream = handler.createMessage(systemPrompt, messages)
|
||||||
expect(textChunks).toHaveLength(1);
|
const chunks: any[] = []
|
||||||
expect(textChunks[0].text).toBe('Test response');
|
for await (const chunk of stream) {
|
||||||
});
|
chunks.push(chunk)
|
||||||
|
}
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
expect(chunks.length).toBeGreaterThan(0)
|
||||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||||
|
expect(textChunks).toHaveLength(1)
|
||||||
|
expect(textChunks[0].text).toBe("Test response")
|
||||||
|
})
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, messages);
|
it("should handle API errors", async () => {
|
||||||
|
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||||
|
|
||||||
await expect(async () => {
|
const stream = handler.createMessage(systemPrompt, messages)
|
||||||
for await (const chunk of stream) {
|
|
||||||
// Should not reach here
|
|
||||||
}
|
|
||||||
}).rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('completePrompt', () => {
|
await expect(async () => {
|
||||||
it('should complete prompt successfully', async () => {
|
for await (const chunk of stream) {
|
||||||
const result = await handler.completePrompt('Test prompt');
|
// Should not reach here
|
||||||
expect(result).toBe('Test response');
|
}
|
||||||
expect(mockCreate).toHaveBeenCalledWith({
|
}).rejects.toThrow("Please check the LM Studio developer logs to debug what went wrong")
|
||||||
model: mockOptions.lmStudioModelId,
|
})
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
})
|
||||||
temperature: 0,
|
|
||||||
stream: false
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
describe("completePrompt", () => {
|
||||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
it("should complete prompt successfully", async () => {
|
||||||
await expect(handler.completePrompt('Test prompt'))
|
const result = await handler.completePrompt("Test prompt")
|
||||||
.rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
|
expect(result).toBe("Test response")
|
||||||
});
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
|
model: mockOptions.lmStudioModelId,
|
||||||
|
messages: [{ role: "user", content: "Test prompt" }],
|
||||||
|
temperature: 0,
|
||||||
|
stream: false,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
it('should handle empty response', async () => {
|
it("should handle API errors", async () => {
|
||||||
mockCreate.mockResolvedValueOnce({
|
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||||
choices: [{ message: { content: '' } }]
|
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||||
});
|
"Please check the LM Studio developer logs to debug what went wrong",
|
||||||
const result = await handler.completePrompt('Test prompt');
|
)
|
||||||
expect(result).toBe('');
|
})
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('getModel', () => {
|
it("should handle empty response", async () => {
|
||||||
it('should return model info', () => {
|
mockCreate.mockResolvedValueOnce({
|
||||||
const modelInfo = handler.getModel();
|
choices: [{ message: { content: "" } }],
|
||||||
expect(modelInfo.id).toBe(mockOptions.lmStudioModelId);
|
})
|
||||||
expect(modelInfo.info).toBeDefined();
|
const result = await handler.completePrompt("Test prompt")
|
||||||
expect(modelInfo.info.maxTokens).toBe(-1);
|
expect(result).toBe("")
|
||||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
})
|
||||||
});
|
})
|
||||||
});
|
|
||||||
});
|
describe("getModel", () => {
|
||||||
|
it("should return model info", () => {
|
||||||
|
const modelInfo = handler.getModel()
|
||||||
|
expect(modelInfo.id).toBe(mockOptions.lmStudioModelId)
|
||||||
|
expect(modelInfo.info).toBeDefined()
|
||||||
|
expect(modelInfo.info.maxTokens).toBe(-1)
|
||||||
|
expect(modelInfo.info.contextWindow).toBe(128_000)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -1,160 +1,165 @@
|
|||||||
import { OllamaHandler } from '../ollama';
|
import { OllamaHandler } from "../ollama"
|
||||||
import { ApiHandlerOptions } from '../../../shared/api';
|
import { ApiHandlerOptions } from "../../../shared/api"
|
||||||
import OpenAI from 'openai';
|
import OpenAI from "openai"
|
||||||
import { Anthropic } from '@anthropic-ai/sdk';
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
|
|
||||||
// Mock OpenAI client
|
// Mock OpenAI client
|
||||||
const mockCreate = jest.fn();
|
const mockCreate = jest.fn()
|
||||||
jest.mock('openai', () => {
|
jest.mock("openai", () => {
|
||||||
return {
|
return {
|
||||||
__esModule: true,
|
__esModule: true,
|
||||||
default: jest.fn().mockImplementation(() => ({
|
default: jest.fn().mockImplementation(() => ({
|
||||||
chat: {
|
chat: {
|
||||||
completions: {
|
completions: {
|
||||||
create: mockCreate.mockImplementation(async (options) => {
|
create: mockCreate.mockImplementation(async (options) => {
|
||||||
if (!options.stream) {
|
if (!options.stream) {
|
||||||
return {
|
return {
|
||||||
id: 'test-completion',
|
id: "test-completion",
|
||||||
choices: [{
|
choices: [
|
||||||
message: { role: 'assistant', content: 'Test response' },
|
{
|
||||||
finish_reason: 'stop',
|
message: { role: "assistant", content: "Test response" },
|
||||||
index: 0
|
finish_reason: "stop",
|
||||||
}],
|
index: 0,
|
||||||
usage: {
|
},
|
||||||
prompt_tokens: 10,
|
],
|
||||||
completion_tokens: 5,
|
usage: {
|
||||||
total_tokens: 15
|
prompt_tokens: 10,
|
||||||
}
|
completion_tokens: 5,
|
||||||
};
|
total_tokens: 15,
|
||||||
}
|
},
|
||||||
|
}
|
||||||
return {
|
}
|
||||||
[Symbol.asyncIterator]: async function* () {
|
|
||||||
yield {
|
|
||||||
choices: [{
|
|
||||||
delta: { content: 'Test response' },
|
|
||||||
index: 0
|
|
||||||
}],
|
|
||||||
usage: null
|
|
||||||
};
|
|
||||||
yield {
|
|
||||||
choices: [{
|
|
||||||
delta: {},
|
|
||||||
index: 0
|
|
||||||
}],
|
|
||||||
usage: {
|
|
||||||
prompt_tokens: 10,
|
|
||||||
completion_tokens: 5,
|
|
||||||
total_tokens: 15
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('OllamaHandler', () => {
|
return {
|
||||||
let handler: OllamaHandler;
|
[Symbol.asyncIterator]: async function* () {
|
||||||
let mockOptions: ApiHandlerOptions;
|
yield {
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: { content: "Test response" },
|
||||||
|
index: 0,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usage: null,
|
||||||
|
}
|
||||||
|
yield {
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {},
|
||||||
|
index: 0,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 5,
|
||||||
|
total_tokens: 15,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
beforeEach(() => {
|
describe("OllamaHandler", () => {
|
||||||
mockOptions = {
|
let handler: OllamaHandler
|
||||||
apiModelId: 'llama2',
|
let mockOptions: ApiHandlerOptions
|
||||||
ollamaModelId: 'llama2',
|
|
||||||
ollamaBaseUrl: 'http://localhost:11434/v1'
|
|
||||||
};
|
|
||||||
handler = new OllamaHandler(mockOptions);
|
|
||||||
mockCreate.mockClear();
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('constructor', () => {
|
beforeEach(() => {
|
||||||
it('should initialize with provided options', () => {
|
mockOptions = {
|
||||||
expect(handler).toBeInstanceOf(OllamaHandler);
|
apiModelId: "llama2",
|
||||||
expect(handler.getModel().id).toBe(mockOptions.ollamaModelId);
|
ollamaModelId: "llama2",
|
||||||
});
|
ollamaBaseUrl: "http://localhost:11434/v1",
|
||||||
|
}
|
||||||
|
handler = new OllamaHandler(mockOptions)
|
||||||
|
mockCreate.mockClear()
|
||||||
|
})
|
||||||
|
|
||||||
it('should use default base URL if not provided', () => {
|
describe("constructor", () => {
|
||||||
const handlerWithoutUrl = new OllamaHandler({
|
it("should initialize with provided options", () => {
|
||||||
apiModelId: 'llama2',
|
expect(handler).toBeInstanceOf(OllamaHandler)
|
||||||
ollamaModelId: 'llama2'
|
expect(handler.getModel().id).toBe(mockOptions.ollamaModelId)
|
||||||
});
|
})
|
||||||
expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('createMessage', () => {
|
it("should use default base URL if not provided", () => {
|
||||||
const systemPrompt = 'You are a helpful assistant.';
|
const handlerWithoutUrl = new OllamaHandler({
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [
|
apiModelId: "llama2",
|
||||||
{
|
ollamaModelId: "llama2",
|
||||||
role: 'user',
|
})
|
||||||
content: 'Hello!'
|
expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler)
|
||||||
}
|
})
|
||||||
];
|
})
|
||||||
|
|
||||||
it('should handle streaming responses', async () => {
|
describe("createMessage", () => {
|
||||||
const stream = handler.createMessage(systemPrompt, messages);
|
const systemPrompt = "You are a helpful assistant."
|
||||||
const chunks: any[] = [];
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
for await (const chunk of stream) {
|
{
|
||||||
chunks.push(chunk);
|
role: "user",
|
||||||
}
|
content: "Hello!",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
expect(chunks.length).toBeGreaterThan(0);
|
it("should handle streaming responses", async () => {
|
||||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
const stream = handler.createMessage(systemPrompt, messages)
|
||||||
expect(textChunks).toHaveLength(1);
|
const chunks: any[] = []
|
||||||
expect(textChunks[0].text).toBe('Test response');
|
for await (const chunk of stream) {
|
||||||
});
|
chunks.push(chunk)
|
||||||
|
}
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
expect(chunks.length).toBeGreaterThan(0)
|
||||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||||
|
expect(textChunks).toHaveLength(1)
|
||||||
|
expect(textChunks[0].text).toBe("Test response")
|
||||||
|
})
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, messages);
|
it("should handle API errors", async () => {
|
||||||
|
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||||
|
|
||||||
await expect(async () => {
|
const stream = handler.createMessage(systemPrompt, messages)
|
||||||
for await (const chunk of stream) {
|
|
||||||
// Should not reach here
|
|
||||||
}
|
|
||||||
}).rejects.toThrow('API Error');
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('completePrompt', () => {
|
await expect(async () => {
|
||||||
it('should complete prompt successfully', async () => {
|
for await (const chunk of stream) {
|
||||||
const result = await handler.completePrompt('Test prompt');
|
// Should not reach here
|
||||||
expect(result).toBe('Test response');
|
}
|
||||||
expect(mockCreate).toHaveBeenCalledWith({
|
}).rejects.toThrow("API Error")
|
||||||
model: mockOptions.ollamaModelId,
|
})
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
})
|
||||||
temperature: 0,
|
|
||||||
stream: false
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
describe("completePrompt", () => {
|
||||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
it("should complete prompt successfully", async () => {
|
||||||
await expect(handler.completePrompt('Test prompt'))
|
const result = await handler.completePrompt("Test prompt")
|
||||||
.rejects.toThrow('Ollama completion error: API Error');
|
expect(result).toBe("Test response")
|
||||||
});
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
|
model: mockOptions.ollamaModelId,
|
||||||
|
messages: [{ role: "user", content: "Test prompt" }],
|
||||||
|
temperature: 0,
|
||||||
|
stream: false,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
it('should handle empty response', async () => {
|
it("should handle API errors", async () => {
|
||||||
mockCreate.mockResolvedValueOnce({
|
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||||
choices: [{ message: { content: '' } }]
|
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Ollama completion error: API Error")
|
||||||
});
|
})
|
||||||
const result = await handler.completePrompt('Test prompt');
|
|
||||||
expect(result).toBe('');
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('getModel', () => {
|
it("should handle empty response", async () => {
|
||||||
it('should return model info', () => {
|
mockCreate.mockResolvedValueOnce({
|
||||||
const modelInfo = handler.getModel();
|
choices: [{ message: { content: "" } }],
|
||||||
expect(modelInfo.id).toBe(mockOptions.ollamaModelId);
|
})
|
||||||
expect(modelInfo.info).toBeDefined();
|
const result = await handler.completePrompt("Test prompt")
|
||||||
expect(modelInfo.info.maxTokens).toBe(-1);
|
expect(result).toBe("")
|
||||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
})
|
||||||
});
|
})
|
||||||
});
|
|
||||||
});
|
describe("getModel", () => {
|
||||||
|
it("should return model info", () => {
|
||||||
|
const modelInfo = handler.getModel()
|
||||||
|
expect(modelInfo.id).toBe(mockOptions.ollamaModelId)
|
||||||
|
expect(modelInfo.info).toBeDefined()
|
||||||
|
expect(modelInfo.info.maxTokens).toBe(-1)
|
||||||
|
expect(modelInfo.info.contextWindow).toBe(128_000)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -1,319 +1,326 @@
|
|||||||
import { OpenAiNativeHandler } from '../openai-native';
|
import { OpenAiNativeHandler } from "../openai-native"
|
||||||
import { ApiHandlerOptions } from '../../../shared/api';
|
import { ApiHandlerOptions } from "../../../shared/api"
|
||||||
import OpenAI from 'openai';
|
import OpenAI from "openai"
|
||||||
import { Anthropic } from '@anthropic-ai/sdk';
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
|
|
||||||
// Mock OpenAI client
|
// Mock OpenAI client
|
||||||
const mockCreate = jest.fn();
|
const mockCreate = jest.fn()
|
||||||
jest.mock('openai', () => {
|
jest.mock("openai", () => {
|
||||||
return {
|
return {
|
||||||
__esModule: true,
|
__esModule: true,
|
||||||
default: jest.fn().mockImplementation(() => ({
|
default: jest.fn().mockImplementation(() => ({
|
||||||
chat: {
|
chat: {
|
||||||
completions: {
|
completions: {
|
||||||
create: mockCreate.mockImplementation(async (options) => {
|
create: mockCreate.mockImplementation(async (options) => {
|
||||||
if (!options.stream) {
|
if (!options.stream) {
|
||||||
return {
|
return {
|
||||||
id: 'test-completion',
|
id: "test-completion",
|
||||||
choices: [{
|
choices: [
|
||||||
message: { role: 'assistant', content: 'Test response' },
|
{
|
||||||
finish_reason: 'stop',
|
message: { role: "assistant", content: "Test response" },
|
||||||
index: 0
|
finish_reason: "stop",
|
||||||
}],
|
index: 0,
|
||||||
usage: {
|
},
|
||||||
prompt_tokens: 10,
|
],
|
||||||
completion_tokens: 5,
|
usage: {
|
||||||
total_tokens: 15
|
prompt_tokens: 10,
|
||||||
}
|
completion_tokens: 5,
|
||||||
};
|
total_tokens: 15,
|
||||||
}
|
},
|
||||||
|
}
|
||||||
return {
|
}
|
||||||
[Symbol.asyncIterator]: async function* () {
|
|
||||||
yield {
|
|
||||||
choices: [{
|
|
||||||
delta: { content: 'Test response' },
|
|
||||||
index: 0
|
|
||||||
}],
|
|
||||||
usage: null
|
|
||||||
};
|
|
||||||
yield {
|
|
||||||
choices: [{
|
|
||||||
delta: {},
|
|
||||||
index: 0
|
|
||||||
}],
|
|
||||||
usage: {
|
|
||||||
prompt_tokens: 10,
|
|
||||||
completion_tokens: 5,
|
|
||||||
total_tokens: 15
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('OpenAiNativeHandler', () => {
|
return {
|
||||||
let handler: OpenAiNativeHandler;
|
[Symbol.asyncIterator]: async function* () {
|
||||||
let mockOptions: ApiHandlerOptions;
|
yield {
|
||||||
const systemPrompt = 'You are a helpful assistant.';
|
choices: [
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [
|
{
|
||||||
{
|
delta: { content: "Test response" },
|
||||||
role: 'user',
|
index: 0,
|
||||||
content: 'Hello!'
|
},
|
||||||
}
|
],
|
||||||
];
|
usage: null,
|
||||||
|
}
|
||||||
|
yield {
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {},
|
||||||
|
index: 0,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 5,
|
||||||
|
total_tokens: 15,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
beforeEach(() => {
|
describe("OpenAiNativeHandler", () => {
|
||||||
mockOptions = {
|
let handler: OpenAiNativeHandler
|
||||||
apiModelId: 'gpt-4o',
|
let mockOptions: ApiHandlerOptions
|
||||||
openAiNativeApiKey: 'test-api-key'
|
const systemPrompt = "You are a helpful assistant."
|
||||||
};
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
handler = new OpenAiNativeHandler(mockOptions);
|
{
|
||||||
mockCreate.mockClear();
|
role: "user",
|
||||||
});
|
content: "Hello!",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
describe('constructor', () => {
|
beforeEach(() => {
|
||||||
it('should initialize with provided options', () => {
|
mockOptions = {
|
||||||
expect(handler).toBeInstanceOf(OpenAiNativeHandler);
|
apiModelId: "gpt-4o",
|
||||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
|
openAiNativeApiKey: "test-api-key",
|
||||||
});
|
}
|
||||||
|
handler = new OpenAiNativeHandler(mockOptions)
|
||||||
|
mockCreate.mockClear()
|
||||||
|
})
|
||||||
|
|
||||||
it('should initialize with empty API key', () => {
|
describe("constructor", () => {
|
||||||
const handlerWithoutKey = new OpenAiNativeHandler({
|
it("should initialize with provided options", () => {
|
||||||
apiModelId: 'gpt-4o',
|
expect(handler).toBeInstanceOf(OpenAiNativeHandler)
|
||||||
openAiNativeApiKey: ''
|
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
|
||||||
});
|
})
|
||||||
expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('createMessage', () => {
|
it("should initialize with empty API key", () => {
|
||||||
it('should handle streaming responses', async () => {
|
const handlerWithoutKey = new OpenAiNativeHandler({
|
||||||
const stream = handler.createMessage(systemPrompt, messages);
|
apiModelId: "gpt-4o",
|
||||||
const chunks: any[] = [];
|
openAiNativeApiKey: "",
|
||||||
for await (const chunk of stream) {
|
})
|
||||||
chunks.push(chunk);
|
expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler)
|
||||||
}
|
})
|
||||||
|
})
|
||||||
|
|
||||||
expect(chunks.length).toBeGreaterThan(0);
|
describe("createMessage", () => {
|
||||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
it("should handle streaming responses", async () => {
|
||||||
expect(textChunks).toHaveLength(1);
|
const stream = handler.createMessage(systemPrompt, messages)
|
||||||
expect(textChunks[0].text).toBe('Test response');
|
const chunks: any[] = []
|
||||||
});
|
for await (const chunk of stream) {
|
||||||
|
chunks.push(chunk)
|
||||||
|
}
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
expect(chunks.length).toBeGreaterThan(0)
|
||||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||||
const stream = handler.createMessage(systemPrompt, messages);
|
expect(textChunks).toHaveLength(1)
|
||||||
await expect(async () => {
|
expect(textChunks[0].text).toBe("Test response")
|
||||||
for await (const chunk of stream) {
|
})
|
||||||
// Should not reach here
|
|
||||||
}
|
|
||||||
}).rejects.toThrow('API Error');
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle missing content in response for o1 model', async () => {
|
it("should handle API errors", async () => {
|
||||||
// Use o1 model which supports developer role
|
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||||
handler = new OpenAiNativeHandler({
|
const stream = handler.createMessage(systemPrompt, messages)
|
||||||
...mockOptions,
|
await expect(async () => {
|
||||||
apiModelId: 'o1'
|
for await (const chunk of stream) {
|
||||||
});
|
// Should not reach here
|
||||||
|
}
|
||||||
|
}).rejects.toThrow("API Error")
|
||||||
|
})
|
||||||
|
|
||||||
mockCreate.mockResolvedValueOnce({
|
it("should handle missing content in response for o1 model", async () => {
|
||||||
choices: [{ message: { content: null } }],
|
// Use o1 model which supports developer role
|
||||||
usage: {
|
handler = new OpenAiNativeHandler({
|
||||||
prompt_tokens: 0,
|
...mockOptions,
|
||||||
completion_tokens: 0,
|
apiModelId: "o1",
|
||||||
total_tokens: 0
|
})
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
const generator = handler.createMessage(systemPrompt, messages);
|
mockCreate.mockResolvedValueOnce({
|
||||||
const results = [];
|
choices: [{ message: { content: null } }],
|
||||||
for await (const result of generator) {
|
usage: {
|
||||||
results.push(result);
|
prompt_tokens: 0,
|
||||||
}
|
completion_tokens: 0,
|
||||||
|
total_tokens: 0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
expect(results).toEqual([
|
const generator = handler.createMessage(systemPrompt, messages)
|
||||||
{ type: 'text', text: '' },
|
const results = []
|
||||||
{ type: 'usage', inputTokens: 0, outputTokens: 0 }
|
for await (const result of generator) {
|
||||||
]);
|
results.push(result)
|
||||||
|
}
|
||||||
|
|
||||||
// Verify developer role is used for system prompt with o1 model
|
expect(results).toEqual([
|
||||||
expect(mockCreate).toHaveBeenCalledWith({
|
{ type: "text", text: "" },
|
||||||
model: 'o1',
|
{ type: "usage", inputTokens: 0, outputTokens: 0 },
|
||||||
messages: [
|
])
|
||||||
{ role: 'developer', content: systemPrompt },
|
|
||||||
{ role: 'user', content: 'Hello!' }
|
|
||||||
]
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('streaming models', () => {
|
// Verify developer role is used for system prompt with o1 model
|
||||||
beforeEach(() => {
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
handler = new OpenAiNativeHandler({
|
model: "o1",
|
||||||
...mockOptions,
|
messages: [
|
||||||
apiModelId: 'gpt-4o',
|
{ role: "developer", content: systemPrompt },
|
||||||
});
|
{ role: "user", content: "Hello!" },
|
||||||
});
|
],
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
it('should handle streaming response', async () => {
|
describe("streaming models", () => {
|
||||||
const mockStream = [
|
beforeEach(() => {
|
||||||
{ choices: [{ delta: { content: 'Hello' } }], usage: null },
|
handler = new OpenAiNativeHandler({
|
||||||
{ choices: [{ delta: { content: ' there' } }], usage: null },
|
...mockOptions,
|
||||||
{ choices: [{ delta: { content: '!' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
apiModelId: "gpt-4o",
|
||||||
];
|
})
|
||||||
|
})
|
||||||
|
|
||||||
mockCreate.mockResolvedValueOnce(
|
it("should handle streaming response", async () => {
|
||||||
(async function* () {
|
const mockStream = [
|
||||||
for (const chunk of mockStream) {
|
{ choices: [{ delta: { content: "Hello" } }], usage: null },
|
||||||
yield chunk;
|
{ choices: [{ delta: { content: " there" } }], usage: null },
|
||||||
}
|
{ choices: [{ delta: { content: "!" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||||
})()
|
]
|
||||||
);
|
|
||||||
|
|
||||||
const generator = handler.createMessage(systemPrompt, messages);
|
mockCreate.mockResolvedValueOnce(
|
||||||
const results = [];
|
(async function* () {
|
||||||
for await (const result of generator) {
|
for (const chunk of mockStream) {
|
||||||
results.push(result);
|
yield chunk
|
||||||
}
|
}
|
||||||
|
})(),
|
||||||
|
)
|
||||||
|
|
||||||
expect(results).toEqual([
|
const generator = handler.createMessage(systemPrompt, messages)
|
||||||
{ type: 'text', text: 'Hello' },
|
const results = []
|
||||||
{ type: 'text', text: ' there' },
|
for await (const result of generator) {
|
||||||
{ type: 'text', text: '!' },
|
results.push(result)
|
||||||
{ type: 'usage', inputTokens: 10, outputTokens: 5 },
|
}
|
||||||
]);
|
|
||||||
|
|
||||||
expect(mockCreate).toHaveBeenCalledWith({
|
expect(results).toEqual([
|
||||||
model: 'gpt-4o',
|
{ type: "text", text: "Hello" },
|
||||||
temperature: 0,
|
{ type: "text", text: " there" },
|
||||||
messages: [
|
{ type: "text", text: "!" },
|
||||||
{ role: 'system', content: systemPrompt },
|
{ type: "usage", inputTokens: 10, outputTokens: 5 },
|
||||||
{ role: 'user', content: 'Hello!' },
|
])
|
||||||
],
|
|
||||||
stream: true,
|
|
||||||
stream_options: { include_usage: true },
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle empty delta content', async () => {
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
const mockStream = [
|
model: "gpt-4o",
|
||||||
{ choices: [{ delta: {} }], usage: null },
|
temperature: 0,
|
||||||
{ choices: [{ delta: { content: null } }], usage: null },
|
messages: [
|
||||||
{ choices: [{ delta: { content: 'Hello' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
{ role: "system", content: systemPrompt },
|
||||||
];
|
{ role: "user", content: "Hello!" },
|
||||||
|
],
|
||||||
|
stream: true,
|
||||||
|
stream_options: { include_usage: true },
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
mockCreate.mockResolvedValueOnce(
|
it("should handle empty delta content", async () => {
|
||||||
(async function* () {
|
const mockStream = [
|
||||||
for (const chunk of mockStream) {
|
{ choices: [{ delta: {} }], usage: null },
|
||||||
yield chunk;
|
{ choices: [{ delta: { content: null } }], usage: null },
|
||||||
}
|
{ choices: [{ delta: { content: "Hello" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||||
})()
|
]
|
||||||
);
|
|
||||||
|
|
||||||
const generator = handler.createMessage(systemPrompt, messages);
|
mockCreate.mockResolvedValueOnce(
|
||||||
const results = [];
|
(async function* () {
|
||||||
for await (const result of generator) {
|
for (const chunk of mockStream) {
|
||||||
results.push(result);
|
yield chunk
|
||||||
}
|
}
|
||||||
|
})(),
|
||||||
|
)
|
||||||
|
|
||||||
expect(results).toEqual([
|
const generator = handler.createMessage(systemPrompt, messages)
|
||||||
{ type: 'text', text: 'Hello' },
|
const results = []
|
||||||
{ type: 'usage', inputTokens: 10, outputTokens: 5 },
|
for await (const result of generator) {
|
||||||
]);
|
results.push(result)
|
||||||
});
|
}
|
||||||
});
|
|
||||||
|
|
||||||
describe('completePrompt', () => {
|
expect(results).toEqual([
|
||||||
it('should complete prompt successfully with gpt-4o model', async () => {
|
{ type: "text", text: "Hello" },
|
||||||
const result = await handler.completePrompt('Test prompt');
|
{ type: "usage", inputTokens: 10, outputTokens: 5 },
|
||||||
expect(result).toBe('Test response');
|
])
|
||||||
expect(mockCreate).toHaveBeenCalledWith({
|
})
|
||||||
model: 'gpt-4o',
|
})
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
|
||||||
temperature: 0
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should complete prompt successfully with o1 model', async () => {
|
describe("completePrompt", () => {
|
||||||
handler = new OpenAiNativeHandler({
|
it("should complete prompt successfully with gpt-4o model", async () => {
|
||||||
apiModelId: 'o1',
|
const result = await handler.completePrompt("Test prompt")
|
||||||
openAiNativeApiKey: 'test-api-key'
|
expect(result).toBe("Test response")
|
||||||
});
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
|
model: "gpt-4o",
|
||||||
|
messages: [{ role: "user", content: "Test prompt" }],
|
||||||
|
temperature: 0,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
const result = await handler.completePrompt('Test prompt');
|
it("should complete prompt successfully with o1 model", async () => {
|
||||||
expect(result).toBe('Test response');
|
handler = new OpenAiNativeHandler({
|
||||||
expect(mockCreate).toHaveBeenCalledWith({
|
apiModelId: "o1",
|
||||||
model: 'o1',
|
openAiNativeApiKey: "test-api-key",
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
})
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should complete prompt successfully with o1-preview model', async () => {
|
const result = await handler.completePrompt("Test prompt")
|
||||||
handler = new OpenAiNativeHandler({
|
expect(result).toBe("Test response")
|
||||||
apiModelId: 'o1-preview',
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
openAiNativeApiKey: 'test-api-key'
|
model: "o1",
|
||||||
});
|
messages: [{ role: "user", content: "Test prompt" }],
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
const result = await handler.completePrompt('Test prompt');
|
it("should complete prompt successfully with o1-preview model", async () => {
|
||||||
expect(result).toBe('Test response');
|
handler = new OpenAiNativeHandler({
|
||||||
expect(mockCreate).toHaveBeenCalledWith({
|
apiModelId: "o1-preview",
|
||||||
model: 'o1-preview',
|
openAiNativeApiKey: "test-api-key",
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
})
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should complete prompt successfully with o1-mini model', async () => {
|
const result = await handler.completePrompt("Test prompt")
|
||||||
handler = new OpenAiNativeHandler({
|
expect(result).toBe("Test response")
|
||||||
apiModelId: 'o1-mini',
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
openAiNativeApiKey: 'test-api-key'
|
model: "o1-preview",
|
||||||
});
|
messages: [{ role: "user", content: "Test prompt" }],
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
const result = await handler.completePrompt('Test prompt');
|
it("should complete prompt successfully with o1-mini model", async () => {
|
||||||
expect(result).toBe('Test response');
|
handler = new OpenAiNativeHandler({
|
||||||
expect(mockCreate).toHaveBeenCalledWith({
|
apiModelId: "o1-mini",
|
||||||
model: 'o1-mini',
|
openAiNativeApiKey: "test-api-key",
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
})
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
const result = await handler.completePrompt("Test prompt")
|
||||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
expect(result).toBe("Test response")
|
||||||
await expect(handler.completePrompt('Test prompt'))
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
.rejects.toThrow('OpenAI Native completion error: API Error');
|
model: "o1-mini",
|
||||||
});
|
messages: [{ role: "user", content: "Test prompt" }],
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
it('should handle empty response', async () => {
|
it("should handle API errors", async () => {
|
||||||
mockCreate.mockResolvedValueOnce({
|
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||||
choices: [{ message: { content: '' } }]
|
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||||
});
|
"OpenAI Native completion error: API Error",
|
||||||
const result = await handler.completePrompt('Test prompt');
|
)
|
||||||
expect(result).toBe('');
|
})
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('getModel', () => {
|
it("should handle empty response", async () => {
|
||||||
it('should return model info', () => {
|
mockCreate.mockResolvedValueOnce({
|
||||||
const modelInfo = handler.getModel();
|
choices: [{ message: { content: "" } }],
|
||||||
expect(modelInfo.id).toBe(mockOptions.apiModelId);
|
})
|
||||||
expect(modelInfo.info).toBeDefined();
|
const result = await handler.completePrompt("Test prompt")
|
||||||
expect(modelInfo.info.maxTokens).toBe(4096);
|
expect(result).toBe("")
|
||||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle undefined model ID', () => {
|
describe("getModel", () => {
|
||||||
const handlerWithoutModel = new OpenAiNativeHandler({
|
it("should return model info", () => {
|
||||||
openAiNativeApiKey: 'test-api-key'
|
const modelInfo = handler.getModel()
|
||||||
});
|
expect(modelInfo.id).toBe(mockOptions.apiModelId)
|
||||||
const modelInfo = handlerWithoutModel.getModel();
|
expect(modelInfo.info).toBeDefined()
|
||||||
expect(modelInfo.id).toBe('gpt-4o'); // Default model
|
expect(modelInfo.info.maxTokens).toBe(4096)
|
||||||
expect(modelInfo.info).toBeDefined();
|
expect(modelInfo.info.contextWindow).toBe(128_000)
|
||||||
});
|
})
|
||||||
});
|
|
||||||
});
|
it("should handle undefined model ID", () => {
|
||||||
|
const handlerWithoutModel = new OpenAiNativeHandler({
|
||||||
|
openAiNativeApiKey: "test-api-key",
|
||||||
|
})
|
||||||
|
const modelInfo = handlerWithoutModel.getModel()
|
||||||
|
expect(modelInfo.id).toBe("gpt-4o") // Default model
|
||||||
|
expect(modelInfo.info).toBeDefined()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -1,224 +1,233 @@
|
|||||||
import { OpenAiHandler } from '../openai';
|
import { OpenAiHandler } from "../openai"
|
||||||
import { ApiHandlerOptions } from '../../../shared/api';
|
import { ApiHandlerOptions } from "../../../shared/api"
|
||||||
import { ApiStream } from '../../transform/stream';
|
import { ApiStream } from "../../transform/stream"
|
||||||
import OpenAI from 'openai';
|
import OpenAI from "openai"
|
||||||
import { Anthropic } from '@anthropic-ai/sdk';
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
|
|
||||||
// Mock OpenAI client
|
// Mock OpenAI client
|
||||||
const mockCreate = jest.fn();
|
const mockCreate = jest.fn()
|
||||||
jest.mock('openai', () => {
|
jest.mock("openai", () => {
|
||||||
return {
|
return {
|
||||||
__esModule: true,
|
__esModule: true,
|
||||||
default: jest.fn().mockImplementation(() => ({
|
default: jest.fn().mockImplementation(() => ({
|
||||||
chat: {
|
chat: {
|
||||||
completions: {
|
completions: {
|
||||||
create: mockCreate.mockImplementation(async (options) => {
|
create: mockCreate.mockImplementation(async (options) => {
|
||||||
if (!options.stream) {
|
if (!options.stream) {
|
||||||
return {
|
return {
|
||||||
id: 'test-completion',
|
id: "test-completion",
|
||||||
choices: [{
|
choices: [
|
||||||
message: { role: 'assistant', content: 'Test response', refusal: null },
|
{
|
||||||
finish_reason: 'stop',
|
message: { role: "assistant", content: "Test response", refusal: null },
|
||||||
index: 0
|
finish_reason: "stop",
|
||||||
}],
|
index: 0,
|
||||||
usage: {
|
},
|
||||||
prompt_tokens: 10,
|
],
|
||||||
completion_tokens: 5,
|
usage: {
|
||||||
total_tokens: 15
|
prompt_tokens: 10,
|
||||||
}
|
completion_tokens: 5,
|
||||||
};
|
total_tokens: 15,
|
||||||
}
|
},
|
||||||
|
}
|
||||||
return {
|
}
|
||||||
[Symbol.asyncIterator]: async function* () {
|
|
||||||
yield {
|
|
||||||
choices: [{
|
|
||||||
delta: { content: 'Test response' },
|
|
||||||
index: 0
|
|
||||||
}],
|
|
||||||
usage: null
|
|
||||||
};
|
|
||||||
yield {
|
|
||||||
choices: [{
|
|
||||||
delta: {},
|
|
||||||
index: 0
|
|
||||||
}],
|
|
||||||
usage: {
|
|
||||||
prompt_tokens: 10,
|
|
||||||
completion_tokens: 5,
|
|
||||||
total_tokens: 15
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('OpenAiHandler', () => {
|
return {
|
||||||
let handler: OpenAiHandler;
|
[Symbol.asyncIterator]: async function* () {
|
||||||
let mockOptions: ApiHandlerOptions;
|
yield {
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: { content: "Test response" },
|
||||||
|
index: 0,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usage: null,
|
||||||
|
}
|
||||||
|
yield {
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {},
|
||||||
|
index: 0,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 5,
|
||||||
|
total_tokens: 15,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
beforeEach(() => {
|
describe("OpenAiHandler", () => {
|
||||||
mockOptions = {
|
let handler: OpenAiHandler
|
||||||
openAiApiKey: 'test-api-key',
|
let mockOptions: ApiHandlerOptions
|
||||||
openAiModelId: 'gpt-4',
|
|
||||||
openAiBaseUrl: 'https://api.openai.com/v1'
|
|
||||||
};
|
|
||||||
handler = new OpenAiHandler(mockOptions);
|
|
||||||
mockCreate.mockClear();
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('constructor', () => {
|
beforeEach(() => {
|
||||||
it('should initialize with provided options', () => {
|
mockOptions = {
|
||||||
expect(handler).toBeInstanceOf(OpenAiHandler);
|
openAiApiKey: "test-api-key",
|
||||||
expect(handler.getModel().id).toBe(mockOptions.openAiModelId);
|
openAiModelId: "gpt-4",
|
||||||
});
|
openAiBaseUrl: "https://api.openai.com/v1",
|
||||||
|
}
|
||||||
|
handler = new OpenAiHandler(mockOptions)
|
||||||
|
mockCreate.mockClear()
|
||||||
|
})
|
||||||
|
|
||||||
it('should use custom base URL if provided', () => {
|
describe("constructor", () => {
|
||||||
const customBaseUrl = 'https://custom.openai.com/v1';
|
it("should initialize with provided options", () => {
|
||||||
const handlerWithCustomUrl = new OpenAiHandler({
|
expect(handler).toBeInstanceOf(OpenAiHandler)
|
||||||
...mockOptions,
|
expect(handler.getModel().id).toBe(mockOptions.openAiModelId)
|
||||||
openAiBaseUrl: customBaseUrl
|
})
|
||||||
});
|
|
||||||
expect(handlerWithCustomUrl).toBeInstanceOf(OpenAiHandler);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('createMessage', () => {
|
it("should use custom base URL if provided", () => {
|
||||||
const systemPrompt = 'You are a helpful assistant.';
|
const customBaseUrl = "https://custom.openai.com/v1"
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [
|
const handlerWithCustomUrl = new OpenAiHandler({
|
||||||
{
|
...mockOptions,
|
||||||
role: 'user',
|
openAiBaseUrl: customBaseUrl,
|
||||||
content: [{
|
})
|
||||||
type: 'text' as const,
|
expect(handlerWithCustomUrl).toBeInstanceOf(OpenAiHandler)
|
||||||
text: 'Hello!'
|
})
|
||||||
}]
|
})
|
||||||
}
|
|
||||||
];
|
|
||||||
|
|
||||||
it('should handle non-streaming mode', async () => {
|
describe("createMessage", () => {
|
||||||
const handler = new OpenAiHandler({
|
const systemPrompt = "You are a helpful assistant."
|
||||||
...mockOptions,
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
openAiStreamingEnabled: false
|
{
|
||||||
});
|
role: "user",
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: "text" as const,
|
||||||
|
text: "Hello!",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, messages);
|
it("should handle non-streaming mode", async () => {
|
||||||
const chunks: any[] = [];
|
const handler = new OpenAiHandler({
|
||||||
for await (const chunk of stream) {
|
...mockOptions,
|
||||||
chunks.push(chunk);
|
openAiStreamingEnabled: false,
|
||||||
}
|
})
|
||||||
|
|
||||||
expect(chunks.length).toBeGreaterThan(0);
|
const stream = handler.createMessage(systemPrompt, messages)
|
||||||
const textChunk = chunks.find(chunk => chunk.type === 'text');
|
const chunks: any[] = []
|
||||||
const usageChunk = chunks.find(chunk => chunk.type === 'usage');
|
for await (const chunk of stream) {
|
||||||
|
chunks.push(chunk)
|
||||||
expect(textChunk).toBeDefined();
|
}
|
||||||
expect(textChunk?.text).toBe('Test response');
|
|
||||||
expect(usageChunk).toBeDefined();
|
|
||||||
expect(usageChunk?.inputTokens).toBe(10);
|
|
||||||
expect(usageChunk?.outputTokens).toBe(5);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle streaming responses', async () => {
|
expect(chunks.length).toBeGreaterThan(0)
|
||||||
const stream = handler.createMessage(systemPrompt, messages);
|
const textChunk = chunks.find((chunk) => chunk.type === "text")
|
||||||
const chunks: any[] = [];
|
const usageChunk = chunks.find((chunk) => chunk.type === "usage")
|
||||||
for await (const chunk of stream) {
|
|
||||||
chunks.push(chunk);
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(chunks.length).toBeGreaterThan(0);
|
expect(textChunk).toBeDefined()
|
||||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
expect(textChunk?.text).toBe("Test response")
|
||||||
expect(textChunks).toHaveLength(1);
|
expect(usageChunk).toBeDefined()
|
||||||
expect(textChunks[0].text).toBe('Test response');
|
expect(usageChunk?.inputTokens).toBe(10)
|
||||||
});
|
expect(usageChunk?.outputTokens).toBe(5)
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('error handling', () => {
|
it("should handle streaming responses", async () => {
|
||||||
const testMessages: Anthropic.Messages.MessageParam[] = [
|
const stream = handler.createMessage(systemPrompt, messages)
|
||||||
{
|
const chunks: any[] = []
|
||||||
role: 'user',
|
for await (const chunk of stream) {
|
||||||
content: [{
|
chunks.push(chunk)
|
||||||
type: 'text' as const,
|
}
|
||||||
text: 'Hello'
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
];
|
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
expect(chunks.length).toBeGreaterThan(0)
|
||||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||||
|
expect(textChunks).toHaveLength(1)
|
||||||
|
expect(textChunks[0].text).toBe("Test response")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
const stream = handler.createMessage('system prompt', testMessages);
|
describe("error handling", () => {
|
||||||
|
const testMessages: Anthropic.Messages.MessageParam[] = [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: "text" as const,
|
||||||
|
text: "Hello",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
await expect(async () => {
|
it("should handle API errors", async () => {
|
||||||
for await (const chunk of stream) {
|
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||||
// Should not reach here
|
|
||||||
}
|
|
||||||
}).rejects.toThrow('API Error');
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle rate limiting', async () => {
|
const stream = handler.createMessage("system prompt", testMessages)
|
||||||
const rateLimitError = new Error('Rate limit exceeded');
|
|
||||||
rateLimitError.name = 'Error';
|
|
||||||
(rateLimitError as any).status = 429;
|
|
||||||
mockCreate.mockRejectedValueOnce(rateLimitError);
|
|
||||||
|
|
||||||
const stream = handler.createMessage('system prompt', testMessages);
|
await expect(async () => {
|
||||||
|
for await (const chunk of stream) {
|
||||||
|
// Should not reach here
|
||||||
|
}
|
||||||
|
}).rejects.toThrow("API Error")
|
||||||
|
})
|
||||||
|
|
||||||
await expect(async () => {
|
it("should handle rate limiting", async () => {
|
||||||
for await (const chunk of stream) {
|
const rateLimitError = new Error("Rate limit exceeded")
|
||||||
// Should not reach here
|
rateLimitError.name = "Error"
|
||||||
}
|
;(rateLimitError as any).status = 429
|
||||||
}).rejects.toThrow('Rate limit exceeded');
|
mockCreate.mockRejectedValueOnce(rateLimitError)
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('completePrompt', () => {
|
const stream = handler.createMessage("system prompt", testMessages)
|
||||||
it('should complete prompt successfully', async () => {
|
|
||||||
const result = await handler.completePrompt('Test prompt');
|
|
||||||
expect(result).toBe('Test response');
|
|
||||||
expect(mockCreate).toHaveBeenCalledWith({
|
|
||||||
model: mockOptions.openAiModelId,
|
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
|
||||||
temperature: 0
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
await expect(async () => {
|
||||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
for await (const chunk of stream) {
|
||||||
await expect(handler.completePrompt('Test prompt'))
|
// Should not reach here
|
||||||
.rejects.toThrow('OpenAI completion error: API Error');
|
}
|
||||||
});
|
}).rejects.toThrow("Rate limit exceeded")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
it('should handle empty response', async () => {
|
describe("completePrompt", () => {
|
||||||
mockCreate.mockImplementationOnce(() => ({
|
it("should complete prompt successfully", async () => {
|
||||||
choices: [{ message: { content: '' } }]
|
const result = await handler.completePrompt("Test prompt")
|
||||||
}));
|
expect(result).toBe("Test response")
|
||||||
const result = await handler.completePrompt('Test prompt');
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
expect(result).toBe('');
|
model: mockOptions.openAiModelId,
|
||||||
});
|
messages: [{ role: "user", content: "Test prompt" }],
|
||||||
});
|
temperature: 0,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
describe('getModel', () => {
|
it("should handle API errors", async () => {
|
||||||
it('should return model info with sane defaults', () => {
|
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||||
const model = handler.getModel();
|
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("OpenAI completion error: API Error")
|
||||||
expect(model.id).toBe(mockOptions.openAiModelId);
|
})
|
||||||
expect(model.info).toBeDefined();
|
|
||||||
expect(model.info.contextWindow).toBe(128_000);
|
|
||||||
expect(model.info.supportsImages).toBe(true);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle undefined model ID', () => {
|
it("should handle empty response", async () => {
|
||||||
const handlerWithoutModel = new OpenAiHandler({
|
mockCreate.mockImplementationOnce(() => ({
|
||||||
...mockOptions,
|
choices: [{ message: { content: "" } }],
|
||||||
openAiModelId: undefined
|
}))
|
||||||
});
|
const result = await handler.completePrompt("Test prompt")
|
||||||
const model = handlerWithoutModel.getModel();
|
expect(result).toBe("")
|
||||||
expect(model.id).toBe('');
|
})
|
||||||
expect(model.info).toBeDefined();
|
})
|
||||||
});
|
|
||||||
});
|
describe("getModel", () => {
|
||||||
});
|
it("should return model info with sane defaults", () => {
|
||||||
|
const model = handler.getModel()
|
||||||
|
expect(model.id).toBe(mockOptions.openAiModelId)
|
||||||
|
expect(model.info).toBeDefined()
|
||||||
|
expect(model.info.contextWindow).toBe(128_000)
|
||||||
|
expect(model.info.supportsImages).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle undefined model ID", () => {
|
||||||
|
const handlerWithoutModel = new OpenAiHandler({
|
||||||
|
...mockOptions,
|
||||||
|
openAiModelId: undefined,
|
||||||
|
})
|
||||||
|
const model = handlerWithoutModel.getModel()
|
||||||
|
expect(model.id).toBe("")
|
||||||
|
expect(model.info).toBeDefined()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -1,283 +1,297 @@
|
|||||||
import { OpenRouterHandler } from '../openrouter'
|
import { OpenRouterHandler } from "../openrouter"
|
||||||
import { ApiHandlerOptions, ModelInfo } from '../../../shared/api'
|
import { ApiHandlerOptions, ModelInfo } from "../../../shared/api"
|
||||||
import OpenAI from 'openai'
|
import OpenAI from "openai"
|
||||||
import axios from 'axios'
|
import axios from "axios"
|
||||||
import { Anthropic } from '@anthropic-ai/sdk'
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
|
|
||||||
// Mock dependencies
|
// Mock dependencies
|
||||||
jest.mock('openai')
|
jest.mock("openai")
|
||||||
jest.mock('axios')
|
jest.mock("axios")
|
||||||
jest.mock('delay', () => jest.fn(() => Promise.resolve()))
|
jest.mock("delay", () => jest.fn(() => Promise.resolve()))
|
||||||
|
|
||||||
describe('OpenRouterHandler', () => {
|
describe("OpenRouterHandler", () => {
|
||||||
const mockOptions: ApiHandlerOptions = {
|
const mockOptions: ApiHandlerOptions = {
|
||||||
openRouterApiKey: 'test-key',
|
openRouterApiKey: "test-key",
|
||||||
openRouterModelId: 'test-model',
|
openRouterModelId: "test-model",
|
||||||
openRouterModelInfo: {
|
openRouterModelInfo: {
|
||||||
name: 'Test Model',
|
name: "Test Model",
|
||||||
description: 'Test Description',
|
description: "Test Description",
|
||||||
maxTokens: 1000,
|
maxTokens: 1000,
|
||||||
contextWindow: 2000,
|
contextWindow: 2000,
|
||||||
supportsPromptCache: true,
|
supportsPromptCache: true,
|
||||||
inputPrice: 0.01,
|
inputPrice: 0.01,
|
||||||
outputPrice: 0.02
|
outputPrice: 0.02,
|
||||||
} as ModelInfo
|
} as ModelInfo,
|
||||||
}
|
}
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
jest.clearAllMocks()
|
||||||
})
|
})
|
||||||
|
|
||||||
test('constructor initializes with correct options', () => {
|
test("constructor initializes with correct options", () => {
|
||||||
const handler = new OpenRouterHandler(mockOptions)
|
const handler = new OpenRouterHandler(mockOptions)
|
||||||
expect(handler).toBeInstanceOf(OpenRouterHandler)
|
expect(handler).toBeInstanceOf(OpenRouterHandler)
|
||||||
expect(OpenAI).toHaveBeenCalledWith({
|
expect(OpenAI).toHaveBeenCalledWith({
|
||||||
baseURL: 'https://openrouter.ai/api/v1',
|
baseURL: "https://openrouter.ai/api/v1",
|
||||||
apiKey: mockOptions.openRouterApiKey,
|
apiKey: mockOptions.openRouterApiKey,
|
||||||
defaultHeaders: {
|
defaultHeaders: {
|
||||||
'HTTP-Referer': 'https://github.com/RooVetGit/Roo-Cline',
|
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
|
||||||
'X-Title': 'Roo-Cline',
|
"X-Title": "Roo-Cline",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
test('getModel returns correct model info when options are provided', () => {
|
test("getModel returns correct model info when options are provided", () => {
|
||||||
const handler = new OpenRouterHandler(mockOptions)
|
const handler = new OpenRouterHandler(mockOptions)
|
||||||
const result = handler.getModel()
|
const result = handler.getModel()
|
||||||
|
|
||||||
expect(result).toEqual({
|
|
||||||
id: mockOptions.openRouterModelId,
|
|
||||||
info: mockOptions.openRouterModelInfo
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
test('getModel returns default model info when options are not provided', () => {
|
expect(result).toEqual({
|
||||||
const handler = new OpenRouterHandler({})
|
id: mockOptions.openRouterModelId,
|
||||||
const result = handler.getModel()
|
info: mockOptions.openRouterModelInfo,
|
||||||
|
})
|
||||||
expect(result.id).toBe('anthropic/claude-3.5-sonnet:beta')
|
})
|
||||||
expect(result.info.supportsPromptCache).toBe(true)
|
|
||||||
})
|
|
||||||
|
|
||||||
test('createMessage generates correct stream chunks', async () => {
|
test("getModel returns default model info when options are not provided", () => {
|
||||||
const handler = new OpenRouterHandler(mockOptions)
|
const handler = new OpenRouterHandler({})
|
||||||
const mockStream = {
|
const result = handler.getModel()
|
||||||
async *[Symbol.asyncIterator]() {
|
|
||||||
yield {
|
|
||||||
id: 'test-id',
|
|
||||||
choices: [{
|
|
||||||
delta: {
|
|
||||||
content: 'test response'
|
|
||||||
}
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mock OpenAI chat.completions.create
|
expect(result.id).toBe("anthropic/claude-3.5-sonnet:beta")
|
||||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
expect(result.info.supportsPromptCache).toBe(true)
|
||||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
})
|
||||||
completions: { create: mockCreate }
|
|
||||||
} as any
|
|
||||||
|
|
||||||
// Mock axios.get for generation details
|
test("createMessage generates correct stream chunks", async () => {
|
||||||
;(axios.get as jest.Mock).mockResolvedValue({
|
const handler = new OpenRouterHandler(mockOptions)
|
||||||
data: {
|
const mockStream = {
|
||||||
data: {
|
async *[Symbol.asyncIterator]() {
|
||||||
native_tokens_prompt: 10,
|
yield {
|
||||||
native_tokens_completion: 20,
|
id: "test-id",
|
||||||
total_cost: 0.001
|
choices: [
|
||||||
}
|
{
|
||||||
}
|
delta: {
|
||||||
})
|
content: "test response",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
const systemPrompt = 'test system prompt'
|
// Mock OpenAI chat.completions.create
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [{ role: 'user' as const, content: 'test message' }]
|
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||||
|
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||||
|
completions: { create: mockCreate },
|
||||||
|
} as any
|
||||||
|
|
||||||
const generator = handler.createMessage(systemPrompt, messages)
|
// Mock axios.get for generation details
|
||||||
const chunks = []
|
;(axios.get as jest.Mock).mockResolvedValue({
|
||||||
|
data: {
|
||||||
for await (const chunk of generator) {
|
data: {
|
||||||
chunks.push(chunk)
|
native_tokens_prompt: 10,
|
||||||
}
|
native_tokens_completion: 20,
|
||||||
|
total_cost: 0.001,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
// Verify stream chunks
|
const systemPrompt = "test system prompt"
|
||||||
expect(chunks).toHaveLength(2) // One text chunk and one usage chunk
|
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user" as const, content: "test message" }]
|
||||||
expect(chunks[0]).toEqual({
|
|
||||||
type: 'text',
|
|
||||||
text: 'test response'
|
|
||||||
})
|
|
||||||
expect(chunks[1]).toEqual({
|
|
||||||
type: 'usage',
|
|
||||||
inputTokens: 10,
|
|
||||||
outputTokens: 20,
|
|
||||||
totalCost: 0.001,
|
|
||||||
fullResponseText: 'test response'
|
|
||||||
})
|
|
||||||
|
|
||||||
// Verify OpenAI client was called with correct parameters
|
const generator = handler.createMessage(systemPrompt, messages)
|
||||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
const chunks = []
|
||||||
model: mockOptions.openRouterModelId,
|
|
||||||
temperature: 0,
|
|
||||||
messages: expect.arrayContaining([
|
|
||||||
{ role: 'system', content: systemPrompt },
|
|
||||||
{ role: 'user', content: 'test message' }
|
|
||||||
]),
|
|
||||||
stream: true
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
|
|
||||||
test('createMessage with middle-out transform enabled', async () => {
|
for await (const chunk of generator) {
|
||||||
const handler = new OpenRouterHandler({
|
chunks.push(chunk)
|
||||||
...mockOptions,
|
}
|
||||||
openRouterUseMiddleOutTransform: true
|
|
||||||
})
|
|
||||||
const mockStream = {
|
|
||||||
async *[Symbol.asyncIterator]() {
|
|
||||||
yield {
|
|
||||||
id: 'test-id',
|
|
||||||
choices: [{
|
|
||||||
delta: {
|
|
||||||
content: 'test response'
|
|
||||||
}
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
// Verify stream chunks
|
||||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
expect(chunks).toHaveLength(2) // One text chunk and one usage chunk
|
||||||
completions: { create: mockCreate }
|
expect(chunks[0]).toEqual({
|
||||||
} as any
|
type: "text",
|
||||||
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
text: "test response",
|
||||||
|
})
|
||||||
|
expect(chunks[1]).toEqual({
|
||||||
|
type: "usage",
|
||||||
|
inputTokens: 10,
|
||||||
|
outputTokens: 20,
|
||||||
|
totalCost: 0.001,
|
||||||
|
fullResponseText: "test response",
|
||||||
|
})
|
||||||
|
|
||||||
await handler.createMessage('test', []).next()
|
// Verify OpenAI client was called with correct parameters
|
||||||
|
expect(mockCreate).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
model: mockOptions.openRouterModelId,
|
||||||
|
temperature: 0,
|
||||||
|
messages: expect.arrayContaining([
|
||||||
|
{ role: "system", content: systemPrompt },
|
||||||
|
{ role: "user", content: "test message" },
|
||||||
|
]),
|
||||||
|
stream: true,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
test("createMessage with middle-out transform enabled", async () => {
|
||||||
transforms: ['middle-out']
|
const handler = new OpenRouterHandler({
|
||||||
}))
|
...mockOptions,
|
||||||
})
|
openRouterUseMiddleOutTransform: true,
|
||||||
|
})
|
||||||
|
const mockStream = {
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
yield {
|
||||||
|
id: "test-id",
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {
|
||||||
|
content: "test response",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
test('createMessage with Claude model adds cache control', async () => {
|
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||||
const handler = new OpenRouterHandler({
|
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||||
...mockOptions,
|
completions: { create: mockCreate },
|
||||||
openRouterModelId: 'anthropic/claude-3.5-sonnet'
|
} as any
|
||||||
})
|
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
||||||
const mockStream = {
|
|
||||||
async *[Symbol.asyncIterator]() {
|
|
||||||
yield {
|
|
||||||
id: 'test-id',
|
|
||||||
choices: [{
|
|
||||||
delta: {
|
|
||||||
content: 'test response'
|
|
||||||
}
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
await handler.createMessage("test", []).next()
|
||||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
|
||||||
completions: { create: mockCreate }
|
|
||||||
} as any
|
|
||||||
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
|
||||||
|
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [
|
expect(mockCreate).toHaveBeenCalledWith(
|
||||||
{ role: 'user', content: 'message 1' },
|
expect.objectContaining({
|
||||||
{ role: 'assistant', content: 'response 1' },
|
transforms: ["middle-out"],
|
||||||
{ role: 'user', content: 'message 2' }
|
}),
|
||||||
]
|
)
|
||||||
|
})
|
||||||
|
|
||||||
await handler.createMessage('test system', messages).next()
|
test("createMessage with Claude model adds cache control", async () => {
|
||||||
|
const handler = new OpenRouterHandler({
|
||||||
|
...mockOptions,
|
||||||
|
openRouterModelId: "anthropic/claude-3.5-sonnet",
|
||||||
|
})
|
||||||
|
const mockStream = {
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
yield {
|
||||||
|
id: "test-id",
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {
|
||||||
|
content: "test response",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||||
messages: expect.arrayContaining([
|
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||||
expect.objectContaining({
|
completions: { create: mockCreate },
|
||||||
role: 'system',
|
} as any
|
||||||
content: expect.arrayContaining([
|
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
||||||
expect.objectContaining({
|
|
||||||
cache_control: { type: 'ephemeral' }
|
|
||||||
})
|
|
||||||
])
|
|
||||||
})
|
|
||||||
])
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
|
|
||||||
test('createMessage handles API errors', async () => {
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
const handler = new OpenRouterHandler(mockOptions)
|
{ role: "user", content: "message 1" },
|
||||||
const mockStream = {
|
{ role: "assistant", content: "response 1" },
|
||||||
async *[Symbol.asyncIterator]() {
|
{ role: "user", content: "message 2" },
|
||||||
yield {
|
]
|
||||||
error: {
|
|
||||||
message: 'API Error',
|
|
||||||
code: 500
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
await handler.createMessage("test system", messages).next()
|
||||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
|
||||||
completions: { create: mockCreate }
|
|
||||||
} as any
|
|
||||||
|
|
||||||
const generator = handler.createMessage('test', [])
|
expect(mockCreate).toHaveBeenCalledWith(
|
||||||
await expect(generator.next()).rejects.toThrow('OpenRouter API Error 500: API Error')
|
expect.objectContaining({
|
||||||
})
|
messages: expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
role: "system",
|
||||||
|
content: expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
cache_control: { type: "ephemeral" },
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
test('completePrompt returns correct response', async () => {
|
test("createMessage handles API errors", async () => {
|
||||||
const handler = new OpenRouterHandler(mockOptions)
|
const handler = new OpenRouterHandler(mockOptions)
|
||||||
const mockResponse = {
|
const mockStream = {
|
||||||
choices: [{
|
async *[Symbol.asyncIterator]() {
|
||||||
message: {
|
yield {
|
||||||
content: 'test completion'
|
error: {
|
||||||
}
|
message: "API Error",
|
||||||
}]
|
code: 500,
|
||||||
}
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
const mockCreate = jest.fn().mockResolvedValue(mockResponse)
|
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||||
completions: { create: mockCreate }
|
completions: { create: mockCreate },
|
||||||
} as any
|
} as any
|
||||||
|
|
||||||
const result = await handler.completePrompt('test prompt')
|
const generator = handler.createMessage("test", [])
|
||||||
|
await expect(generator.next()).rejects.toThrow("OpenRouter API Error 500: API Error")
|
||||||
|
})
|
||||||
|
|
||||||
expect(result).toBe('test completion')
|
test("completePrompt returns correct response", async () => {
|
||||||
expect(mockCreate).toHaveBeenCalledWith({
|
const handler = new OpenRouterHandler(mockOptions)
|
||||||
model: mockOptions.openRouterModelId,
|
const mockResponse = {
|
||||||
messages: [{ role: 'user', content: 'test prompt' }],
|
choices: [
|
||||||
temperature: 0,
|
{
|
||||||
stream: false
|
message: {
|
||||||
})
|
content: "test completion",
|
||||||
})
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
test('completePrompt handles API errors', async () => {
|
const mockCreate = jest.fn().mockResolvedValue(mockResponse)
|
||||||
const handler = new OpenRouterHandler(mockOptions)
|
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||||
const mockError = {
|
completions: { create: mockCreate },
|
||||||
error: {
|
} as any
|
||||||
message: 'API Error',
|
|
||||||
code: 500
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const mockCreate = jest.fn().mockResolvedValue(mockError)
|
const result = await handler.completePrompt("test prompt")
|
||||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
|
||||||
completions: { create: mockCreate }
|
|
||||||
} as any
|
|
||||||
|
|
||||||
await expect(handler.completePrompt('test prompt'))
|
expect(result).toBe("test completion")
|
||||||
.rejects.toThrow('OpenRouter API Error 500: API Error')
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
})
|
model: mockOptions.openRouterModelId,
|
||||||
|
messages: [{ role: "user", content: "test prompt" }],
|
||||||
|
temperature: 0,
|
||||||
|
stream: false,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
test('completePrompt handles unexpected errors', async () => {
|
test("completePrompt handles API errors", async () => {
|
||||||
const handler = new OpenRouterHandler(mockOptions)
|
const handler = new OpenRouterHandler(mockOptions)
|
||||||
const mockCreate = jest.fn().mockRejectedValue(new Error('Unexpected error'))
|
const mockError = {
|
||||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
error: {
|
||||||
completions: { create: mockCreate }
|
message: "API Error",
|
||||||
} as any
|
code: 500,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
await expect(handler.completePrompt('test prompt'))
|
const mockCreate = jest.fn().mockResolvedValue(mockError)
|
||||||
.rejects.toThrow('OpenRouter completion error: Unexpected error')
|
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||||
})
|
completions: { create: mockCreate },
|
||||||
|
} as any
|
||||||
|
|
||||||
|
await expect(handler.completePrompt("test prompt")).rejects.toThrow("OpenRouter API Error 500: API Error")
|
||||||
|
})
|
||||||
|
|
||||||
|
test("completePrompt handles unexpected errors", async () => {
|
||||||
|
const handler = new OpenRouterHandler(mockOptions)
|
||||||
|
const mockCreate = jest.fn().mockRejectedValue(new Error("Unexpected error"))
|
||||||
|
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||||
|
completions: { create: mockCreate },
|
||||||
|
} as any
|
||||||
|
|
||||||
|
await expect(handler.completePrompt("test prompt")).rejects.toThrow(
|
||||||
|
"OpenRouter completion error: Unexpected error",
|
||||||
|
)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,296 +1,295 @@
|
|||||||
import { VertexHandler } from '../vertex';
|
import { VertexHandler } from "../vertex"
|
||||||
import { Anthropic } from '@anthropic-ai/sdk';
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import { AnthropicVertex } from '@anthropic-ai/vertex-sdk';
|
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
|
||||||
|
|
||||||
// Mock Vertex SDK
|
// Mock Vertex SDK
|
||||||
jest.mock('@anthropic-ai/vertex-sdk', () => ({
|
jest.mock("@anthropic-ai/vertex-sdk", () => ({
|
||||||
AnthropicVertex: jest.fn().mockImplementation(() => ({
|
AnthropicVertex: jest.fn().mockImplementation(() => ({
|
||||||
messages: {
|
messages: {
|
||||||
create: jest.fn().mockImplementation(async (options) => {
|
create: jest.fn().mockImplementation(async (options) => {
|
||||||
if (!options.stream) {
|
if (!options.stream) {
|
||||||
return {
|
return {
|
||||||
id: 'test-completion',
|
id: "test-completion",
|
||||||
content: [
|
content: [{ type: "text", text: "Test response" }],
|
||||||
{ type: 'text', text: 'Test response' }
|
role: "assistant",
|
||||||
],
|
model: options.model,
|
||||||
role: 'assistant',
|
usage: {
|
||||||
model: options.model,
|
input_tokens: 10,
|
||||||
usage: {
|
output_tokens: 5,
|
||||||
input_tokens: 10,
|
},
|
||||||
output_tokens: 5
|
}
|
||||||
}
|
}
|
||||||
}
|
return {
|
||||||
}
|
async *[Symbol.asyncIterator]() {
|
||||||
return {
|
yield {
|
||||||
async *[Symbol.asyncIterator]() {
|
type: "message_start",
|
||||||
yield {
|
message: {
|
||||||
type: 'message_start',
|
usage: {
|
||||||
message: {
|
input_tokens: 10,
|
||||||
usage: {
|
output_tokens: 5,
|
||||||
input_tokens: 10,
|
},
|
||||||
output_tokens: 5
|
},
|
||||||
}
|
}
|
||||||
}
|
yield {
|
||||||
}
|
type: "content_block_start",
|
||||||
yield {
|
content_block: {
|
||||||
type: 'content_block_start',
|
type: "text",
|
||||||
content_block: {
|
text: "Test response",
|
||||||
type: 'text',
|
},
|
||||||
text: 'Test response'
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}),
|
||||||
}
|
},
|
||||||
})
|
})),
|
||||||
}
|
}))
|
||||||
}))
|
|
||||||
}));
|
|
||||||
|
|
||||||
describe('VertexHandler', () => {
|
describe("VertexHandler", () => {
|
||||||
let handler: VertexHandler;
|
let handler: VertexHandler
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
handler = new VertexHandler({
|
handler = new VertexHandler({
|
||||||
apiModelId: 'claude-3-5-sonnet-v2@20241022',
|
apiModelId: "claude-3-5-sonnet-v2@20241022",
|
||||||
vertexProjectId: 'test-project',
|
vertexProjectId: "test-project",
|
||||||
vertexRegion: 'us-central1'
|
vertexRegion: "us-central1",
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('constructor', () => {
|
describe("constructor", () => {
|
||||||
it('should initialize with provided config', () => {
|
it("should initialize with provided config", () => {
|
||||||
expect(AnthropicVertex).toHaveBeenCalledWith({
|
expect(AnthropicVertex).toHaveBeenCalledWith({
|
||||||
projectId: 'test-project',
|
projectId: "test-project",
|
||||||
region: 'us-central1'
|
region: "us-central1",
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('createMessage', () => {
|
describe("createMessage", () => {
|
||||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: "user",
|
||||||
content: 'Hello'
|
content: "Hello",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
role: 'assistant',
|
role: "assistant",
|
||||||
content: 'Hi there!'
|
content: "Hi there!",
|
||||||
}
|
},
|
||||||
];
|
]
|
||||||
|
|
||||||
const systemPrompt = 'You are a helpful assistant';
|
const systemPrompt = "You are a helpful assistant"
|
||||||
|
|
||||||
it('should handle streaming responses correctly', async () => {
|
it("should handle streaming responses correctly", async () => {
|
||||||
const mockStream = [
|
const mockStream = [
|
||||||
{
|
{
|
||||||
type: 'message_start',
|
type: "message_start",
|
||||||
message: {
|
message: {
|
||||||
usage: {
|
usage: {
|
||||||
input_tokens: 10,
|
input_tokens: 10,
|
||||||
output_tokens: 0
|
output_tokens: 0,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: 'content_block_start',
|
type: "content_block_start",
|
||||||
index: 0,
|
index: 0,
|
||||||
content_block: {
|
content_block: {
|
||||||
type: 'text',
|
type: "text",
|
||||||
text: 'Hello'
|
text: "Hello",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: 'content_block_delta',
|
type: "content_block_delta",
|
||||||
delta: {
|
delta: {
|
||||||
type: 'text_delta',
|
type: "text_delta",
|
||||||
text: ' world!'
|
text: " world!",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: 'message_delta',
|
type: "message_delta",
|
||||||
usage: {
|
usage: {
|
||||||
output_tokens: 5
|
output_tokens: 5,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
];
|
]
|
||||||
|
|
||||||
// Setup async iterator for mock stream
|
// Setup async iterator for mock stream
|
||||||
const asyncIterator = {
|
const asyncIterator = {
|
||||||
async *[Symbol.asyncIterator]() {
|
async *[Symbol.asyncIterator]() {
|
||||||
for (const chunk of mockStream) {
|
for (const chunk of mockStream) {
|
||||||
yield chunk;
|
yield chunk
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
};
|
}
|
||||||
|
|
||||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
|
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
|
||||||
(handler['client'].messages as any).create = mockCreate;
|
;(handler["client"].messages as any).create = mockCreate
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||||
const chunks = [];
|
const chunks = []
|
||||||
|
|
||||||
for await (const chunk of stream) {
|
|
||||||
chunks.push(chunk);
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(chunks.length).toBe(4);
|
for await (const chunk of stream) {
|
||||||
expect(chunks[0]).toEqual({
|
chunks.push(chunk)
|
||||||
type: 'usage',
|
}
|
||||||
inputTokens: 10,
|
|
||||||
outputTokens: 0
|
|
||||||
});
|
|
||||||
expect(chunks[1]).toEqual({
|
|
||||||
type: 'text',
|
|
||||||
text: 'Hello'
|
|
||||||
});
|
|
||||||
expect(chunks[2]).toEqual({
|
|
||||||
type: 'text',
|
|
||||||
text: ' world!'
|
|
||||||
});
|
|
||||||
expect(chunks[3]).toEqual({
|
|
||||||
type: 'usage',
|
|
||||||
inputTokens: 0,
|
|
||||||
outputTokens: 5
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(mockCreate).toHaveBeenCalledWith({
|
expect(chunks.length).toBe(4)
|
||||||
model: 'claude-3-5-sonnet-v2@20241022',
|
expect(chunks[0]).toEqual({
|
||||||
max_tokens: 8192,
|
type: "usage",
|
||||||
temperature: 0,
|
inputTokens: 10,
|
||||||
system: systemPrompt,
|
outputTokens: 0,
|
||||||
messages: mockMessages,
|
})
|
||||||
stream: true
|
expect(chunks[1]).toEqual({
|
||||||
});
|
type: "text",
|
||||||
});
|
text: "Hello",
|
||||||
|
})
|
||||||
|
expect(chunks[2]).toEqual({
|
||||||
|
type: "text",
|
||||||
|
text: " world!",
|
||||||
|
})
|
||||||
|
expect(chunks[3]).toEqual({
|
||||||
|
type: "usage",
|
||||||
|
inputTokens: 0,
|
||||||
|
outputTokens: 5,
|
||||||
|
})
|
||||||
|
|
||||||
it('should handle multiple content blocks with line breaks', async () => {
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
const mockStream = [
|
model: "claude-3-5-sonnet-v2@20241022",
|
||||||
{
|
max_tokens: 8192,
|
||||||
type: 'content_block_start',
|
temperature: 0,
|
||||||
index: 0,
|
system: systemPrompt,
|
||||||
content_block: {
|
messages: mockMessages,
|
||||||
type: 'text',
|
stream: true,
|
||||||
text: 'First line'
|
})
|
||||||
}
|
})
|
||||||
},
|
|
||||||
{
|
|
||||||
type: 'content_block_start',
|
|
||||||
index: 1,
|
|
||||||
content_block: {
|
|
||||||
type: 'text',
|
|
||||||
text: 'Second line'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
];
|
|
||||||
|
|
||||||
const asyncIterator = {
|
it("should handle multiple content blocks with line breaks", async () => {
|
||||||
async *[Symbol.asyncIterator]() {
|
const mockStream = [
|
||||||
for (const chunk of mockStream) {
|
{
|
||||||
yield chunk;
|
type: "content_block_start",
|
||||||
}
|
index: 0,
|
||||||
}
|
content_block: {
|
||||||
};
|
type: "text",
|
||||||
|
text: "First line",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: "content_block_start",
|
||||||
|
index: 1,
|
||||||
|
content_block: {
|
||||||
|
type: "text",
|
||||||
|
text: "Second line",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
|
const asyncIterator = {
|
||||||
(handler['client'].messages as any).create = mockCreate;
|
async *[Symbol.asyncIterator]() {
|
||||||
|
for (const chunk of mockStream) {
|
||||||
|
yield chunk
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
|
||||||
const chunks = [];
|
;(handler["client"].messages as any).create = mockCreate
|
||||||
|
|
||||||
for await (const chunk of stream) {
|
|
||||||
chunks.push(chunk);
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(chunks.length).toBe(3);
|
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||||
expect(chunks[0]).toEqual({
|
const chunks = []
|
||||||
type: 'text',
|
|
||||||
text: 'First line'
|
|
||||||
});
|
|
||||||
expect(chunks[1]).toEqual({
|
|
||||||
type: 'text',
|
|
||||||
text: '\n'
|
|
||||||
});
|
|
||||||
expect(chunks[2]).toEqual({
|
|
||||||
type: 'text',
|
|
||||||
text: 'Second line'
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
for await (const chunk of stream) {
|
||||||
const mockError = new Error('Vertex API error');
|
chunks.push(chunk)
|
||||||
const mockCreate = jest.fn().mockRejectedValue(mockError);
|
}
|
||||||
(handler['client'].messages as any).create = mockCreate;
|
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
expect(chunks.length).toBe(3)
|
||||||
|
expect(chunks[0]).toEqual({
|
||||||
|
type: "text",
|
||||||
|
text: "First line",
|
||||||
|
})
|
||||||
|
expect(chunks[1]).toEqual({
|
||||||
|
type: "text",
|
||||||
|
text: "\n",
|
||||||
|
})
|
||||||
|
expect(chunks[2]).toEqual({
|
||||||
|
type: "text",
|
||||||
|
text: "Second line",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
await expect(async () => {
|
it("should handle API errors", async () => {
|
||||||
for await (const chunk of stream) {
|
const mockError = new Error("Vertex API error")
|
||||||
// Should throw before yielding any chunks
|
const mockCreate = jest.fn().mockRejectedValue(mockError)
|
||||||
}
|
;(handler["client"].messages as any).create = mockCreate
|
||||||
}).rejects.toThrow('Vertex API error');
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('completePrompt', () => {
|
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||||
it('should complete prompt successfully', async () => {
|
|
||||||
const result = await handler.completePrompt('Test prompt');
|
|
||||||
expect(result).toBe('Test response');
|
|
||||||
expect(handler['client'].messages.create).toHaveBeenCalledWith({
|
|
||||||
model: 'claude-3-5-sonnet-v2@20241022',
|
|
||||||
max_tokens: 8192,
|
|
||||||
temperature: 0,
|
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
|
||||||
stream: false
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle API errors', async () => {
|
await expect(async () => {
|
||||||
const mockError = new Error('Vertex API error');
|
for await (const chunk of stream) {
|
||||||
const mockCreate = jest.fn().mockRejectedValue(mockError);
|
// Should throw before yielding any chunks
|
||||||
(handler['client'].messages as any).create = mockCreate;
|
}
|
||||||
|
}).rejects.toThrow("Vertex API error")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
await expect(handler.completePrompt('Test prompt'))
|
describe("completePrompt", () => {
|
||||||
.rejects.toThrow('Vertex completion error: Vertex API error');
|
it("should complete prompt successfully", async () => {
|
||||||
});
|
const result = await handler.completePrompt("Test prompt")
|
||||||
|
expect(result).toBe("Test response")
|
||||||
|
expect(handler["client"].messages.create).toHaveBeenCalledWith({
|
||||||
|
model: "claude-3-5-sonnet-v2@20241022",
|
||||||
|
max_tokens: 8192,
|
||||||
|
temperature: 0,
|
||||||
|
messages: [{ role: "user", content: "Test prompt" }],
|
||||||
|
stream: false,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
it('should handle non-text content', async () => {
|
it("should handle API errors", async () => {
|
||||||
const mockCreate = jest.fn().mockResolvedValue({
|
const mockError = new Error("Vertex API error")
|
||||||
content: [{ type: 'image' }]
|
const mockCreate = jest.fn().mockRejectedValue(mockError)
|
||||||
});
|
;(handler["client"].messages as any).create = mockCreate
|
||||||
(handler['client'].messages as any).create = mockCreate;
|
|
||||||
|
|
||||||
const result = await handler.completePrompt('Test prompt');
|
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||||
expect(result).toBe('');
|
"Vertex completion error: Vertex API error",
|
||||||
});
|
)
|
||||||
|
})
|
||||||
|
|
||||||
it('should handle empty response', async () => {
|
it("should handle non-text content", async () => {
|
||||||
const mockCreate = jest.fn().mockResolvedValue({
|
const mockCreate = jest.fn().mockResolvedValue({
|
||||||
content: [{ type: 'text', text: '' }]
|
content: [{ type: "image" }],
|
||||||
});
|
})
|
||||||
(handler['client'].messages as any).create = mockCreate;
|
;(handler["client"].messages as any).create = mockCreate
|
||||||
|
|
||||||
const result = await handler.completePrompt('Test prompt');
|
const result = await handler.completePrompt("Test prompt")
|
||||||
expect(result).toBe('');
|
expect(result).toBe("")
|
||||||
});
|
})
|
||||||
});
|
|
||||||
|
|
||||||
describe('getModel', () => {
|
it("should handle empty response", async () => {
|
||||||
it('should return correct model info', () => {
|
const mockCreate = jest.fn().mockResolvedValue({
|
||||||
const modelInfo = handler.getModel();
|
content: [{ type: "text", text: "" }],
|
||||||
expect(modelInfo.id).toBe('claude-3-5-sonnet-v2@20241022');
|
})
|
||||||
expect(modelInfo.info).toBeDefined();
|
;(handler["client"].messages as any).create = mockCreate
|
||||||
expect(modelInfo.info.maxTokens).toBe(8192);
|
|
||||||
expect(modelInfo.info.contextWindow).toBe(200_000);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should return default model if invalid model specified', () => {
|
const result = await handler.completePrompt("Test prompt")
|
||||||
const invalidHandler = new VertexHandler({
|
expect(result).toBe("")
|
||||||
apiModelId: 'invalid-model',
|
})
|
||||||
vertexProjectId: 'test-project',
|
})
|
||||||
vertexRegion: 'us-central1'
|
|
||||||
});
|
describe("getModel", () => {
|
||||||
const modelInfo = invalidHandler.getModel();
|
it("should return correct model info", () => {
|
||||||
expect(modelInfo.id).toBe('claude-3-5-sonnet-v2@20241022'); // Default model
|
const modelInfo = handler.getModel()
|
||||||
});
|
expect(modelInfo.id).toBe("claude-3-5-sonnet-v2@20241022")
|
||||||
});
|
expect(modelInfo.info).toBeDefined()
|
||||||
});
|
expect(modelInfo.info.maxTokens).toBe(8192)
|
||||||
|
expect(modelInfo.info.contextWindow).toBe(200_000)
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should return default model if invalid model specified", () => {
|
||||||
|
const invalidHandler = new VertexHandler({
|
||||||
|
apiModelId: "invalid-model",
|
||||||
|
vertexProjectId: "test-project",
|
||||||
|
vertexRegion: "us-central1",
|
||||||
|
})
|
||||||
|
const modelInfo = invalidHandler.getModel()
|
||||||
|
expect(modelInfo.id).toBe("claude-3-5-sonnet-v2@20241022") // Default model
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -1,289 +1,295 @@
|
|||||||
import * as vscode from 'vscode';
|
import * as vscode from "vscode"
|
||||||
import { VsCodeLmHandler } from '../vscode-lm';
|
import { VsCodeLmHandler } from "../vscode-lm"
|
||||||
import { ApiHandlerOptions } from '../../../shared/api';
|
import { ApiHandlerOptions } from "../../../shared/api"
|
||||||
import { Anthropic } from '@anthropic-ai/sdk';
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
|
|
||||||
// Mock vscode namespace
|
// Mock vscode namespace
|
||||||
jest.mock('vscode', () => {
|
jest.mock("vscode", () => {
|
||||||
class MockLanguageModelTextPart {
|
class MockLanguageModelTextPart {
|
||||||
type = 'text';
|
type = "text"
|
||||||
constructor(public value: string) {}
|
constructor(public value: string) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
class MockLanguageModelToolCallPart {
|
class MockLanguageModelToolCallPart {
|
||||||
type = 'tool_call';
|
type = "tool_call"
|
||||||
constructor(
|
constructor(
|
||||||
public callId: string,
|
public callId: string,
|
||||||
public name: string,
|
public name: string,
|
||||||
public input: any
|
public input: any,
|
||||||
) {}
|
) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
workspace: {
|
workspace: {
|
||||||
onDidChangeConfiguration: jest.fn((callback) => ({
|
onDidChangeConfiguration: jest.fn((callback) => ({
|
||||||
dispose: jest.fn()
|
dispose: jest.fn(),
|
||||||
}))
|
})),
|
||||||
},
|
},
|
||||||
CancellationTokenSource: jest.fn(() => ({
|
CancellationTokenSource: jest.fn(() => ({
|
||||||
token: {
|
token: {
|
||||||
isCancellationRequested: false,
|
isCancellationRequested: false,
|
||||||
onCancellationRequested: jest.fn()
|
onCancellationRequested: jest.fn(),
|
||||||
},
|
},
|
||||||
cancel: jest.fn(),
|
cancel: jest.fn(),
|
||||||
dispose: jest.fn()
|
dispose: jest.fn(),
|
||||||
})),
|
})),
|
||||||
CancellationError: class CancellationError extends Error {
|
CancellationError: class CancellationError extends Error {
|
||||||
constructor() {
|
constructor() {
|
||||||
super('Operation cancelled');
|
super("Operation cancelled")
|
||||||
this.name = 'CancellationError';
|
this.name = "CancellationError"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
LanguageModelChatMessage: {
|
LanguageModelChatMessage: {
|
||||||
Assistant: jest.fn((content) => ({
|
Assistant: jest.fn((content) => ({
|
||||||
role: 'assistant',
|
role: "assistant",
|
||||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
|
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
|
||||||
})),
|
})),
|
||||||
User: jest.fn((content) => ({
|
User: jest.fn((content) => ({
|
||||||
role: 'user',
|
role: "user",
|
||||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
|
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
|
||||||
}))
|
})),
|
||||||
},
|
},
|
||||||
LanguageModelTextPart: MockLanguageModelTextPart,
|
LanguageModelTextPart: MockLanguageModelTextPart,
|
||||||
LanguageModelToolCallPart: MockLanguageModelToolCallPart,
|
LanguageModelToolCallPart: MockLanguageModelToolCallPart,
|
||||||
lm: {
|
lm: {
|
||||||
selectChatModels: jest.fn()
|
selectChatModels: jest.fn(),
|
||||||
}
|
},
|
||||||
};
|
}
|
||||||
});
|
})
|
||||||
|
|
||||||
const mockLanguageModelChat = {
|
const mockLanguageModelChat = {
|
||||||
id: 'test-model',
|
id: "test-model",
|
||||||
name: 'Test Model',
|
name: "Test Model",
|
||||||
vendor: 'test-vendor',
|
vendor: "test-vendor",
|
||||||
family: 'test-family',
|
family: "test-family",
|
||||||
version: '1.0',
|
version: "1.0",
|
||||||
maxInputTokens: 4096,
|
maxInputTokens: 4096,
|
||||||
sendRequest: jest.fn(),
|
sendRequest: jest.fn(),
|
||||||
countTokens: jest.fn()
|
countTokens: jest.fn(),
|
||||||
};
|
}
|
||||||
|
|
||||||
describe('VsCodeLmHandler', () => {
|
describe("VsCodeLmHandler", () => {
|
||||||
let handler: VsCodeLmHandler;
|
let handler: VsCodeLmHandler
|
||||||
const defaultOptions: ApiHandlerOptions = {
|
const defaultOptions: ApiHandlerOptions = {
|
||||||
vsCodeLmModelSelector: {
|
vsCodeLmModelSelector: {
|
||||||
vendor: 'test-vendor',
|
vendor: "test-vendor",
|
||||||
family: 'test-family'
|
family: "test-family",
|
||||||
}
|
},
|
||||||
};
|
}
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks();
|
jest.clearAllMocks()
|
||||||
handler = new VsCodeLmHandler(defaultOptions);
|
handler = new VsCodeLmHandler(defaultOptions)
|
||||||
});
|
})
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
handler.dispose();
|
handler.dispose()
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('constructor', () => {
|
describe("constructor", () => {
|
||||||
it('should initialize with provided options', () => {
|
it("should initialize with provided options", () => {
|
||||||
expect(handler).toBeDefined();
|
expect(handler).toBeDefined()
|
||||||
expect(vscode.workspace.onDidChangeConfiguration).toHaveBeenCalled();
|
expect(vscode.workspace.onDidChangeConfiguration).toHaveBeenCalled()
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle configuration changes', () => {
|
it("should handle configuration changes", () => {
|
||||||
const callback = (vscode.workspace.onDidChangeConfiguration as jest.Mock).mock.calls[0][0];
|
const callback = (vscode.workspace.onDidChangeConfiguration as jest.Mock).mock.calls[0][0]
|
||||||
callback({ affectsConfiguration: () => true });
|
callback({ affectsConfiguration: () => true })
|
||||||
// Should reset client when config changes
|
// Should reset client when config changes
|
||||||
expect(handler['client']).toBeNull();
|
expect(handler["client"]).toBeNull()
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('createClient', () => {
|
describe("createClient", () => {
|
||||||
it('should create client with selector', async () => {
|
it("should create client with selector", async () => {
|
||||||
const mockModel = { ...mockLanguageModelChat };
|
const mockModel = { ...mockLanguageModelChat }
|
||||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||||
|
|
||||||
const client = await handler['createClient']({
|
const client = await handler["createClient"]({
|
||||||
vendor: 'test-vendor',
|
vendor: "test-vendor",
|
||||||
family: 'test-family'
|
family: "test-family",
|
||||||
});
|
})
|
||||||
|
|
||||||
expect(client).toBeDefined();
|
expect(client).toBeDefined()
|
||||||
expect(client.id).toBe('test-model');
|
expect(client.id).toBe("test-model")
|
||||||
expect(vscode.lm.selectChatModels).toHaveBeenCalledWith({
|
expect(vscode.lm.selectChatModels).toHaveBeenCalledWith({
|
||||||
vendor: 'test-vendor',
|
vendor: "test-vendor",
|
||||||
family: 'test-family'
|
family: "test-family",
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should return default client when no models available', async () => {
|
it("should return default client when no models available", async () => {
|
||||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([]);
|
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([])
|
||||||
|
|
||||||
const client = await handler['createClient']({});
|
const client = await handler["createClient"]({})
|
||||||
|
|
||||||
expect(client).toBeDefined();
|
|
||||||
expect(client.id).toBe('default-lm');
|
|
||||||
expect(client.vendor).toBe('vscode');
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('createMessage', () => {
|
expect(client).toBeDefined()
|
||||||
|
expect(client.id).toBe("default-lm")
|
||||||
|
expect(client.vendor).toBe("vscode")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("createMessage", () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
const mockModel = { ...mockLanguageModelChat };
|
const mockModel = { ...mockLanguageModelChat }
|
||||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||||
mockLanguageModelChat.countTokens.mockResolvedValue(10);
|
mockLanguageModelChat.countTokens.mockResolvedValue(10)
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should stream text responses', async () => {
|
it("should stream text responses", async () => {
|
||||||
const systemPrompt = 'You are a helpful assistant';
|
const systemPrompt = "You are a helpful assistant"
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
role: 'user' as const,
|
{
|
||||||
content: 'Hello'
|
role: "user" as const,
|
||||||
}];
|
content: "Hello",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
const responseText = 'Hello! How can I help you?';
|
const responseText = "Hello! How can I help you?"
|
||||||
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
|
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
|
||||||
stream: (async function* () {
|
stream: (async function* () {
|
||||||
yield new vscode.LanguageModelTextPart(responseText);
|
yield new vscode.LanguageModelTextPart(responseText)
|
||||||
return;
|
return
|
||||||
})(),
|
})(),
|
||||||
text: (async function* () {
|
text: (async function* () {
|
||||||
yield responseText;
|
yield responseText
|
||||||
return;
|
return
|
||||||
})()
|
})(),
|
||||||
});
|
})
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, messages);
|
const stream = handler.createMessage(systemPrompt, messages)
|
||||||
const chunks = [];
|
const chunks = []
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
chunks.push(chunk);
|
chunks.push(chunk)
|
||||||
}
|
}
|
||||||
|
|
||||||
expect(chunks).toHaveLength(2); // Text chunk + usage chunk
|
expect(chunks).toHaveLength(2) // Text chunk + usage chunk
|
||||||
expect(chunks[0]).toEqual({
|
expect(chunks[0]).toEqual({
|
||||||
type: 'text',
|
type: "text",
|
||||||
text: responseText
|
text: responseText,
|
||||||
});
|
})
|
||||||
expect(chunks[1]).toMatchObject({
|
expect(chunks[1]).toMatchObject({
|
||||||
type: 'usage',
|
type: "usage",
|
||||||
inputTokens: expect.any(Number),
|
inputTokens: expect.any(Number),
|
||||||
outputTokens: expect.any(Number)
|
outputTokens: expect.any(Number),
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle tool calls', async () => {
|
it("should handle tool calls", async () => {
|
||||||
const systemPrompt = 'You are a helpful assistant';
|
const systemPrompt = "You are a helpful assistant"
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
role: 'user' as const,
|
{
|
||||||
content: 'Calculate 2+2'
|
role: "user" as const,
|
||||||
}];
|
content: "Calculate 2+2",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
const toolCallData = {
|
const toolCallData = {
|
||||||
name: 'calculator',
|
name: "calculator",
|
||||||
arguments: { operation: 'add', numbers: [2, 2] },
|
arguments: { operation: "add", numbers: [2, 2] },
|
||||||
callId: 'call-1'
|
callId: "call-1",
|
||||||
};
|
}
|
||||||
|
|
||||||
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
|
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
|
||||||
stream: (async function* () {
|
stream: (async function* () {
|
||||||
yield new vscode.LanguageModelToolCallPart(
|
yield new vscode.LanguageModelToolCallPart(
|
||||||
toolCallData.callId,
|
toolCallData.callId,
|
||||||
toolCallData.name,
|
toolCallData.name,
|
||||||
toolCallData.arguments
|
toolCallData.arguments,
|
||||||
);
|
)
|
||||||
return;
|
return
|
||||||
})(),
|
})(),
|
||||||
text: (async function* () {
|
text: (async function* () {
|
||||||
yield JSON.stringify({ type: 'tool_call', ...toolCallData });
|
yield JSON.stringify({ type: "tool_call", ...toolCallData })
|
||||||
return;
|
return
|
||||||
})()
|
})(),
|
||||||
});
|
})
|
||||||
|
|
||||||
const stream = handler.createMessage(systemPrompt, messages);
|
const stream = handler.createMessage(systemPrompt, messages)
|
||||||
const chunks = [];
|
const chunks = []
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
chunks.push(chunk);
|
chunks.push(chunk)
|
||||||
}
|
}
|
||||||
|
|
||||||
expect(chunks).toHaveLength(2); // Tool call chunk + usage chunk
|
expect(chunks).toHaveLength(2) // Tool call chunk + usage chunk
|
||||||
expect(chunks[0]).toEqual({
|
expect(chunks[0]).toEqual({
|
||||||
type: 'text',
|
type: "text",
|
||||||
text: JSON.stringify({ type: 'tool_call', ...toolCallData })
|
text: JSON.stringify({ type: "tool_call", ...toolCallData }),
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle errors', async () => {
|
it("should handle errors", async () => {
|
||||||
const systemPrompt = 'You are a helpful assistant';
|
const systemPrompt = "You are a helpful assistant"
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
role: 'user' as const,
|
{
|
||||||
content: 'Hello'
|
role: "user" as const,
|
||||||
}];
|
content: "Hello",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error('API Error'));
|
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("API Error"))
|
||||||
|
|
||||||
await expect(async () => {
|
await expect(async () => {
|
||||||
const stream = handler.createMessage(systemPrompt, messages);
|
const stream = handler.createMessage(systemPrompt, messages)
|
||||||
for await (const _ of stream) {
|
for await (const _ of stream) {
|
||||||
// consume stream
|
// consume stream
|
||||||
}
|
}
|
||||||
}).rejects.toThrow('API Error');
|
}).rejects.toThrow("API Error")
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
|
describe("getModel", () => {
|
||||||
|
it("should return model info when client exists", async () => {
|
||||||
|
const mockModel = { ...mockLanguageModelChat }
|
||||||
|
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||||
|
|
||||||
describe('getModel', () => {
|
|
||||||
it('should return model info when client exists', async () => {
|
|
||||||
const mockModel = { ...mockLanguageModelChat };
|
|
||||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
|
||||||
|
|
||||||
// Initialize client
|
// Initialize client
|
||||||
await handler['getClient']();
|
await handler["getClient"]()
|
||||||
|
|
||||||
const model = handler.getModel();
|
|
||||||
expect(model.id).toBe('test-model');
|
|
||||||
expect(model.info).toBeDefined();
|
|
||||||
expect(model.info.contextWindow).toBe(4096);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should return fallback model info when no client exists', () => {
|
const model = handler.getModel()
|
||||||
const model = handler.getModel();
|
expect(model.id).toBe("test-model")
|
||||||
expect(model.id).toBe('test-vendor/test-family');
|
expect(model.info).toBeDefined()
|
||||||
expect(model.info).toBeDefined();
|
expect(model.info.contextWindow).toBe(4096)
|
||||||
});
|
})
|
||||||
});
|
|
||||||
|
|
||||||
describe('completePrompt', () => {
|
it("should return fallback model info when no client exists", () => {
|
||||||
it('should complete single prompt', async () => {
|
const model = handler.getModel()
|
||||||
const mockModel = { ...mockLanguageModelChat };
|
expect(model.id).toBe("test-vendor/test-family")
|
||||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
expect(model.info).toBeDefined()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
const responseText = 'Completed text';
|
describe("completePrompt", () => {
|
||||||
|
it("should complete single prompt", async () => {
|
||||||
|
const mockModel = { ...mockLanguageModelChat }
|
||||||
|
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||||
|
|
||||||
|
const responseText = "Completed text"
|
||||||
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
|
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
|
||||||
stream: (async function* () {
|
stream: (async function* () {
|
||||||
yield new vscode.LanguageModelTextPart(responseText);
|
yield new vscode.LanguageModelTextPart(responseText)
|
||||||
return;
|
return
|
||||||
})(),
|
})(),
|
||||||
text: (async function* () {
|
text: (async function* () {
|
||||||
yield responseText;
|
yield responseText
|
||||||
return;
|
return
|
||||||
})()
|
})(),
|
||||||
});
|
})
|
||||||
|
|
||||||
const result = await handler.completePrompt('Test prompt');
|
const result = await handler.completePrompt("Test prompt")
|
||||||
expect(result).toBe(responseText);
|
expect(result).toBe(responseText)
|
||||||
expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled();
|
expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled()
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle errors during completion', async () => {
|
it("should handle errors during completion", async () => {
|
||||||
const mockModel = { ...mockLanguageModelChat };
|
const mockModel = { ...mockLanguageModelChat }
|
||||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||||
|
|
||||||
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error('Completion failed'));
|
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("Completion failed"))
|
||||||
|
|
||||||
await expect(handler.completePrompt('Test prompt'))
|
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||||
.rejects
|
"VSCode LM completion error: Completion failed",
|
||||||
.toThrow('VSCode LM completion error: Completion failed');
|
)
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|||||||
@@ -181,14 +181,14 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
max_tokens: this.getModel().info.maxTokens || 8192,
|
max_tokens: this.getModel().info.maxTokens || 8192,
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
messages: [{ role: "user", content: prompt }],
|
messages: [{ role: "user", content: prompt }],
|
||||||
stream: false
|
stream: false,
|
||||||
})
|
})
|
||||||
|
|
||||||
const content = response.content[0]
|
const content = response.content[0]
|
||||||
if (content.type === 'text') {
|
if (content.type === "text") {
|
||||||
return content.text
|
return content.text
|
||||||
}
|
}
|
||||||
return ''
|
return ""
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (error instanceof Error) {
|
if (error instanceof Error) {
|
||||||
throw new Error(`Anthropic completion error: ${error.message}`)
|
throw new Error(`Anthropic completion error: ${error.message}`)
|
||||||
|
|||||||
@@ -1,4 +1,9 @@
|
|||||||
import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime"
|
import {
|
||||||
|
BedrockRuntimeClient,
|
||||||
|
ConverseStreamCommand,
|
||||||
|
ConverseCommand,
|
||||||
|
BedrockRuntimeClientConfig,
|
||||||
|
} from "@aws-sdk/client-bedrock-runtime"
|
||||||
import { Anthropic } from "@anthropic-ai/sdk"
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import { ApiHandler, SingleCompletionHandler } from "../"
|
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||||
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
|
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
|
||||||
@@ -7,275 +12,276 @@ import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../
|
|||||||
|
|
||||||
// Define types for stream events based on AWS SDK
|
// Define types for stream events based on AWS SDK
|
||||||
export interface StreamEvent {
|
export interface StreamEvent {
|
||||||
messageStart?: {
|
messageStart?: {
|
||||||
role?: string;
|
role?: string
|
||||||
};
|
}
|
||||||
messageStop?: {
|
messageStop?: {
|
||||||
stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence";
|
stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence"
|
||||||
additionalModelResponseFields?: Record<string, unknown>;
|
additionalModelResponseFields?: Record<string, unknown>
|
||||||
};
|
}
|
||||||
contentBlockStart?: {
|
contentBlockStart?: {
|
||||||
start?: {
|
start?: {
|
||||||
text?: string;
|
text?: string
|
||||||
};
|
}
|
||||||
contentBlockIndex?: number;
|
contentBlockIndex?: number
|
||||||
};
|
}
|
||||||
contentBlockDelta?: {
|
contentBlockDelta?: {
|
||||||
delta?: {
|
delta?: {
|
||||||
text?: string;
|
text?: string
|
||||||
};
|
}
|
||||||
contentBlockIndex?: number;
|
contentBlockIndex?: number
|
||||||
};
|
}
|
||||||
metadata?: {
|
metadata?: {
|
||||||
usage?: {
|
usage?: {
|
||||||
inputTokens: number;
|
inputTokens: number
|
||||||
outputTokens: number;
|
outputTokens: number
|
||||||
totalTokens?: number; // Made optional since we don't use it
|
totalTokens?: number // Made optional since we don't use it
|
||||||
};
|
}
|
||||||
metrics?: {
|
metrics?: {
|
||||||
latencyMs: number;
|
latencyMs: number
|
||||||
};
|
}
|
||||||
};
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||||
private options: ApiHandlerOptions
|
private options: ApiHandlerOptions
|
||||||
private client: BedrockRuntimeClient
|
private client: BedrockRuntimeClient
|
||||||
|
|
||||||
constructor(options: ApiHandlerOptions) {
|
constructor(options: ApiHandlerOptions) {
|
||||||
this.options = options
|
this.options = options
|
||||||
|
|
||||||
// Only include credentials if they actually exist
|
|
||||||
const clientConfig: BedrockRuntimeClientConfig = {
|
|
||||||
region: this.options.awsRegion || "us-east-1"
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.options.awsAccessKey && this.options.awsSecretKey) {
|
// Only include credentials if they actually exist
|
||||||
// Create credentials object with all properties at once
|
const clientConfig: BedrockRuntimeClientConfig = {
|
||||||
clientConfig.credentials = {
|
region: this.options.awsRegion || "us-east-1",
|
||||||
accessKeyId: this.options.awsAccessKey,
|
}
|
||||||
secretAccessKey: this.options.awsSecretKey,
|
|
||||||
...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
this.client = new BedrockRuntimeClient(clientConfig)
|
if (this.options.awsAccessKey && this.options.awsSecretKey) {
|
||||||
}
|
// Create credentials object with all properties at once
|
||||||
|
clientConfig.credentials = {
|
||||||
|
accessKeyId: this.options.awsAccessKey,
|
||||||
|
secretAccessKey: this.options.awsSecretKey,
|
||||||
|
...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
this.client = new BedrockRuntimeClient(clientConfig)
|
||||||
const modelConfig = this.getModel()
|
}
|
||||||
|
|
||||||
// Handle cross-region inference
|
|
||||||
let modelId: string
|
|
||||||
if (this.options.awsUseCrossRegionInference) {
|
|
||||||
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
|
|
||||||
switch (regionPrefix) {
|
|
||||||
case "us-":
|
|
||||||
modelId = `us.${modelConfig.id}`
|
|
||||||
break
|
|
||||||
case "eu-":
|
|
||||||
modelId = `eu.${modelConfig.id}`
|
|
||||||
break
|
|
||||||
default:
|
|
||||||
modelId = modelConfig.id
|
|
||||||
break
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
modelId = modelConfig.id
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert messages to Bedrock format
|
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
||||||
const formattedMessages = convertToBedrockConverseMessages(messages)
|
const modelConfig = this.getModel()
|
||||||
|
|
||||||
// Construct the payload
|
// Handle cross-region inference
|
||||||
const payload = {
|
let modelId: string
|
||||||
modelId,
|
if (this.options.awsUseCrossRegionInference) {
|
||||||
messages: formattedMessages,
|
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
|
||||||
system: [{ text: systemPrompt }],
|
switch (regionPrefix) {
|
||||||
inferenceConfig: {
|
case "us-":
|
||||||
maxTokens: modelConfig.info.maxTokens || 5000,
|
modelId = `us.${modelConfig.id}`
|
||||||
temperature: 0.3,
|
break
|
||||||
topP: 0.1,
|
case "eu-":
|
||||||
...(this.options.awsUsePromptCache ? {
|
modelId = `eu.${modelConfig.id}`
|
||||||
promptCache: {
|
break
|
||||||
promptCacheId: this.options.awspromptCacheId || ""
|
default:
|
||||||
}
|
modelId = modelConfig.id
|
||||||
} : {})
|
break
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
|
modelId = modelConfig.id
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
// Convert messages to Bedrock format
|
||||||
const command = new ConverseStreamCommand(payload)
|
const formattedMessages = convertToBedrockConverseMessages(messages)
|
||||||
const response = await this.client.send(command)
|
|
||||||
|
|
||||||
if (!response.stream) {
|
// Construct the payload
|
||||||
throw new Error('No stream available in the response')
|
const payload = {
|
||||||
}
|
modelId,
|
||||||
|
messages: formattedMessages,
|
||||||
|
system: [{ text: systemPrompt }],
|
||||||
|
inferenceConfig: {
|
||||||
|
maxTokens: modelConfig.info.maxTokens || 5000,
|
||||||
|
temperature: 0.3,
|
||||||
|
topP: 0.1,
|
||||||
|
...(this.options.awsUsePromptCache
|
||||||
|
? {
|
||||||
|
promptCache: {
|
||||||
|
promptCacheId: this.options.awspromptCacheId || "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
: {}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
for await (const chunk of response.stream) {
|
try {
|
||||||
// Parse the chunk as JSON if it's a string (for tests)
|
const command = new ConverseStreamCommand(payload)
|
||||||
let streamEvent: StreamEvent
|
const response = await this.client.send(command)
|
||||||
try {
|
|
||||||
streamEvent = typeof chunk === 'string' ?
|
|
||||||
JSON.parse(chunk) :
|
|
||||||
chunk as unknown as StreamEvent
|
|
||||||
} catch (e) {
|
|
||||||
console.error('Failed to parse stream event:', e)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle metadata events first
|
if (!response.stream) {
|
||||||
if (streamEvent.metadata?.usage) {
|
throw new Error("No stream available in the response")
|
||||||
yield {
|
}
|
||||||
type: "usage",
|
|
||||||
inputTokens: streamEvent.metadata.usage.inputTokens || 0,
|
|
||||||
outputTokens: streamEvent.metadata.usage.outputTokens || 0
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle message start
|
for await (const chunk of response.stream) {
|
||||||
if (streamEvent.messageStart) {
|
// Parse the chunk as JSON if it's a string (for tests)
|
||||||
continue
|
let streamEvent: StreamEvent
|
||||||
}
|
try {
|
||||||
|
streamEvent = typeof chunk === "string" ? JSON.parse(chunk) : (chunk as unknown as StreamEvent)
|
||||||
|
} catch (e) {
|
||||||
|
console.error("Failed to parse stream event:", e)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Handle content blocks
|
// Handle metadata events first
|
||||||
if (streamEvent.contentBlockStart?.start?.text) {
|
if (streamEvent.metadata?.usage) {
|
||||||
yield {
|
yield {
|
||||||
type: "text",
|
type: "usage",
|
||||||
text: streamEvent.contentBlockStart.start.text
|
inputTokens: streamEvent.metadata.usage.inputTokens || 0,
|
||||||
}
|
outputTokens: streamEvent.metadata.usage.outputTokens || 0,
|
||||||
continue
|
}
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Handle content deltas
|
// Handle message start
|
||||||
if (streamEvent.contentBlockDelta?.delta?.text) {
|
if (streamEvent.messageStart) {
|
||||||
yield {
|
continue
|
||||||
type: "text",
|
}
|
||||||
text: streamEvent.contentBlockDelta.delta.text
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle message stop
|
// Handle content blocks
|
||||||
if (streamEvent.messageStop) {
|
if (streamEvent.contentBlockStart?.start?.text) {
|
||||||
continue
|
yield {
|
||||||
}
|
type: "text",
|
||||||
}
|
text: streamEvent.contentBlockStart.start.text,
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
} catch (error: unknown) {
|
// Handle content deltas
|
||||||
console.error('Bedrock Runtime API Error:', error)
|
if (streamEvent.contentBlockDelta?.delta?.text) {
|
||||||
// Only access stack if error is an Error object
|
yield {
|
||||||
if (error instanceof Error) {
|
type: "text",
|
||||||
console.error('Error stack:', error.stack)
|
text: streamEvent.contentBlockDelta.delta.text,
|
||||||
yield {
|
}
|
||||||
type: "text",
|
continue
|
||||||
text: `Error: ${error.message}`
|
}
|
||||||
}
|
|
||||||
yield {
|
|
||||||
type: "usage",
|
|
||||||
inputTokens: 0,
|
|
||||||
outputTokens: 0
|
|
||||||
}
|
|
||||||
throw error
|
|
||||||
} else {
|
|
||||||
const unknownError = new Error("An unknown error occurred")
|
|
||||||
yield {
|
|
||||||
type: "text",
|
|
||||||
text: unknownError.message
|
|
||||||
}
|
|
||||||
yield {
|
|
||||||
type: "usage",
|
|
||||||
inputTokens: 0,
|
|
||||||
outputTokens: 0
|
|
||||||
}
|
|
||||||
throw unknownError
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
getModel(): { id: BedrockModelId | string; info: ModelInfo } {
|
// Handle message stop
|
||||||
const modelId = this.options.apiModelId
|
if (streamEvent.messageStop) {
|
||||||
if (modelId) {
|
continue
|
||||||
// For tests, allow any model ID
|
}
|
||||||
if (process.env.NODE_ENV === 'test') {
|
}
|
||||||
return {
|
} catch (error: unknown) {
|
||||||
id: modelId,
|
console.error("Bedrock Runtime API Error:", error)
|
||||||
info: {
|
// Only access stack if error is an Error object
|
||||||
maxTokens: 5000,
|
if (error instanceof Error) {
|
||||||
contextWindow: 128_000,
|
console.error("Error stack:", error.stack)
|
||||||
supportsPromptCache: false
|
yield {
|
||||||
}
|
type: "text",
|
||||||
}
|
text: `Error: ${error.message}`,
|
||||||
}
|
}
|
||||||
// For production, validate against known models
|
yield {
|
||||||
if (modelId in bedrockModels) {
|
type: "usage",
|
||||||
const id = modelId as BedrockModelId
|
inputTokens: 0,
|
||||||
return { id, info: bedrockModels[id] }
|
outputTokens: 0,
|
||||||
}
|
}
|
||||||
}
|
throw error
|
||||||
return {
|
} else {
|
||||||
id: bedrockDefaultModelId,
|
const unknownError = new Error("An unknown error occurred")
|
||||||
info: bedrockModels[bedrockDefaultModelId]
|
yield {
|
||||||
}
|
type: "text",
|
||||||
}
|
text: unknownError.message,
|
||||||
|
}
|
||||||
|
yield {
|
||||||
|
type: "usage",
|
||||||
|
inputTokens: 0,
|
||||||
|
outputTokens: 0,
|
||||||
|
}
|
||||||
|
throw unknownError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async completePrompt(prompt: string): Promise<string> {
|
getModel(): { id: BedrockModelId | string; info: ModelInfo } {
|
||||||
try {
|
const modelId = this.options.apiModelId
|
||||||
const modelConfig = this.getModel()
|
if (modelId) {
|
||||||
|
// For tests, allow any model ID
|
||||||
// Handle cross-region inference
|
if (process.env.NODE_ENV === "test") {
|
||||||
let modelId: string
|
return {
|
||||||
if (this.options.awsUseCrossRegionInference) {
|
id: modelId,
|
||||||
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
|
info: {
|
||||||
switch (regionPrefix) {
|
maxTokens: 5000,
|
||||||
case "us-":
|
contextWindow: 128_000,
|
||||||
modelId = `us.${modelConfig.id}`
|
supportsPromptCache: false,
|
||||||
break
|
},
|
||||||
case "eu-":
|
}
|
||||||
modelId = `eu.${modelConfig.id}`
|
}
|
||||||
break
|
// For production, validate against known models
|
||||||
default:
|
if (modelId in bedrockModels) {
|
||||||
modelId = modelConfig.id
|
const id = modelId as BedrockModelId
|
||||||
break
|
return { id, info: bedrockModels[id] }
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
modelId = modelConfig.id
|
return {
|
||||||
}
|
id: bedrockDefaultModelId,
|
||||||
|
info: bedrockModels[bedrockDefaultModelId],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const payload = {
|
async completePrompt(prompt: string): Promise<string> {
|
||||||
modelId,
|
try {
|
||||||
messages: convertToBedrockConverseMessages([{
|
const modelConfig = this.getModel()
|
||||||
role: "user",
|
|
||||||
content: prompt
|
|
||||||
}]),
|
|
||||||
inferenceConfig: {
|
|
||||||
maxTokens: modelConfig.info.maxTokens || 5000,
|
|
||||||
temperature: 0.3,
|
|
||||||
topP: 0.1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const command = new ConverseCommand(payload)
|
// Handle cross-region inference
|
||||||
const response = await this.client.send(command)
|
let modelId: string
|
||||||
|
if (this.options.awsUseCrossRegionInference) {
|
||||||
|
let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
|
||||||
|
switch (regionPrefix) {
|
||||||
|
case "us-":
|
||||||
|
modelId = `us.${modelConfig.id}`
|
||||||
|
break
|
||||||
|
case "eu-":
|
||||||
|
modelId = `eu.${modelConfig.id}`
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
modelId = modelConfig.id
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
modelId = modelConfig.id
|
||||||
|
}
|
||||||
|
|
||||||
if (response.output && response.output instanceof Uint8Array) {
|
const payload = {
|
||||||
try {
|
modelId,
|
||||||
const outputStr = new TextDecoder().decode(response.output)
|
messages: convertToBedrockConverseMessages([
|
||||||
const output = JSON.parse(outputStr)
|
{
|
||||||
if (output.content) {
|
role: "user",
|
||||||
return output.content
|
content: prompt,
|
||||||
}
|
},
|
||||||
} catch (parseError) {
|
]),
|
||||||
console.error('Failed to parse Bedrock response:', parseError)
|
inferenceConfig: {
|
||||||
}
|
maxTokens: modelConfig.info.maxTokens || 5000,
|
||||||
}
|
temperature: 0.3,
|
||||||
return ''
|
topP: 0.1,
|
||||||
} catch (error) {
|
},
|
||||||
if (error instanceof Error) {
|
}
|
||||||
throw new Error(`Bedrock completion error: ${error.message}`)
|
|
||||||
}
|
const command = new ConverseCommand(payload)
|
||||||
throw error
|
const response = await this.client.send(command)
|
||||||
}
|
|
||||||
}
|
if (response.output && response.output instanceof Uint8Array) {
|
||||||
|
try {
|
||||||
|
const outputStr = new TextDecoder().decode(response.output)
|
||||||
|
const output = JSON.parse(outputStr)
|
||||||
|
if (output.content) {
|
||||||
|
return output.content
|
||||||
|
}
|
||||||
|
} catch (parseError) {
|
||||||
|
console.error("Failed to parse Bedrock response:", parseError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error) {
|
||||||
|
throw new Error(`Bedrock completion error: ${error.message}`)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,24 +3,24 @@ import { ApiHandlerOptions, ModelInfo } from "../../shared/api"
|
|||||||
import { deepSeekModels, deepSeekDefaultModelId } from "../../shared/api"
|
import { deepSeekModels, deepSeekDefaultModelId } from "../../shared/api"
|
||||||
|
|
||||||
export class DeepSeekHandler extends OpenAiHandler {
|
export class DeepSeekHandler extends OpenAiHandler {
|
||||||
constructor(options: ApiHandlerOptions) {
|
constructor(options: ApiHandlerOptions) {
|
||||||
if (!options.deepSeekApiKey) {
|
if (!options.deepSeekApiKey) {
|
||||||
throw new Error("DeepSeek API key is required. Please provide it in the settings.")
|
throw new Error("DeepSeek API key is required. Please provide it in the settings.")
|
||||||
}
|
}
|
||||||
super({
|
super({
|
||||||
...options,
|
...options,
|
||||||
openAiApiKey: options.deepSeekApiKey,
|
openAiApiKey: options.deepSeekApiKey,
|
||||||
openAiModelId: options.deepSeekModelId ?? deepSeekDefaultModelId,
|
openAiModelId: options.deepSeekModelId ?? deepSeekDefaultModelId,
|
||||||
openAiBaseUrl: options.deepSeekBaseUrl ?? "https://api.deepseek.com/v1",
|
openAiBaseUrl: options.deepSeekBaseUrl ?? "https://api.deepseek.com/v1",
|
||||||
includeMaxTokens: true
|
includeMaxTokens: true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
override getModel(): { id: string; info: ModelInfo } {
|
override getModel(): { id: string; info: ModelInfo } {
|
||||||
const modelId = this.options.deepSeekModelId ?? deepSeekDefaultModelId
|
const modelId = this.options.deepSeekModelId ?? deepSeekDefaultModelId
|
||||||
return {
|
return {
|
||||||
id: modelId,
|
id: modelId,
|
||||||
info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId]
|
info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,17 +72,17 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
maxTokens = 8_192
|
maxTokens = 8_192
|
||||||
}
|
}
|
||||||
|
|
||||||
const { data: completion, response } = await this.client.chat.completions.create({
|
const { data: completion, response } = await this.client.chat.completions
|
||||||
model: this.getModel().id,
|
.create({
|
||||||
max_tokens: maxTokens,
|
model: this.getModel().id,
|
||||||
temperature: 0,
|
max_tokens: maxTokens,
|
||||||
messages: openAiMessages,
|
temperature: 0,
|
||||||
stream: true,
|
messages: openAiMessages,
|
||||||
}).withResponse();
|
stream: true,
|
||||||
|
})
|
||||||
|
.withResponse()
|
||||||
|
|
||||||
const completionRequestId = response.headers.get(
|
const completionRequestId = response.headers.get("x-completion-request-id")
|
||||||
'x-completion-request-id',
|
|
||||||
);
|
|
||||||
|
|
||||||
for await (const chunk of completion) {
|
for await (const chunk of completion) {
|
||||||
const delta = chunk.choices[0]?.delta
|
const delta = chunk.choices[0]?.delta
|
||||||
@@ -96,13 +96,16 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await axios.get(`https://glama.ai/api/gateway/v1/completion-requests/${completionRequestId}`, {
|
const response = await axios.get(
|
||||||
headers: {
|
`https://glama.ai/api/gateway/v1/completion-requests/${completionRequestId}`,
|
||||||
Authorization: `Bearer ${this.options.glamaApiKey}`,
|
{
|
||||||
|
headers: {
|
||||||
|
Authorization: `Bearer ${this.options.glamaApiKey}`,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
})
|
)
|
||||||
|
|
||||||
const completionRequest = response.data;
|
const completionRequest = response.data
|
||||||
|
|
||||||
if (completionRequest.tokenUsage) {
|
if (completionRequest.tokenUsage) {
|
||||||
yield {
|
yield {
|
||||||
@@ -113,7 +116,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
outputTokens: completionRequest.tokenUsage.completionTokens,
|
outputTokens: completionRequest.tokenUsage.completionTokens,
|
||||||
totalCost: parseFloat(completionRequest.totalCostUsd),
|
totalCost: parseFloat(completionRequest.totalCostUsd),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error fetching Glama completion details", error)
|
console.error("Error fetching Glama completion details", error)
|
||||||
}
|
}
|
||||||
@@ -126,7 +129,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
if (modelId && modelInfo) {
|
if (modelId && modelInfo) {
|
||||||
return { id: modelId, info: modelInfo }
|
return { id: modelId, info: modelInfo }
|
||||||
}
|
}
|
||||||
|
|
||||||
return { id: glamaDefaultModelId, info: glamaDefaultModelInfo }
|
return { id: glamaDefaultModelId, info: glamaDefaultModelInfo }
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,7 +144,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
if (this.getModel().id.startsWith("anthropic/")) {
|
if (this.getModel().id.startsWith("anthropic/")) {
|
||||||
requestOptions.max_tokens = 8192
|
requestOptions.max_tokens = 8192
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await this.client.chat.completions.create(requestOptions)
|
const response = await this.client.chat.completions.create(requestOptions)
|
||||||
return response.choices[0]?.message.content || ""
|
return response.choices[0]?.message.content || ""
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ export class LmStudioHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
model: this.getModel().id,
|
model: this.getModel().id,
|
||||||
messages: [{ role: "user", content: prompt }],
|
messages: [{ role: "user", content: prompt }],
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
stream: false
|
stream: false,
|
||||||
})
|
})
|
||||||
return response.choices[0]?.message.content || ""
|
return response.choices[0]?.message.content || ""
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ export class OllamaHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
model: this.getModel().id,
|
model: this.getModel().id,
|
||||||
messages: [{ role: "user", content: prompt }],
|
messages: [{ role: "user", content: prompt }],
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
stream: false
|
stream: false,
|
||||||
})
|
})
|
||||||
return response.choices[0]?.message.content || ""
|
return response.choices[0]?.message.content || ""
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|||||||
@@ -32,7 +32,10 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
|
|||||||
// o1 doesnt support streaming or non-1 temp but does support a developer prompt
|
// o1 doesnt support streaming or non-1 temp but does support a developer prompt
|
||||||
const response = await this.client.chat.completions.create({
|
const response = await this.client.chat.completions.create({
|
||||||
model: modelId,
|
model: modelId,
|
||||||
messages: [{ role: modelId === "o1" ? "developer" : "user", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
|
messages: [
|
||||||
|
{ role: modelId === "o1" ? "developer" : "user", content: systemPrompt },
|
||||||
|
...convertToOpenAiMessages(messages),
|
||||||
|
],
|
||||||
})
|
})
|
||||||
yield {
|
yield {
|
||||||
type: "text",
|
type: "text",
|
||||||
@@ -98,14 +101,14 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
|
|||||||
// o1 doesn't support non-1 temp
|
// o1 doesn't support non-1 temp
|
||||||
requestOptions = {
|
requestOptions = {
|
||||||
model: modelId,
|
model: modelId,
|
||||||
messages: [{ role: "user", content: prompt }]
|
messages: [{ role: "user", content: prompt }],
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
default:
|
default:
|
||||||
requestOptions = {
|
requestOptions = {
|
||||||
model: modelId,
|
model: modelId,
|
||||||
messages: [{ role: "user", content: prompt }],
|
messages: [{ role: "user", content: prompt }],
|
||||||
temperature: 0
|
temperature: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
constructor(options: ApiHandlerOptions) {
|
constructor(options: ApiHandlerOptions) {
|
||||||
this.options = options
|
this.options = options
|
||||||
// Azure API shape slightly differs from the core API shape: https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
|
// Azure API shape slightly differs from the core API shape: https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
|
||||||
const urlHost = new URL(this.options.openAiBaseUrl ?? "").host;
|
const urlHost = new URL(this.options.openAiBaseUrl ?? "").host
|
||||||
if (urlHost === "azure.com" || urlHost.endsWith(".azure.com")) {
|
if (urlHost === "azure.com" || urlHost.endsWith(".azure.com")) {
|
||||||
this.client = new AzureOpenAI({
|
this.client = new AzureOpenAI({
|
||||||
baseURL: this.options.openAiBaseUrl,
|
baseURL: this.options.openAiBaseUrl,
|
||||||
@@ -39,7 +39,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
if (this.options.openAiStreamingEnabled ?? true) {
|
if (this.options.openAiStreamingEnabled ?? true) {
|
||||||
const systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = {
|
const systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = {
|
||||||
role: "system",
|
role: "system",
|
||||||
content: systemPrompt
|
content: systemPrompt,
|
||||||
}
|
}
|
||||||
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
|
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
|
||||||
model: modelId,
|
model: modelId,
|
||||||
@@ -74,14 +74,14 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
// o1 for instance doesnt support streaming, non-1 temp, or system prompt
|
// o1 for instance doesnt support streaming, non-1 temp, or system prompt
|
||||||
const systemMessage: OpenAI.Chat.ChatCompletionUserMessageParam = {
|
const systemMessage: OpenAI.Chat.ChatCompletionUserMessageParam = {
|
||||||
role: "user",
|
role: "user",
|
||||||
content: systemPrompt
|
content: systemPrompt,
|
||||||
}
|
}
|
||||||
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
|
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
|
||||||
model: modelId,
|
model: modelId,
|
||||||
messages: [systemMessage, ...convertToOpenAiMessages(messages)],
|
messages: [systemMessage, ...convertToOpenAiMessages(messages)],
|
||||||
}
|
}
|
||||||
const response = await this.client.chat.completions.create(requestOptions)
|
const response = await this.client.chat.completions.create(requestOptions)
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
type: "text",
|
type: "text",
|
||||||
text: response.choices[0]?.message.content || "",
|
text: response.choices[0]?.message.content || "",
|
||||||
@@ -108,7 +108,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
messages: [{ role: "user", content: prompt }],
|
messages: [{ role: "user", content: prompt }],
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await this.client.chat.completions.create(requestOptions)
|
const response = await this.client.chat.completions.create(requestOptions)
|
||||||
return response.choices[0]?.message.content || ""
|
return response.choices[0]?.message.content || ""
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|||||||
@@ -9,12 +9,12 @@ import delay from "delay"
|
|||||||
|
|
||||||
// Add custom interface for OpenRouter params
|
// Add custom interface for OpenRouter params
|
||||||
type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & {
|
type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & {
|
||||||
transforms?: string[];
|
transforms?: string[]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add custom interface for OpenRouter usage chunk
|
// Add custom interface for OpenRouter usage chunk
|
||||||
interface OpenRouterApiStreamUsageChunk extends ApiStreamUsageChunk {
|
interface OpenRouterApiStreamUsageChunk extends ApiStreamUsageChunk {
|
||||||
fullResponseText: string;
|
fullResponseText: string
|
||||||
}
|
}
|
||||||
|
|
||||||
import { SingleCompletionHandler } from ".."
|
import { SingleCompletionHandler } from ".."
|
||||||
@@ -35,7 +35,10 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): AsyncGenerator<ApiStreamChunk> {
|
async *createMessage(
|
||||||
|
systemPrompt: string,
|
||||||
|
messages: Anthropic.Messages.MessageParam[],
|
||||||
|
): AsyncGenerator<ApiStreamChunk> {
|
||||||
// Convert Anthropic messages to OpenAI format
|
// Convert Anthropic messages to OpenAI format
|
||||||
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
|
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
|
||||||
{ role: "system", content: systemPrompt },
|
{ role: "system", content: systemPrompt },
|
||||||
@@ -108,7 +111,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
// https://openrouter.ai/docs/transforms
|
// https://openrouter.ai/docs/transforms
|
||||||
let fullResponseText = "";
|
let fullResponseText = ""
|
||||||
const stream = await this.client.chat.completions.create({
|
const stream = await this.client.chat.completions.create({
|
||||||
model: this.getModel().id,
|
model: this.getModel().id,
|
||||||
max_tokens: maxTokens,
|
max_tokens: maxTokens,
|
||||||
@@ -116,8 +119,8 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
messages: openAiMessages,
|
messages: openAiMessages,
|
||||||
stream: true,
|
stream: true,
|
||||||
// This way, the transforms field will only be included in the parameters when openRouterUseMiddleOutTransform is true.
|
// This way, the transforms field will only be included in the parameters when openRouterUseMiddleOutTransform is true.
|
||||||
...(this.options.openRouterUseMiddleOutTransform && { transforms: ["middle-out"] })
|
...(this.options.openRouterUseMiddleOutTransform && { transforms: ["middle-out"] }),
|
||||||
} as OpenRouterChatCompletionParams);
|
} as OpenRouterChatCompletionParams)
|
||||||
|
|
||||||
let genId: string | undefined
|
let genId: string | undefined
|
||||||
|
|
||||||
@@ -135,11 +138,11 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
|
|
||||||
const delta = chunk.choices[0]?.delta
|
const delta = chunk.choices[0]?.delta
|
||||||
if (delta?.content) {
|
if (delta?.content) {
|
||||||
fullResponseText += delta.content;
|
fullResponseText += delta.content
|
||||||
yield {
|
yield {
|
||||||
type: "text",
|
type: "text",
|
||||||
text: delta.content,
|
text: delta.content,
|
||||||
} as ApiStreamChunk;
|
} as ApiStreamChunk
|
||||||
}
|
}
|
||||||
// if (chunk.usage) {
|
// if (chunk.usage) {
|
||||||
// yield {
|
// yield {
|
||||||
@@ -170,13 +173,12 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
inputTokens: generation?.native_tokens_prompt || 0,
|
inputTokens: generation?.native_tokens_prompt || 0,
|
||||||
outputTokens: generation?.native_tokens_completion || 0,
|
outputTokens: generation?.native_tokens_completion || 0,
|
||||||
totalCost: generation?.total_cost || 0,
|
totalCost: generation?.total_cost || 0,
|
||||||
fullResponseText
|
fullResponseText,
|
||||||
} as OpenRouterApiStreamUsageChunk;
|
} as OpenRouterApiStreamUsageChunk
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// ignore if fails
|
// ignore if fails
|
||||||
console.error("Error fetching OpenRouter generation details:", error)
|
console.error("Error fetching OpenRouter generation details:", error)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
getModel(): { id: string; info: ModelInfo } {
|
getModel(): { id: string; info: ModelInfo } {
|
||||||
const modelId = this.options.openRouterModelId
|
const modelId = this.options.openRouterModelId
|
||||||
@@ -193,7 +195,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
model: this.getModel().id,
|
model: this.getModel().id,
|
||||||
messages: [{ role: "user", content: prompt }],
|
messages: [{ role: "user", content: prompt }],
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
stream: false
|
stream: false,
|
||||||
})
|
})
|
||||||
|
|
||||||
if ("error" in response) {
|
if ("error" in response) {
|
||||||
|
|||||||
@@ -91,14 +91,14 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
max_tokens: this.getModel().info.maxTokens || 8192,
|
max_tokens: this.getModel().info.maxTokens || 8192,
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
messages: [{ role: "user", content: prompt }],
|
messages: [{ role: "user", content: prompt }],
|
||||||
stream: false
|
stream: false,
|
||||||
})
|
})
|
||||||
|
|
||||||
const content = response.content[0]
|
const content = response.content[0]
|
||||||
if (content.type === 'text') {
|
if (content.type === "text") {
|
||||||
return content.text
|
return content.text
|
||||||
}
|
}
|
||||||
return ''
|
return ""
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (error instanceof Error) {
|
if (error instanceof Error) {
|
||||||
throw new Error(`Vertex completion error: ${error.message}`)
|
throw new Error(`Vertex completion error: ${error.message}`)
|
||||||
|
|||||||
@@ -1,31 +1,31 @@
|
|||||||
import { Anthropic } from "@anthropic-ai/sdk";
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import * as vscode from 'vscode';
|
import * as vscode from "vscode"
|
||||||
import { ApiHandler, SingleCompletionHandler } from "../";
|
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||||
import { calculateApiCost } from "../../utils/cost";
|
import { calculateApiCost } from "../../utils/cost"
|
||||||
import { ApiStream } from "../transform/stream";
|
import { ApiStream } from "../transform/stream"
|
||||||
import { convertToVsCodeLmMessages } from "../transform/vscode-lm-format";
|
import { convertToVsCodeLmMessages } from "../transform/vscode-lm-format"
|
||||||
import { SELECTOR_SEPARATOR, stringifyVsCodeLmModelSelector } from "../../shared/vsCodeSelectorUtils";
|
import { SELECTOR_SEPARATOR, stringifyVsCodeLmModelSelector } from "../../shared/vsCodeSelectorUtils"
|
||||||
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api";
|
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Handles interaction with VS Code's Language Model API for chat-based operations.
|
* Handles interaction with VS Code's Language Model API for chat-based operations.
|
||||||
* This handler implements the ApiHandler interface to provide VS Code LM specific functionality.
|
* This handler implements the ApiHandler interface to provide VS Code LM specific functionality.
|
||||||
*
|
*
|
||||||
* @implements {ApiHandler}
|
* @implements {ApiHandler}
|
||||||
*
|
*
|
||||||
* @remarks
|
* @remarks
|
||||||
* The handler manages a VS Code language model chat client and provides methods to:
|
* The handler manages a VS Code language model chat client and provides methods to:
|
||||||
* - Create and manage chat client instances
|
* - Create and manage chat client instances
|
||||||
* - Stream messages using VS Code's Language Model API
|
* - Stream messages using VS Code's Language Model API
|
||||||
* - Retrieve model information
|
* - Retrieve model information
|
||||||
*
|
*
|
||||||
* @example
|
* @example
|
||||||
* ```typescript
|
* ```typescript
|
||||||
* const options = {
|
* const options = {
|
||||||
* vsCodeLmModelSelector: { vendor: "copilot", family: "gpt-4" }
|
* vsCodeLmModelSelector: { vendor: "copilot", family: "gpt-4" }
|
||||||
* };
|
* };
|
||||||
* const handler = new VsCodeLmHandler(options);
|
* const handler = new VsCodeLmHandler(options);
|
||||||
*
|
*
|
||||||
* // Stream a conversation
|
* // Stream a conversation
|
||||||
* const systemPrompt = "You are a helpful assistant";
|
* const systemPrompt = "You are a helpful assistant";
|
||||||
* const messages = [{ role: "user", content: "Hello!" }];
|
* const messages = [{ role: "user", content: "Hello!" }];
|
||||||
@@ -35,39 +35,36 @@ import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../..
|
|||||||
* ```
|
* ```
|
||||||
*/
|
*/
|
||||||
export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
||||||
|
private options: ApiHandlerOptions
|
||||||
private options: ApiHandlerOptions;
|
private client: vscode.LanguageModelChat | null
|
||||||
private client: vscode.LanguageModelChat | null;
|
private disposable: vscode.Disposable | null
|
||||||
private disposable: vscode.Disposable | null;
|
private currentRequestCancellation: vscode.CancellationTokenSource | null
|
||||||
private currentRequestCancellation: vscode.CancellationTokenSource | null;
|
|
||||||
|
|
||||||
constructor(options: ApiHandlerOptions) {
|
constructor(options: ApiHandlerOptions) {
|
||||||
this.options = options;
|
this.options = options
|
||||||
this.client = null;
|
this.client = null
|
||||||
this.disposable = null;
|
this.disposable = null
|
||||||
this.currentRequestCancellation = null;
|
this.currentRequestCancellation = null
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Listen for model changes and reset client
|
// Listen for model changes and reset client
|
||||||
this.disposable = vscode.workspace.onDidChangeConfiguration(event => {
|
this.disposable = vscode.workspace.onDidChangeConfiguration((event) => {
|
||||||
if (event.affectsConfiguration('lm')) {
|
if (event.affectsConfiguration("lm")) {
|
||||||
try {
|
try {
|
||||||
this.client = null;
|
this.client = null
|
||||||
this.ensureCleanState();
|
this.ensureCleanState()
|
||||||
}
|
} catch (error) {
|
||||||
catch (error) {
|
console.error("Error during configuration change cleanup:", error)
|
||||||
console.error('Error during configuration change cleanup:', error);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
})
|
||||||
}
|
} catch (error) {
|
||||||
catch (error) {
|
|
||||||
// Ensure cleanup if constructor fails
|
// Ensure cleanup if constructor fails
|
||||||
this.dispose();
|
this.dispose()
|
||||||
|
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`Cline <Language Model API>: Failed to initialize handler: ${error instanceof Error ? error.message : 'Unknown error'}`
|
`Cline <Language Model API>: Failed to initialize handler: ${error instanceof Error ? error.message : "Unknown error"}`,
|
||||||
);
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,46 +74,46 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
* @param selector - Selector criteria to filter language model chat instances
|
* @param selector - Selector criteria to filter language model chat instances
|
||||||
* @returns Promise resolving to the first matching language model chat instance
|
* @returns Promise resolving to the first matching language model chat instance
|
||||||
* @throws Error when no matching models are found with the given selector
|
* @throws Error when no matching models are found with the given selector
|
||||||
*
|
*
|
||||||
* @example
|
* @example
|
||||||
* const selector = { vendor: "copilot", family: "gpt-4o" };
|
* const selector = { vendor: "copilot", family: "gpt-4o" };
|
||||||
* const chatClient = await createClient(selector);
|
* const chatClient = await createClient(selector);
|
||||||
*/
|
*/
|
||||||
async createClient(selector: vscode.LanguageModelChatSelector): Promise<vscode.LanguageModelChat> {
|
async createClient(selector: vscode.LanguageModelChatSelector): Promise<vscode.LanguageModelChat> {
|
||||||
try {
|
try {
|
||||||
const models = await vscode.lm.selectChatModels(selector);
|
const models = await vscode.lm.selectChatModels(selector)
|
||||||
|
|
||||||
// Use first available model or create a minimal model object
|
// Use first available model or create a minimal model object
|
||||||
if (models && Array.isArray(models) && models.length > 0) {
|
if (models && Array.isArray(models) && models.length > 0) {
|
||||||
return models[0];
|
return models[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a minimal model if no models are available
|
// Create a minimal model if no models are available
|
||||||
return {
|
return {
|
||||||
id: 'default-lm',
|
id: "default-lm",
|
||||||
name: 'Default Language Model',
|
name: "Default Language Model",
|
||||||
vendor: 'vscode',
|
vendor: "vscode",
|
||||||
family: 'lm',
|
family: "lm",
|
||||||
version: '1.0',
|
version: "1.0",
|
||||||
maxInputTokens: 8192,
|
maxInputTokens: 8192,
|
||||||
sendRequest: async (messages, options, token) => {
|
sendRequest: async (messages, options, token) => {
|
||||||
// Provide a minimal implementation
|
// Provide a minimal implementation
|
||||||
return {
|
return {
|
||||||
stream: (async function* () {
|
stream: (async function* () {
|
||||||
yield new vscode.LanguageModelTextPart(
|
yield new vscode.LanguageModelTextPart(
|
||||||
"Language model functionality is limited. Please check VS Code configuration."
|
"Language model functionality is limited. Please check VS Code configuration.",
|
||||||
);
|
)
|
||||||
})(),
|
})(),
|
||||||
text: (async function* () {
|
text: (async function* () {
|
||||||
yield "Language model functionality is limited. Please check VS Code configuration.";
|
yield "Language model functionality is limited. Please check VS Code configuration."
|
||||||
})()
|
})(),
|
||||||
};
|
}
|
||||||
},
|
},
|
||||||
countTokens: async () => 0
|
countTokens: async () => 0,
|
||||||
};
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
const errorMessage = error instanceof Error ? error.message : "Unknown error"
|
||||||
throw new Error(`Cline <Language Model API>: Failed to select model: ${errorMessage}`);
|
throw new Error(`Cline <Language Model API>: Failed to select model: ${errorMessage}`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,242 +122,234 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
*
|
*
|
||||||
* @param systemPrompt - The system prompt to initialize the conversation context
|
* @param systemPrompt - The system prompt to initialize the conversation context
|
||||||
* @param messages - An array of message parameters following the Anthropic message format
|
* @param messages - An array of message parameters following the Anthropic message format
|
||||||
*
|
*
|
||||||
* @yields {ApiStream} An async generator that yields either text chunks or tool calls from the model response
|
* @yields {ApiStream} An async generator that yields either text chunks or tool calls from the model response
|
||||||
*
|
*
|
||||||
* @throws {Error} When vsCodeLmModelSelector option is not provided
|
* @throws {Error} When vsCodeLmModelSelector option is not provided
|
||||||
* @throws {Error} When the response stream encounters an error
|
* @throws {Error} When the response stream encounters an error
|
||||||
*
|
*
|
||||||
* @remarks
|
* @remarks
|
||||||
* This method handles the initialization of the VS Code LM client if not already created,
|
* This method handles the initialization of the VS Code LM client if not already created,
|
||||||
* converts the messages to VS Code LM format, and streams the response chunks.
|
* converts the messages to VS Code LM format, and streams the response chunks.
|
||||||
* Tool calls handling is currently a work in progress.
|
* Tool calls handling is currently a work in progress.
|
||||||
*/
|
*/
|
||||||
dispose(): void {
|
dispose(): void {
|
||||||
|
|
||||||
if (this.disposable) {
|
if (this.disposable) {
|
||||||
|
this.disposable.dispose()
|
||||||
this.disposable.dispose();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this.currentRequestCancellation) {
|
if (this.currentRequestCancellation) {
|
||||||
|
this.currentRequestCancellation.cancel()
|
||||||
this.currentRequestCancellation.cancel();
|
this.currentRequestCancellation.dispose()
|
||||||
this.currentRequestCancellation.dispose();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async countTokens(text: string | vscode.LanguageModelChatMessage): Promise<number> {
|
private async countTokens(text: string | vscode.LanguageModelChatMessage): Promise<number> {
|
||||||
// Check for required dependencies
|
// Check for required dependencies
|
||||||
if (!this.client) {
|
if (!this.client) {
|
||||||
console.warn('Cline <Language Model API>: No client available for token counting');
|
console.warn("Cline <Language Model API>: No client available for token counting")
|
||||||
return 0;
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!this.currentRequestCancellation) {
|
if (!this.currentRequestCancellation) {
|
||||||
console.warn('Cline <Language Model API>: No cancellation token available for token counting');
|
console.warn("Cline <Language Model API>: No cancellation token available for token counting")
|
||||||
return 0;
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate input
|
// Validate input
|
||||||
if (!text) {
|
if (!text) {
|
||||||
console.debug('Cline <Language Model API>: Empty text provided for token counting');
|
console.debug("Cline <Language Model API>: Empty text provided for token counting")
|
||||||
return 0;
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Handle different input types
|
// Handle different input types
|
||||||
let tokenCount: number;
|
let tokenCount: number
|
||||||
|
|
||||||
if (typeof text === 'string') {
|
if (typeof text === "string") {
|
||||||
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token);
|
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
|
||||||
} else if (text instanceof vscode.LanguageModelChatMessage) {
|
} else if (text instanceof vscode.LanguageModelChatMessage) {
|
||||||
// For chat messages, ensure we have content
|
// For chat messages, ensure we have content
|
||||||
if (!text.content || (Array.isArray(text.content) && text.content.length === 0)) {
|
if (!text.content || (Array.isArray(text.content) && text.content.length === 0)) {
|
||||||
console.debug('Cline <Language Model API>: Empty chat message content');
|
console.debug("Cline <Language Model API>: Empty chat message content")
|
||||||
return 0;
|
return 0
|
||||||
}
|
}
|
||||||
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token);
|
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
|
||||||
} else {
|
} else {
|
||||||
console.warn('Cline <Language Model API>: Invalid input type for token counting');
|
console.warn("Cline <Language Model API>: Invalid input type for token counting")
|
||||||
return 0;
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate the result
|
// Validate the result
|
||||||
if (typeof tokenCount !== 'number') {
|
if (typeof tokenCount !== "number") {
|
||||||
console.warn('Cline <Language Model API>: Non-numeric token count received:', tokenCount);
|
console.warn("Cline <Language Model API>: Non-numeric token count received:", tokenCount)
|
||||||
return 0;
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tokenCount < 0) {
|
if (tokenCount < 0) {
|
||||||
console.warn('Cline <Language Model API>: Negative token count received:', tokenCount);
|
console.warn("Cline <Language Model API>: Negative token count received:", tokenCount)
|
||||||
return 0;
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
return tokenCount;
|
return tokenCount
|
||||||
}
|
} catch (error) {
|
||||||
catch (error) {
|
|
||||||
// Handle specific error types
|
// Handle specific error types
|
||||||
if (error instanceof vscode.CancellationError) {
|
if (error instanceof vscode.CancellationError) {
|
||||||
console.debug('Cline <Language Model API>: Token counting cancelled by user');
|
console.debug("Cline <Language Model API>: Token counting cancelled by user")
|
||||||
return 0;
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
const errorMessage = error instanceof Error ? error.message : "Unknown error"
|
||||||
console.warn('Cline <Language Model API>: Token counting failed:', errorMessage);
|
console.warn("Cline <Language Model API>: Token counting failed:", errorMessage)
|
||||||
|
|
||||||
// Log additional error details if available
|
// Log additional error details if available
|
||||||
if (error instanceof Error && error.stack) {
|
if (error instanceof Error && error.stack) {
|
||||||
console.debug('Token counting error stack:', error.stack);
|
console.debug("Token counting error stack:", error.stack)
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0; // Fallback to prevent stream interruption
|
return 0 // Fallback to prevent stream interruption
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async calculateTotalInputTokens(systemPrompt: string, vsCodeLmMessages: vscode.LanguageModelChatMessage[]): Promise<number> {
|
private async calculateTotalInputTokens(
|
||||||
|
systemPrompt: string,
|
||||||
|
vsCodeLmMessages: vscode.LanguageModelChatMessage[],
|
||||||
|
): Promise<number> {
|
||||||
|
const systemTokens: number = await this.countTokens(systemPrompt)
|
||||||
|
|
||||||
const systemTokens: number = await this.countTokens(systemPrompt);
|
const messageTokens: number[] = await Promise.all(vsCodeLmMessages.map((msg) => this.countTokens(msg)))
|
||||||
|
|
||||||
const messageTokens: number[] = await Promise.all(
|
return systemTokens + messageTokens.reduce((sum: number, tokens: number): number => sum + tokens, 0)
|
||||||
vsCodeLmMessages.map(msg => this.countTokens(msg))
|
|
||||||
);
|
|
||||||
|
|
||||||
return systemTokens + messageTokens.reduce(
|
|
||||||
(sum: number, tokens: number): number => sum + tokens, 0
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private ensureCleanState(): void {
|
private ensureCleanState(): void {
|
||||||
|
|
||||||
if (this.currentRequestCancellation) {
|
if (this.currentRequestCancellation) {
|
||||||
|
this.currentRequestCancellation.cancel()
|
||||||
this.currentRequestCancellation.cancel();
|
this.currentRequestCancellation.dispose()
|
||||||
this.currentRequestCancellation.dispose();
|
this.currentRequestCancellation = null
|
||||||
this.currentRequestCancellation = null;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async getClient(): Promise<vscode.LanguageModelChat> {
|
private async getClient(): Promise<vscode.LanguageModelChat> {
|
||||||
if (!this.client) {
|
if (!this.client) {
|
||||||
console.debug('Cline <Language Model API>: Getting client with options:', {
|
console.debug("Cline <Language Model API>: Getting client with options:", {
|
||||||
vsCodeLmModelSelector: this.options.vsCodeLmModelSelector,
|
vsCodeLmModelSelector: this.options.vsCodeLmModelSelector,
|
||||||
hasOptions: !!this.options,
|
hasOptions: !!this.options,
|
||||||
selectorKeys: this.options.vsCodeLmModelSelector ? Object.keys(this.options.vsCodeLmModelSelector) : []
|
selectorKeys: this.options.vsCodeLmModelSelector ? Object.keys(this.options.vsCodeLmModelSelector) : [],
|
||||||
});
|
})
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Use default empty selector if none provided to get all available models
|
// Use default empty selector if none provided to get all available models
|
||||||
const selector = this.options?.vsCodeLmModelSelector || {};
|
const selector = this.options?.vsCodeLmModelSelector || {}
|
||||||
console.debug('Cline <Language Model API>: Creating client with selector:', selector);
|
console.debug("Cline <Language Model API>: Creating client with selector:", selector)
|
||||||
this.client = await this.createClient(selector);
|
this.client = await this.createClient(selector)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
const message = error instanceof Error ? error.message : 'Unknown error';
|
const message = error instanceof Error ? error.message : "Unknown error"
|
||||||
console.error('Cline <Language Model API>: Client creation failed:', message);
|
console.error("Cline <Language Model API>: Client creation failed:", message)
|
||||||
throw new Error(`Cline <Language Model API>: Failed to create client: ${message}`);
|
throw new Error(`Cline <Language Model API>: Failed to create client: ${message}`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return this.client;
|
return this.client
|
||||||
}
|
}
|
||||||
|
|
||||||
private cleanTerminalOutput(text: string): string {
|
private cleanTerminalOutput(text: string): string {
|
||||||
if (!text) {
|
if (!text) {
|
||||||
return '';
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
return text
|
return (
|
||||||
// Нормализуем переносы строк
|
text
|
||||||
.replace(/\r\n/g, '\n')
|
// Нормализуем переносы строк
|
||||||
.replace(/\r/g, '\n')
|
.replace(/\r\n/g, "\n")
|
||||||
|
.replace(/\r/g, "\n")
|
||||||
|
|
||||||
// Удаляем ANSI escape sequences
|
// Удаляем ANSI escape sequences
|
||||||
.replace(/\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])/g, '') // Полный набор ANSI sequences
|
.replace(/\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])/g, "") // Полный набор ANSI sequences
|
||||||
.replace(/\x9B[0-?]*[ -/]*[@-~]/g, '') // CSI sequences
|
.replace(/\x9B[0-?]*[ -/]*[@-~]/g, "") // CSI sequences
|
||||||
|
|
||||||
// Удаляем последовательности установки заголовка терминала и прочие OSC sequences
|
// Удаляем последовательности установки заголовка терминала и прочие OSC sequences
|
||||||
.replace(/\x1B\][0-9;]*(?:\x07|\x1B\\)/g, '')
|
.replace(/\x1B\][0-9;]*(?:\x07|\x1B\\)/g, "")
|
||||||
|
|
||||||
// Удаляем управляющие символы
|
// Удаляем управляющие символы
|
||||||
.replace(/[\x00-\x09\x0B-\x0C\x0E-\x1F\x7F]/g, '')
|
.replace(/[\x00-\x09\x0B-\x0C\x0E-\x1F\x7F]/g, "")
|
||||||
|
|
||||||
// Удаляем escape-последовательности VS Code
|
// Удаляем escape-последовательности VS Code
|
||||||
.replace(/\x1B[PD].*?\x1B\\/g, '') // DCS sequences
|
.replace(/\x1B[PD].*?\x1B\\/g, "") // DCS sequences
|
||||||
.replace(/\x1B_.*?\x1B\\/g, '') // APC sequences
|
.replace(/\x1B_.*?\x1B\\/g, "") // APC sequences
|
||||||
.replace(/\x1B\^.*?\x1B\\/g, '') // PM sequences
|
.replace(/\x1B\^.*?\x1B\\/g, "") // PM sequences
|
||||||
.replace(/\x1B\[[\d;]*[HfABCDEFGJKST]/g, '') // Cursor movement and clear screen
|
.replace(/\x1B\[[\d;]*[HfABCDEFGJKST]/g, "") // Cursor movement and clear screen
|
||||||
|
|
||||||
// Удаляем пути Windows и служебную информацию
|
// Удаляем пути Windows и служебную информацию
|
||||||
.replace(/^(?:PS )?[A-Z]:\\[^\n]*$/mg, '')
|
.replace(/^(?:PS )?[A-Z]:\\[^\n]*$/gm, "")
|
||||||
.replace(/^;?Cwd=.*$/mg, '')
|
.replace(/^;?Cwd=.*$/gm, "")
|
||||||
|
|
||||||
// Очищаем экранированные последовательности
|
// Очищаем экранированные последовательности
|
||||||
.replace(/\\x[0-9a-fA-F]{2}/g, '')
|
.replace(/\\x[0-9a-fA-F]{2}/g, "")
|
||||||
.replace(/\\u[0-9a-fA-F]{4}/g, '')
|
.replace(/\\u[0-9a-fA-F]{4}/g, "")
|
||||||
|
|
||||||
// Финальная очистка
|
// Финальная очистка
|
||||||
.replace(/\n{3,}/g, '\n\n') // Убираем множественные пустые строки
|
.replace(/\n{3,}/g, "\n\n") // Убираем множественные пустые строки
|
||||||
.trim();
|
.trim()
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
private cleanMessageContent(content: any): any {
|
private cleanMessageContent(content: any): any {
|
||||||
if (!content) {
|
if (!content) {
|
||||||
return content;
|
return content
|
||||||
}
|
}
|
||||||
|
|
||||||
if (typeof content === 'string') {
|
if (typeof content === "string") {
|
||||||
return this.cleanTerminalOutput(content);
|
return this.cleanTerminalOutput(content)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Array.isArray(content)) {
|
if (Array.isArray(content)) {
|
||||||
return content.map(item => this.cleanMessageContent(item));
|
return content.map((item) => this.cleanMessageContent(item))
|
||||||
}
|
}
|
||||||
|
|
||||||
if (typeof content === 'object') {
|
if (typeof content === "object") {
|
||||||
const cleaned: any = {};
|
const cleaned: any = {}
|
||||||
for (const [key, value] of Object.entries(content)) {
|
for (const [key, value] of Object.entries(content)) {
|
||||||
cleaned[key] = this.cleanMessageContent(value);
|
cleaned[key] = this.cleanMessageContent(value)
|
||||||
}
|
}
|
||||||
return cleaned;
|
return cleaned
|
||||||
}
|
}
|
||||||
|
|
||||||
return content;
|
return content
|
||||||
}
|
}
|
||||||
|
|
||||||
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
||||||
|
|
||||||
// Ensure clean state before starting a new request
|
// Ensure clean state before starting a new request
|
||||||
this.ensureCleanState();
|
this.ensureCleanState()
|
||||||
const client: vscode.LanguageModelChat = await this.getClient();
|
const client: vscode.LanguageModelChat = await this.getClient()
|
||||||
|
|
||||||
// Clean system prompt and messages
|
// Clean system prompt and messages
|
||||||
const cleanedSystemPrompt = this.cleanTerminalOutput(systemPrompt);
|
const cleanedSystemPrompt = this.cleanTerminalOutput(systemPrompt)
|
||||||
const cleanedMessages = messages.map(msg => ({
|
const cleanedMessages = messages.map((msg) => ({
|
||||||
...msg,
|
...msg,
|
||||||
content: this.cleanMessageContent(msg.content)
|
content: this.cleanMessageContent(msg.content),
|
||||||
}));
|
}))
|
||||||
|
|
||||||
// Convert Anthropic messages to VS Code LM messages
|
// Convert Anthropic messages to VS Code LM messages
|
||||||
const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = [
|
const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = [
|
||||||
vscode.LanguageModelChatMessage.Assistant(cleanedSystemPrompt),
|
vscode.LanguageModelChatMessage.Assistant(cleanedSystemPrompt),
|
||||||
...convertToVsCodeLmMessages(cleanedMessages),
|
...convertToVsCodeLmMessages(cleanedMessages),
|
||||||
];
|
]
|
||||||
|
|
||||||
// Initialize cancellation token for the request
|
// Initialize cancellation token for the request
|
||||||
this.currentRequestCancellation = new vscode.CancellationTokenSource();
|
this.currentRequestCancellation = new vscode.CancellationTokenSource()
|
||||||
|
|
||||||
// Calculate input tokens before starting the stream
|
// Calculate input tokens before starting the stream
|
||||||
const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages);
|
const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages)
|
||||||
|
|
||||||
// Accumulate the text and count at the end of the stream to reduce token counting overhead.
|
// Accumulate the text and count at the end of the stream to reduce token counting overhead.
|
||||||
let accumulatedText: string = '';
|
let accumulatedText: string = ""
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
||||||
// Create the response stream with minimal required options
|
// Create the response stream with minimal required options
|
||||||
const requestOptions: vscode.LanguageModelChatRequestOptions = {
|
const requestOptions: vscode.LanguageModelChatRequestOptions = {
|
||||||
justification: `Cline would like to use '${client.name}' from '${client.vendor}', Click 'Allow' to proceed.`
|
justification: `Cline would like to use '${client.name}' from '${client.vendor}', Click 'Allow' to proceed.`,
|
||||||
};
|
}
|
||||||
|
|
||||||
// Note: Tool support is currently provided by the VSCode Language Model API directly
|
// Note: Tool support is currently provided by the VSCode Language Model API directly
|
||||||
// Extensions can register tools using vscode.lm.registerTool()
|
// Extensions can register tools using vscode.lm.registerTool()
|
||||||
@@ -368,40 +357,40 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
const response: vscode.LanguageModelChatResponse = await client.sendRequest(
|
const response: vscode.LanguageModelChatResponse = await client.sendRequest(
|
||||||
vsCodeLmMessages,
|
vsCodeLmMessages,
|
||||||
requestOptions,
|
requestOptions,
|
||||||
this.currentRequestCancellation.token
|
this.currentRequestCancellation.token,
|
||||||
);
|
)
|
||||||
|
|
||||||
// Consume the stream and handle both text and tool call chunks
|
// Consume the stream and handle both text and tool call chunks
|
||||||
for await (const chunk of response.stream) {
|
for await (const chunk of response.stream) {
|
||||||
if (chunk instanceof vscode.LanguageModelTextPart) {
|
if (chunk instanceof vscode.LanguageModelTextPart) {
|
||||||
// Validate text part value
|
// Validate text part value
|
||||||
if (typeof chunk.value !== 'string') {
|
if (typeof chunk.value !== "string") {
|
||||||
console.warn('Cline <Language Model API>: Invalid text part value received:', chunk.value);
|
console.warn("Cline <Language Model API>: Invalid text part value received:", chunk.value)
|
||||||
continue;
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
accumulatedText += chunk.value;
|
accumulatedText += chunk.value
|
||||||
yield {
|
yield {
|
||||||
type: "text",
|
type: "text",
|
||||||
text: chunk.value,
|
text: chunk.value,
|
||||||
};
|
}
|
||||||
} else if (chunk instanceof vscode.LanguageModelToolCallPart) {
|
} else if (chunk instanceof vscode.LanguageModelToolCallPart) {
|
||||||
try {
|
try {
|
||||||
// Validate tool call parameters
|
// Validate tool call parameters
|
||||||
if (!chunk.name || typeof chunk.name !== 'string') {
|
if (!chunk.name || typeof chunk.name !== "string") {
|
||||||
console.warn('Cline <Language Model API>: Invalid tool name received:', chunk.name);
|
console.warn("Cline <Language Model API>: Invalid tool name received:", chunk.name)
|
||||||
continue;
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!chunk.callId || typeof chunk.callId !== 'string') {
|
if (!chunk.callId || typeof chunk.callId !== "string") {
|
||||||
console.warn('Cline <Language Model API>: Invalid tool callId received:', chunk.callId);
|
console.warn("Cline <Language Model API>: Invalid tool callId received:", chunk.callId)
|
||||||
continue;
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure input is a valid object
|
// Ensure input is a valid object
|
||||||
if (!chunk.input || typeof chunk.input !== 'object') {
|
if (!chunk.input || typeof chunk.input !== "object") {
|
||||||
console.warn('Cline <Language Model API>: Invalid tool input received:', chunk.input);
|
console.warn("Cline <Language Model API>: Invalid tool input received:", chunk.input)
|
||||||
continue;
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert tool calls to text format with proper error handling
|
// Convert tool calls to text format with proper error handling
|
||||||
@@ -409,82 +398,75 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
type: "tool_call",
|
type: "tool_call",
|
||||||
name: chunk.name,
|
name: chunk.name,
|
||||||
arguments: chunk.input,
|
arguments: chunk.input,
|
||||||
callId: chunk.callId
|
callId: chunk.callId,
|
||||||
};
|
}
|
||||||
|
|
||||||
const toolCallText = JSON.stringify(toolCall);
|
const toolCallText = JSON.stringify(toolCall)
|
||||||
accumulatedText += toolCallText;
|
accumulatedText += toolCallText
|
||||||
|
|
||||||
// Log tool call for debugging
|
// Log tool call for debugging
|
||||||
console.debug('Cline <Language Model API>: Processing tool call:', {
|
console.debug("Cline <Language Model API>: Processing tool call:", {
|
||||||
name: chunk.name,
|
name: chunk.name,
|
||||||
callId: chunk.callId,
|
callId: chunk.callId,
|
||||||
inputSize: JSON.stringify(chunk.input).length
|
inputSize: JSON.stringify(chunk.input).length,
|
||||||
});
|
})
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
type: "text",
|
type: "text",
|
||||||
text: toolCallText,
|
text: toolCallText,
|
||||||
};
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Cline <Language Model API>: Failed to process tool call:', error);
|
console.error("Cline <Language Model API>: Failed to process tool call:", error)
|
||||||
// Continue processing other chunks even if one fails
|
// Continue processing other chunks even if one fails
|
||||||
continue;
|
continue
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
console.warn('Cline <Language Model API>: Unknown chunk type received:', chunk);
|
console.warn("Cline <Language Model API>: Unknown chunk type received:", chunk)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count tokens in the accumulated text after stream completion
|
// Count tokens in the accumulated text after stream completion
|
||||||
const totalOutputTokens: number = await this.countTokens(accumulatedText);
|
const totalOutputTokens: number = await this.countTokens(accumulatedText)
|
||||||
|
|
||||||
// Report final usage after stream completion
|
// Report final usage after stream completion
|
||||||
yield {
|
yield {
|
||||||
type: "usage",
|
type: "usage",
|
||||||
inputTokens: totalInputTokens,
|
inputTokens: totalInputTokens,
|
||||||
outputTokens: totalOutputTokens,
|
outputTokens: totalOutputTokens,
|
||||||
totalCost: calculateApiCost(
|
totalCost: calculateApiCost(this.getModel().info, totalInputTokens, totalOutputTokens),
|
||||||
this.getModel().info,
|
}
|
||||||
totalInputTokens,
|
} catch (error: unknown) {
|
||||||
totalOutputTokens
|
this.ensureCleanState()
|
||||||
)
|
|
||||||
};
|
|
||||||
}
|
|
||||||
catch (error: unknown) {
|
|
||||||
|
|
||||||
this.ensureCleanState();
|
|
||||||
|
|
||||||
if (error instanceof vscode.CancellationError) {
|
if (error instanceof vscode.CancellationError) {
|
||||||
|
throw new Error("Cline <Language Model API>: Request cancelled by user")
|
||||||
throw new Error("Cline <Language Model API>: Request cancelled by user");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (error instanceof Error) {
|
if (error instanceof Error) {
|
||||||
console.error('Cline <Language Model API>: Stream error details:', {
|
console.error("Cline <Language Model API>: Stream error details:", {
|
||||||
message: error.message,
|
message: error.message,
|
||||||
stack: error.stack,
|
stack: error.stack,
|
||||||
name: error.name
|
name: error.name,
|
||||||
});
|
})
|
||||||
|
|
||||||
// Return original error if it's already an Error instance
|
// Return original error if it's already an Error instance
|
||||||
throw error;
|
throw error
|
||||||
} else if (typeof error === 'object' && error !== null) {
|
} else if (typeof error === "object" && error !== null) {
|
||||||
// Handle error-like objects
|
// Handle error-like objects
|
||||||
const errorDetails = JSON.stringify(error, null, 2);
|
const errorDetails = JSON.stringify(error, null, 2)
|
||||||
console.error('Cline <Language Model API>: Stream error object:', errorDetails);
|
console.error("Cline <Language Model API>: Stream error object:", errorDetails)
|
||||||
throw new Error(`Cline <Language Model API>: Response stream error: ${errorDetails}`);
|
throw new Error(`Cline <Language Model API>: Response stream error: ${errorDetails}`)
|
||||||
} else {
|
} else {
|
||||||
// Fallback for unknown error types
|
// Fallback for unknown error types
|
||||||
const errorMessage = String(error);
|
const errorMessage = String(error)
|
||||||
console.error('Cline <Language Model API>: Unknown stream error:', errorMessage);
|
console.error("Cline <Language Model API>: Unknown stream error:", errorMessage)
|
||||||
throw new Error(`Cline <Language Model API>: Response stream error: ${errorMessage}`);
|
throw new Error(`Cline <Language Model API>: Response stream error: ${errorMessage}`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return model information based on the current client state
|
// Return model information based on the current client state
|
||||||
getModel(): { id: string; info: ModelInfo; } {
|
getModel(): { id: string; info: ModelInfo } {
|
||||||
if (this.client) {
|
if (this.client) {
|
||||||
// Validate client properties
|
// Validate client properties
|
||||||
const requiredProps = {
|
const requiredProps = {
|
||||||
@@ -492,68 +474,69 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
vendor: this.client.vendor,
|
vendor: this.client.vendor,
|
||||||
family: this.client.family,
|
family: this.client.family,
|
||||||
version: this.client.version,
|
version: this.client.version,
|
||||||
maxInputTokens: this.client.maxInputTokens
|
maxInputTokens: this.client.maxInputTokens,
|
||||||
};
|
}
|
||||||
|
|
||||||
// Log any missing properties for debugging
|
// Log any missing properties for debugging
|
||||||
for (const [prop, value] of Object.entries(requiredProps)) {
|
for (const [prop, value] of Object.entries(requiredProps)) {
|
||||||
if (!value && value !== 0) {
|
if (!value && value !== 0) {
|
||||||
console.warn(`Cline <Language Model API>: Client missing ${prop} property`);
|
console.warn(`Cline <Language Model API>: Client missing ${prop} property`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Construct model ID using available information
|
// Construct model ID using available information
|
||||||
const modelParts = [
|
const modelParts = [this.client.vendor, this.client.family, this.client.version].filter(Boolean)
|
||||||
this.client.vendor,
|
|
||||||
this.client.family,
|
|
||||||
this.client.version
|
|
||||||
].filter(Boolean);
|
|
||||||
|
|
||||||
const modelId = this.client.id || modelParts.join(SELECTOR_SEPARATOR);
|
const modelId = this.client.id || modelParts.join(SELECTOR_SEPARATOR)
|
||||||
|
|
||||||
// Build model info with conservative defaults for missing values
|
// Build model info with conservative defaults for missing values
|
||||||
const modelInfo: ModelInfo = {
|
const modelInfo: ModelInfo = {
|
||||||
maxTokens: -1, // Unlimited tokens by default
|
maxTokens: -1, // Unlimited tokens by default
|
||||||
contextWindow: typeof this.client.maxInputTokens === 'number'
|
contextWindow:
|
||||||
? Math.max(0, this.client.maxInputTokens)
|
typeof this.client.maxInputTokens === "number"
|
||||||
: openAiModelInfoSaneDefaults.contextWindow,
|
? Math.max(0, this.client.maxInputTokens)
|
||||||
|
: openAiModelInfoSaneDefaults.contextWindow,
|
||||||
supportsImages: false, // VSCode Language Model API currently doesn't support image inputs
|
supportsImages: false, // VSCode Language Model API currently doesn't support image inputs
|
||||||
supportsPromptCache: true,
|
supportsPromptCache: true,
|
||||||
inputPrice: 0,
|
inputPrice: 0,
|
||||||
outputPrice: 0,
|
outputPrice: 0,
|
||||||
description: `VSCode Language Model: ${modelId}`
|
description: `VSCode Language Model: ${modelId}`,
|
||||||
};
|
}
|
||||||
|
|
||||||
return { id: modelId, info: modelInfo };
|
return { id: modelId, info: modelInfo }
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback when no client is available
|
// Fallback when no client is available
|
||||||
const fallbackId = this.options.vsCodeLmModelSelector
|
const fallbackId = this.options.vsCodeLmModelSelector
|
||||||
? stringifyVsCodeLmModelSelector(this.options.vsCodeLmModelSelector)
|
? stringifyVsCodeLmModelSelector(this.options.vsCodeLmModelSelector)
|
||||||
: "vscode-lm";
|
: "vscode-lm"
|
||||||
|
|
||||||
console.debug('Cline <Language Model API>: No client available, using fallback model info');
|
console.debug("Cline <Language Model API>: No client available, using fallback model info")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
id: fallbackId,
|
id: fallbackId,
|
||||||
info: {
|
info: {
|
||||||
...openAiModelInfoSaneDefaults,
|
...openAiModelInfoSaneDefaults,
|
||||||
description: `VSCode Language Model (Fallback): ${fallbackId}`
|
description: `VSCode Language Model (Fallback): ${fallbackId}`,
|
||||||
}
|
},
|
||||||
};
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async completePrompt(prompt: string): Promise<string> {
|
async completePrompt(prompt: string): Promise<string> {
|
||||||
try {
|
try {
|
||||||
const client = await this.getClient();
|
const client = await this.getClient()
|
||||||
const response = await client.sendRequest([vscode.LanguageModelChatMessage.User(prompt)], {}, new vscode.CancellationTokenSource().token);
|
const response = await client.sendRequest(
|
||||||
let result = "";
|
[vscode.LanguageModelChatMessage.User(prompt)],
|
||||||
|
{},
|
||||||
|
new vscode.CancellationTokenSource().token,
|
||||||
|
)
|
||||||
|
let result = ""
|
||||||
for await (const chunk of response.stream) {
|
for await (const chunk of response.stream) {
|
||||||
if (chunk instanceof vscode.LanguageModelTextPart) {
|
if (chunk instanceof vscode.LanguageModelTextPart) {
|
||||||
result += chunk.value;
|
result += chunk.value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (error instanceof Error) {
|
if (error instanceof Error) {
|
||||||
throw new Error(`VSCode LM completion error: ${error.message}`)
|
throw new Error(`VSCode LM completion error: ${error.message}`)
|
||||||
|
|||||||
@@ -1,252 +1,250 @@
|
|||||||
import { convertToBedrockConverseMessages, convertToAnthropicMessage } from '../bedrock-converse-format'
|
import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../bedrock-converse-format"
|
||||||
import { Anthropic } from '@anthropic-ai/sdk'
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import { ContentBlock, ToolResultContentBlock } from '@aws-sdk/client-bedrock-runtime'
|
import { ContentBlock, ToolResultContentBlock } from "@aws-sdk/client-bedrock-runtime"
|
||||||
import { StreamEvent } from '../../providers/bedrock'
|
import { StreamEvent } from "../../providers/bedrock"
|
||||||
|
|
||||||
describe('bedrock-converse-format', () => {
|
describe("bedrock-converse-format", () => {
|
||||||
describe('convertToBedrockConverseMessages', () => {
|
describe("convertToBedrockConverseMessages", () => {
|
||||||
test('converts simple text messages correctly', () => {
|
test("converts simple text messages correctly", () => {
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
{ role: 'user', content: 'Hello' },
|
{ role: "user", content: "Hello" },
|
||||||
{ role: 'assistant', content: 'Hi there' }
|
{ role: "assistant", content: "Hi there" },
|
||||||
]
|
]
|
||||||
|
|
||||||
const result = convertToBedrockConverseMessages(messages)
|
const result = convertToBedrockConverseMessages(messages)
|
||||||
|
|
||||||
expect(result).toEqual([
|
expect(result).toEqual([
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: "user",
|
||||||
content: [{ text: 'Hello' }]
|
content: [{ text: "Hello" }],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
role: 'assistant',
|
role: "assistant",
|
||||||
content: [{ text: 'Hi there' }]
|
content: [{ text: "Hi there" }],
|
||||||
}
|
},
|
||||||
])
|
])
|
||||||
})
|
})
|
||||||
|
|
||||||
test('converts messages with images correctly', () => {
|
test("converts messages with images correctly", () => {
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: "user",
|
||||||
content: [
|
content: [
|
||||||
{
|
{
|
||||||
type: 'text',
|
type: "text",
|
||||||
text: 'Look at this image:'
|
text: "Look at this image:",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: 'image',
|
type: "image",
|
||||||
source: {
|
source: {
|
||||||
type: 'base64',
|
type: "base64",
|
||||||
data: 'SGVsbG8=', // "Hello" in base64
|
data: "SGVsbG8=", // "Hello" in base64
|
||||||
media_type: 'image/jpeg' as const
|
media_type: "image/jpeg" as const,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
const result = convertToBedrockConverseMessages(messages)
|
const result = convertToBedrockConverseMessages(messages)
|
||||||
|
|
||||||
if (!result[0] || !result[0].content) {
|
if (!result[0] || !result[0].content) {
|
||||||
fail('Expected result to have content')
|
fail("Expected result to have content")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
expect(result[0].role).toBe('user')
|
expect(result[0].role).toBe("user")
|
||||||
expect(result[0].content).toHaveLength(2)
|
expect(result[0].content).toHaveLength(2)
|
||||||
expect(result[0].content[0]).toEqual({ text: 'Look at this image:' })
|
expect(result[0].content[0]).toEqual({ text: "Look at this image:" })
|
||||||
|
|
||||||
const imageBlock = result[0].content[1] as ContentBlock
|
|
||||||
if ('image' in imageBlock && imageBlock.image && imageBlock.image.source) {
|
|
||||||
expect(imageBlock.image.format).toBe('jpeg')
|
|
||||||
expect(imageBlock.image.source).toBeDefined()
|
|
||||||
expect(imageBlock.image.source.bytes).toBeDefined()
|
|
||||||
} else {
|
|
||||||
fail('Expected image block not found')
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
test('converts tool use messages correctly', () => {
|
const imageBlock = result[0].content[1] as ContentBlock
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [
|
if ("image" in imageBlock && imageBlock.image && imageBlock.image.source) {
|
||||||
{
|
expect(imageBlock.image.format).toBe("jpeg")
|
||||||
role: 'assistant',
|
expect(imageBlock.image.source).toBeDefined()
|
||||||
content: [
|
expect(imageBlock.image.source.bytes).toBeDefined()
|
||||||
{
|
} else {
|
||||||
type: 'tool_use',
|
fail("Expected image block not found")
|
||||||
id: 'test-id',
|
}
|
||||||
name: 'read_file',
|
})
|
||||||
input: {
|
|
||||||
path: 'test.txt'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
const result = convertToBedrockConverseMessages(messages)
|
test("converts tool use messages correctly", () => {
|
||||||
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
|
{
|
||||||
|
role: "assistant",
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: "tool_use",
|
||||||
|
id: "test-id",
|
||||||
|
name: "read_file",
|
||||||
|
input: {
|
||||||
|
path: "test.txt",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
if (!result[0] || !result[0].content) {
|
const result = convertToBedrockConverseMessages(messages)
|
||||||
fail('Expected result to have content')
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(result[0].role).toBe('assistant')
|
if (!result[0] || !result[0].content) {
|
||||||
const toolBlock = result[0].content[0] as ContentBlock
|
fail("Expected result to have content")
|
||||||
if ('toolUse' in toolBlock && toolBlock.toolUse) {
|
return
|
||||||
expect(toolBlock.toolUse).toEqual({
|
}
|
||||||
toolUseId: 'test-id',
|
|
||||||
name: 'read_file',
|
|
||||||
input: '<read_file>\n<path>\ntest.txt\n</path>\n</read_file>'
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
fail('Expected tool use block not found')
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
test('converts tool result messages correctly', () => {
|
expect(result[0].role).toBe("assistant")
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [
|
const toolBlock = result[0].content[0] as ContentBlock
|
||||||
{
|
if ("toolUse" in toolBlock && toolBlock.toolUse) {
|
||||||
role: 'assistant',
|
expect(toolBlock.toolUse).toEqual({
|
||||||
content: [
|
toolUseId: "test-id",
|
||||||
{
|
name: "read_file",
|
||||||
type: 'tool_result',
|
input: "<read_file>\n<path>\ntest.txt\n</path>\n</read_file>",
|
||||||
tool_use_id: 'test-id',
|
})
|
||||||
content: [{ type: 'text', text: 'File contents here' }]
|
} else {
|
||||||
}
|
fail("Expected tool use block not found")
|
||||||
]
|
}
|
||||||
}
|
})
|
||||||
]
|
|
||||||
|
|
||||||
const result = convertToBedrockConverseMessages(messages)
|
test("converts tool result messages correctly", () => {
|
||||||
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
|
{
|
||||||
|
role: "assistant",
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: "tool_result",
|
||||||
|
tool_use_id: "test-id",
|
||||||
|
content: [{ type: "text", text: "File contents here" }],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
if (!result[0] || !result[0].content) {
|
const result = convertToBedrockConverseMessages(messages)
|
||||||
fail('Expected result to have content')
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(result[0].role).toBe('assistant')
|
if (!result[0] || !result[0].content) {
|
||||||
const resultBlock = result[0].content[0] as ContentBlock
|
fail("Expected result to have content")
|
||||||
if ('toolResult' in resultBlock && resultBlock.toolResult) {
|
return
|
||||||
const expectedContent: ToolResultContentBlock[] = [
|
}
|
||||||
{ text: 'File contents here' }
|
|
||||||
]
|
|
||||||
expect(resultBlock.toolResult).toEqual({
|
|
||||||
toolUseId: 'test-id',
|
|
||||||
content: expectedContent,
|
|
||||||
status: 'success'
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
fail('Expected tool result block not found')
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
test('handles text content correctly', () => {
|
expect(result[0].role).toBe("assistant")
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [
|
const resultBlock = result[0].content[0] as ContentBlock
|
||||||
{
|
if ("toolResult" in resultBlock && resultBlock.toolResult) {
|
||||||
role: 'user',
|
const expectedContent: ToolResultContentBlock[] = [{ text: "File contents here" }]
|
||||||
content: [
|
expect(resultBlock.toolResult).toEqual({
|
||||||
{
|
toolUseId: "test-id",
|
||||||
type: 'text',
|
content: expectedContent,
|
||||||
text: 'Hello world'
|
status: "success",
|
||||||
}
|
})
|
||||||
]
|
} else {
|
||||||
}
|
fail("Expected tool result block not found")
|
||||||
]
|
}
|
||||||
|
})
|
||||||
|
|
||||||
const result = convertToBedrockConverseMessages(messages)
|
test("handles text content correctly", () => {
|
||||||
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: "text",
|
||||||
|
text: "Hello world",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
if (!result[0] || !result[0].content) {
|
const result = convertToBedrockConverseMessages(messages)
|
||||||
fail('Expected result to have content')
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(result[0].role).toBe('user')
|
if (!result[0] || !result[0].content) {
|
||||||
expect(result[0].content).toHaveLength(1)
|
fail("Expected result to have content")
|
||||||
const textBlock = result[0].content[0] as ContentBlock
|
return
|
||||||
expect(textBlock).toEqual({ text: 'Hello world' })
|
}
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('convertToAnthropicMessage', () => {
|
expect(result[0].role).toBe("user")
|
||||||
test('converts metadata events correctly', () => {
|
expect(result[0].content).toHaveLength(1)
|
||||||
const event: StreamEvent = {
|
const textBlock = result[0].content[0] as ContentBlock
|
||||||
metadata: {
|
expect(textBlock).toEqual({ text: "Hello world" })
|
||||||
usage: {
|
})
|
||||||
inputTokens: 10,
|
})
|
||||||
outputTokens: 20
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const result = convertToAnthropicMessage(event, 'test-model')
|
describe("convertToAnthropicMessage", () => {
|
||||||
|
test("converts metadata events correctly", () => {
|
||||||
|
const event: StreamEvent = {
|
||||||
|
metadata: {
|
||||||
|
usage: {
|
||||||
|
inputTokens: 10,
|
||||||
|
outputTokens: 20,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
expect(result).toEqual({
|
const result = convertToAnthropicMessage(event, "test-model")
|
||||||
id: '',
|
|
||||||
type: 'message',
|
|
||||||
role: 'assistant',
|
|
||||||
model: 'test-model',
|
|
||||||
usage: {
|
|
||||||
input_tokens: 10,
|
|
||||||
output_tokens: 20
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
test('converts content block start events correctly', () => {
|
expect(result).toEqual({
|
||||||
const event: StreamEvent = {
|
id: "",
|
||||||
contentBlockStart: {
|
type: "message",
|
||||||
start: {
|
role: "assistant",
|
||||||
text: 'Hello'
|
model: "test-model",
|
||||||
}
|
usage: {
|
||||||
}
|
input_tokens: 10,
|
||||||
}
|
output_tokens: 20,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
const result = convertToAnthropicMessage(event, 'test-model')
|
test("converts content block start events correctly", () => {
|
||||||
|
const event: StreamEvent = {
|
||||||
|
contentBlockStart: {
|
||||||
|
start: {
|
||||||
|
text: "Hello",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
expect(result).toEqual({
|
const result = convertToAnthropicMessage(event, "test-model")
|
||||||
type: 'message',
|
|
||||||
role: 'assistant',
|
|
||||||
content: [{ type: 'text', text: 'Hello' }],
|
|
||||||
model: 'test-model'
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
test('converts content block delta events correctly', () => {
|
expect(result).toEqual({
|
||||||
const event: StreamEvent = {
|
type: "message",
|
||||||
contentBlockDelta: {
|
role: "assistant",
|
||||||
delta: {
|
content: [{ type: "text", text: "Hello" }],
|
||||||
text: ' world'
|
model: "test-model",
|
||||||
}
|
})
|
||||||
}
|
})
|
||||||
}
|
|
||||||
|
|
||||||
const result = convertToAnthropicMessage(event, 'test-model')
|
test("converts content block delta events correctly", () => {
|
||||||
|
const event: StreamEvent = {
|
||||||
|
contentBlockDelta: {
|
||||||
|
delta: {
|
||||||
|
text: " world",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
expect(result).toEqual({
|
const result = convertToAnthropicMessage(event, "test-model")
|
||||||
type: 'message',
|
|
||||||
role: 'assistant',
|
|
||||||
content: [{ type: 'text', text: ' world' }],
|
|
||||||
model: 'test-model'
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
test('converts message stop events correctly', () => {
|
expect(result).toEqual({
|
||||||
const event: StreamEvent = {
|
type: "message",
|
||||||
messageStop: {
|
role: "assistant",
|
||||||
stopReason: 'end_turn' as const
|
content: [{ type: "text", text: " world" }],
|
||||||
}
|
model: "test-model",
|
||||||
}
|
})
|
||||||
|
})
|
||||||
|
|
||||||
const result = convertToAnthropicMessage(event, 'test-model')
|
test("converts message stop events correctly", () => {
|
||||||
|
const event: StreamEvent = {
|
||||||
|
messageStop: {
|
||||||
|
stopReason: "end_turn" as const,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
expect(result).toEqual({
|
const result = convertToAnthropicMessage(event, "test-model")
|
||||||
type: 'message',
|
|
||||||
role: 'assistant',
|
expect(result).toEqual({
|
||||||
stop_reason: 'end_turn',
|
type: "message",
|
||||||
stop_sequence: null,
|
role: "assistant",
|
||||||
model: 'test-model'
|
stop_reason: "end_turn",
|
||||||
})
|
stop_sequence: null,
|
||||||
})
|
model: "test-model",
|
||||||
})
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,257 +1,275 @@
|
|||||||
import { convertToOpenAiMessages, convertToAnthropicMessage } from '../openai-format';
|
import { convertToOpenAiMessages, convertToAnthropicMessage } from "../openai-format"
|
||||||
import { Anthropic } from '@anthropic-ai/sdk';
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import OpenAI from 'openai';
|
import OpenAI from "openai"
|
||||||
|
|
||||||
type PartialChatCompletion = Omit<OpenAI.Chat.Completions.ChatCompletion, 'choices'> & {
|
type PartialChatCompletion = Omit<OpenAI.Chat.Completions.ChatCompletion, "choices"> & {
|
||||||
choices: Array<Partial<OpenAI.Chat.Completions.ChatCompletion.Choice> & {
|
choices: Array<
|
||||||
message: OpenAI.Chat.Completions.ChatCompletion.Choice['message'];
|
Partial<OpenAI.Chat.Completions.ChatCompletion.Choice> & {
|
||||||
finish_reason: string;
|
message: OpenAI.Chat.Completions.ChatCompletion.Choice["message"]
|
||||||
index: number;
|
finish_reason: string
|
||||||
}>;
|
index: number
|
||||||
};
|
}
|
||||||
|
>
|
||||||
|
}
|
||||||
|
|
||||||
describe('OpenAI Format Transformations', () => {
|
describe("OpenAI Format Transformations", () => {
|
||||||
describe('convertToOpenAiMessages', () => {
|
describe("convertToOpenAiMessages", () => {
|
||||||
it('should convert simple text messages', () => {
|
it("should convert simple text messages", () => {
|
||||||
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: "user",
|
||||||
content: 'Hello'
|
content: "Hello",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
role: 'assistant',
|
role: "assistant",
|
||||||
content: 'Hi there!'
|
content: "Hi there!",
|
||||||
}
|
},
|
||||||
];
|
]
|
||||||
|
|
||||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages);
|
const openAiMessages = convertToOpenAiMessages(anthropicMessages)
|
||||||
expect(openAiMessages).toHaveLength(2);
|
expect(openAiMessages).toHaveLength(2)
|
||||||
expect(openAiMessages[0]).toEqual({
|
expect(openAiMessages[0]).toEqual({
|
||||||
role: 'user',
|
role: "user",
|
||||||
content: 'Hello'
|
content: "Hello",
|
||||||
});
|
})
|
||||||
expect(openAiMessages[1]).toEqual({
|
expect(openAiMessages[1]).toEqual({
|
||||||
role: 'assistant',
|
role: "assistant",
|
||||||
content: 'Hi there!'
|
content: "Hi there!",
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle messages with image content', () => {
|
it("should handle messages with image content", () => {
|
||||||
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: "user",
|
||||||
content: [
|
content: [
|
||||||
{
|
{
|
||||||
type: 'text',
|
type: "text",
|
||||||
text: 'What is in this image?'
|
text: "What is in this image?",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: 'image',
|
type: "image",
|
||||||
source: {
|
source: {
|
||||||
type: 'base64',
|
type: "base64",
|
||||||
media_type: 'image/jpeg',
|
media_type: "image/jpeg",
|
||||||
data: 'base64data'
|
data: "base64data",
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
},
|
||||||
];
|
]
|
||||||
|
|
||||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages);
|
const openAiMessages = convertToOpenAiMessages(anthropicMessages)
|
||||||
expect(openAiMessages).toHaveLength(1);
|
expect(openAiMessages).toHaveLength(1)
|
||||||
expect(openAiMessages[0].role).toBe('user');
|
expect(openAiMessages[0].role).toBe("user")
|
||||||
|
|
||||||
const content = openAiMessages[0].content as Array<{
|
|
||||||
type: string;
|
|
||||||
text?: string;
|
|
||||||
image_url?: { url: string };
|
|
||||||
}>;
|
|
||||||
|
|
||||||
expect(Array.isArray(content)).toBe(true);
|
|
||||||
expect(content).toHaveLength(2);
|
|
||||||
expect(content[0]).toEqual({ type: 'text', text: 'What is in this image?' });
|
|
||||||
expect(content[1]).toEqual({
|
|
||||||
type: 'image_url',
|
|
||||||
image_url: { url: 'data:image/jpeg;base64,base64data' }
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle assistant messages with tool use', () => {
|
const content = openAiMessages[0].content as Array<{
|
||||||
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
type: string
|
||||||
{
|
text?: string
|
||||||
role: 'assistant',
|
image_url?: { url: string }
|
||||||
content: [
|
}>
|
||||||
{
|
|
||||||
type: 'text',
|
|
||||||
text: 'Let me check the weather.'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
type: 'tool_use',
|
|
||||||
id: 'weather-123',
|
|
||||||
name: 'get_weather',
|
|
||||||
input: { city: 'London' }
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
];
|
|
||||||
|
|
||||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages);
|
expect(Array.isArray(content)).toBe(true)
|
||||||
expect(openAiMessages).toHaveLength(1);
|
expect(content).toHaveLength(2)
|
||||||
|
expect(content[0]).toEqual({ type: "text", text: "What is in this image?" })
|
||||||
const assistantMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionAssistantMessageParam;
|
expect(content[1]).toEqual({
|
||||||
expect(assistantMessage.role).toBe('assistant');
|
type: "image_url",
|
||||||
expect(assistantMessage.content).toBe('Let me check the weather.');
|
image_url: { url: "data:image/jpeg;base64,base64data" },
|
||||||
expect(assistantMessage.tool_calls).toHaveLength(1);
|
})
|
||||||
expect(assistantMessage.tool_calls![0]).toEqual({
|
})
|
||||||
id: 'weather-123',
|
|
||||||
type: 'function',
|
|
||||||
function: {
|
|
||||||
name: 'get_weather',
|
|
||||||
arguments: JSON.stringify({ city: 'London' })
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle user messages with tool results', () => {
|
it("should handle assistant messages with tool use", () => {
|
||||||
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: "assistant",
|
||||||
content: [
|
content: [
|
||||||
{
|
{
|
||||||
type: 'tool_result',
|
type: "text",
|
||||||
tool_use_id: 'weather-123',
|
text: "Let me check the weather.",
|
||||||
content: 'Current temperature in London: 20°C'
|
},
|
||||||
}
|
{
|
||||||
]
|
type: "tool_use",
|
||||||
}
|
id: "weather-123",
|
||||||
];
|
name: "get_weather",
|
||||||
|
input: { city: "London" },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages);
|
const openAiMessages = convertToOpenAiMessages(anthropicMessages)
|
||||||
expect(openAiMessages).toHaveLength(1);
|
expect(openAiMessages).toHaveLength(1)
|
||||||
|
|
||||||
const toolMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionToolMessageParam;
|
|
||||||
expect(toolMessage.role).toBe('tool');
|
|
||||||
expect(toolMessage.tool_call_id).toBe('weather-123');
|
|
||||||
expect(toolMessage.content).toBe('Current temperature in London: 20°C');
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('convertToAnthropicMessage', () => {
|
const assistantMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionAssistantMessageParam
|
||||||
it('should convert simple completion', () => {
|
expect(assistantMessage.role).toBe("assistant")
|
||||||
const openAiCompletion: PartialChatCompletion = {
|
expect(assistantMessage.content).toBe("Let me check the weather.")
|
||||||
id: 'completion-123',
|
expect(assistantMessage.tool_calls).toHaveLength(1)
|
||||||
model: 'gpt-4',
|
expect(assistantMessage.tool_calls![0]).toEqual({
|
||||||
choices: [{
|
id: "weather-123",
|
||||||
message: {
|
type: "function",
|
||||||
role: 'assistant',
|
function: {
|
||||||
content: 'Hello there!',
|
name: "get_weather",
|
||||||
refusal: null
|
arguments: JSON.stringify({ city: "London" }),
|
||||||
},
|
},
|
||||||
finish_reason: 'stop',
|
})
|
||||||
index: 0
|
})
|
||||||
}],
|
|
||||||
usage: {
|
|
||||||
prompt_tokens: 10,
|
|
||||||
completion_tokens: 5,
|
|
||||||
total_tokens: 15
|
|
||||||
},
|
|
||||||
created: 123456789,
|
|
||||||
object: 'chat.completion'
|
|
||||||
};
|
|
||||||
|
|
||||||
const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion);
|
it("should handle user messages with tool results", () => {
|
||||||
expect(anthropicMessage.id).toBe('completion-123');
|
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
||||||
expect(anthropicMessage.role).toBe('assistant');
|
{
|
||||||
expect(anthropicMessage.content).toHaveLength(1);
|
role: "user",
|
||||||
expect(anthropicMessage.content[0]).toEqual({
|
content: [
|
||||||
type: 'text',
|
{
|
||||||
text: 'Hello there!'
|
type: "tool_result",
|
||||||
});
|
tool_use_id: "weather-123",
|
||||||
expect(anthropicMessage.stop_reason).toBe('end_turn');
|
content: "Current temperature in London: 20°C",
|
||||||
expect(anthropicMessage.usage).toEqual({
|
},
|
||||||
input_tokens: 10,
|
],
|
||||||
output_tokens: 5
|
},
|
||||||
});
|
]
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle tool calls in completion', () => {
|
const openAiMessages = convertToOpenAiMessages(anthropicMessages)
|
||||||
const openAiCompletion: PartialChatCompletion = {
|
expect(openAiMessages).toHaveLength(1)
|
||||||
id: 'completion-123',
|
|
||||||
model: 'gpt-4',
|
|
||||||
choices: [{
|
|
||||||
message: {
|
|
||||||
role: 'assistant',
|
|
||||||
content: 'Let me check the weather.',
|
|
||||||
tool_calls: [{
|
|
||||||
id: 'weather-123',
|
|
||||||
type: 'function',
|
|
||||||
function: {
|
|
||||||
name: 'get_weather',
|
|
||||||
arguments: '{"city":"London"}'
|
|
||||||
}
|
|
||||||
}],
|
|
||||||
refusal: null
|
|
||||||
},
|
|
||||||
finish_reason: 'tool_calls',
|
|
||||||
index: 0
|
|
||||||
}],
|
|
||||||
usage: {
|
|
||||||
prompt_tokens: 15,
|
|
||||||
completion_tokens: 8,
|
|
||||||
total_tokens: 23
|
|
||||||
},
|
|
||||||
created: 123456789,
|
|
||||||
object: 'chat.completion'
|
|
||||||
};
|
|
||||||
|
|
||||||
const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion);
|
const toolMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionToolMessageParam
|
||||||
expect(anthropicMessage.content).toHaveLength(2);
|
expect(toolMessage.role).toBe("tool")
|
||||||
expect(anthropicMessage.content[0]).toEqual({
|
expect(toolMessage.tool_call_id).toBe("weather-123")
|
||||||
type: 'text',
|
expect(toolMessage.content).toBe("Current temperature in London: 20°C")
|
||||||
text: 'Let me check the weather.'
|
})
|
||||||
});
|
})
|
||||||
expect(anthropicMessage.content[1]).toEqual({
|
|
||||||
type: 'tool_use',
|
|
||||||
id: 'weather-123',
|
|
||||||
name: 'get_weather',
|
|
||||||
input: { city: 'London' }
|
|
||||||
});
|
|
||||||
expect(anthropicMessage.stop_reason).toBe('tool_use');
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle invalid tool call arguments', () => {
|
describe("convertToAnthropicMessage", () => {
|
||||||
const openAiCompletion: PartialChatCompletion = {
|
it("should convert simple completion", () => {
|
||||||
id: 'completion-123',
|
const openAiCompletion: PartialChatCompletion = {
|
||||||
model: 'gpt-4',
|
id: "completion-123",
|
||||||
choices: [{
|
model: "gpt-4",
|
||||||
message: {
|
choices: [
|
||||||
role: 'assistant',
|
{
|
||||||
content: 'Testing invalid arguments',
|
message: {
|
||||||
tool_calls: [{
|
role: "assistant",
|
||||||
id: 'test-123',
|
content: "Hello there!",
|
||||||
type: 'function',
|
refusal: null,
|
||||||
function: {
|
},
|
||||||
name: 'test_function',
|
finish_reason: "stop",
|
||||||
arguments: 'invalid json'
|
index: 0,
|
||||||
}
|
},
|
||||||
}],
|
],
|
||||||
refusal: null
|
usage: {
|
||||||
},
|
prompt_tokens: 10,
|
||||||
finish_reason: 'tool_calls',
|
completion_tokens: 5,
|
||||||
index: 0
|
total_tokens: 15,
|
||||||
}],
|
},
|
||||||
created: 123456789,
|
created: 123456789,
|
||||||
object: 'chat.completion'
|
object: "chat.completion",
|
||||||
};
|
}
|
||||||
|
|
||||||
const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion);
|
const anthropicMessage = convertToAnthropicMessage(
|
||||||
expect(anthropicMessage.content).toHaveLength(2);
|
openAiCompletion as OpenAI.Chat.Completions.ChatCompletion,
|
||||||
expect(anthropicMessage.content[1]).toEqual({
|
)
|
||||||
type: 'tool_use',
|
expect(anthropicMessage.id).toBe("completion-123")
|
||||||
id: 'test-123',
|
expect(anthropicMessage.role).toBe("assistant")
|
||||||
name: 'test_function',
|
expect(anthropicMessage.content).toHaveLength(1)
|
||||||
input: {} // Should default to empty object for invalid JSON
|
expect(anthropicMessage.content[0]).toEqual({
|
||||||
});
|
type: "text",
|
||||||
});
|
text: "Hello there!",
|
||||||
});
|
})
|
||||||
});
|
expect(anthropicMessage.stop_reason).toBe("end_turn")
|
||||||
|
expect(anthropicMessage.usage).toEqual({
|
||||||
|
input_tokens: 10,
|
||||||
|
output_tokens: 5,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle tool calls in completion", () => {
|
||||||
|
const openAiCompletion: PartialChatCompletion = {
|
||||||
|
id: "completion-123",
|
||||||
|
model: "gpt-4",
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
message: {
|
||||||
|
role: "assistant",
|
||||||
|
content: "Let me check the weather.",
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
id: "weather-123",
|
||||||
|
type: "function",
|
||||||
|
function: {
|
||||||
|
name: "get_weather",
|
||||||
|
arguments: '{"city":"London"}',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
refusal: null,
|
||||||
|
},
|
||||||
|
finish_reason: "tool_calls",
|
||||||
|
index: 0,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 15,
|
||||||
|
completion_tokens: 8,
|
||||||
|
total_tokens: 23,
|
||||||
|
},
|
||||||
|
created: 123456789,
|
||||||
|
object: "chat.completion",
|
||||||
|
}
|
||||||
|
|
||||||
|
const anthropicMessage = convertToAnthropicMessage(
|
||||||
|
openAiCompletion as OpenAI.Chat.Completions.ChatCompletion,
|
||||||
|
)
|
||||||
|
expect(anthropicMessage.content).toHaveLength(2)
|
||||||
|
expect(anthropicMessage.content[0]).toEqual({
|
||||||
|
type: "text",
|
||||||
|
text: "Let me check the weather.",
|
||||||
|
})
|
||||||
|
expect(anthropicMessage.content[1]).toEqual({
|
||||||
|
type: "tool_use",
|
||||||
|
id: "weather-123",
|
||||||
|
name: "get_weather",
|
||||||
|
input: { city: "London" },
|
||||||
|
})
|
||||||
|
expect(anthropicMessage.stop_reason).toBe("tool_use")
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle invalid tool call arguments", () => {
|
||||||
|
const openAiCompletion: PartialChatCompletion = {
|
||||||
|
id: "completion-123",
|
||||||
|
model: "gpt-4",
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
message: {
|
||||||
|
role: "assistant",
|
||||||
|
content: "Testing invalid arguments",
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
id: "test-123",
|
||||||
|
type: "function",
|
||||||
|
function: {
|
||||||
|
name: "test_function",
|
||||||
|
arguments: "invalid json",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
refusal: null,
|
||||||
|
},
|
||||||
|
finish_reason: "tool_calls",
|
||||||
|
index: 0,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
created: 123456789,
|
||||||
|
object: "chat.completion",
|
||||||
|
}
|
||||||
|
|
||||||
|
const anthropicMessage = convertToAnthropicMessage(
|
||||||
|
openAiCompletion as OpenAI.Chat.Completions.ChatCompletion,
|
||||||
|
)
|
||||||
|
expect(anthropicMessage.content).toHaveLength(2)
|
||||||
|
expect(anthropicMessage.content[1]).toEqual({
|
||||||
|
type: "tool_use",
|
||||||
|
id: "test-123",
|
||||||
|
name: "test_function",
|
||||||
|
input: {}, // Should default to empty object for invalid JSON
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -1,114 +1,114 @@
|
|||||||
import { ApiStreamChunk } from '../stream';
|
import { ApiStreamChunk } from "../stream"
|
||||||
|
|
||||||
describe('API Stream Types', () => {
|
describe("API Stream Types", () => {
|
||||||
describe('ApiStreamChunk', () => {
|
describe("ApiStreamChunk", () => {
|
||||||
it('should correctly handle text chunks', () => {
|
it("should correctly handle text chunks", () => {
|
||||||
const textChunk: ApiStreamChunk = {
|
const textChunk: ApiStreamChunk = {
|
||||||
type: 'text',
|
type: "text",
|
||||||
text: 'Hello world'
|
text: "Hello world",
|
||||||
};
|
}
|
||||||
|
|
||||||
expect(textChunk.type).toBe('text');
|
expect(textChunk.type).toBe("text")
|
||||||
expect(textChunk.text).toBe('Hello world');
|
expect(textChunk.text).toBe("Hello world")
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should correctly handle usage chunks with cache information', () => {
|
it("should correctly handle usage chunks with cache information", () => {
|
||||||
const usageChunk: ApiStreamChunk = {
|
const usageChunk: ApiStreamChunk = {
|
||||||
type: 'usage',
|
type: "usage",
|
||||||
inputTokens: 100,
|
inputTokens: 100,
|
||||||
outputTokens: 50,
|
outputTokens: 50,
|
||||||
cacheWriteTokens: 20,
|
cacheWriteTokens: 20,
|
||||||
cacheReadTokens: 10
|
cacheReadTokens: 10,
|
||||||
};
|
}
|
||||||
|
|
||||||
expect(usageChunk.type).toBe('usage');
|
expect(usageChunk.type).toBe("usage")
|
||||||
expect(usageChunk.inputTokens).toBe(100);
|
expect(usageChunk.inputTokens).toBe(100)
|
||||||
expect(usageChunk.outputTokens).toBe(50);
|
expect(usageChunk.outputTokens).toBe(50)
|
||||||
expect(usageChunk.cacheWriteTokens).toBe(20);
|
expect(usageChunk.cacheWriteTokens).toBe(20)
|
||||||
expect(usageChunk.cacheReadTokens).toBe(10);
|
expect(usageChunk.cacheReadTokens).toBe(10)
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle usage chunks without cache tokens', () => {
|
it("should handle usage chunks without cache tokens", () => {
|
||||||
const usageChunk: ApiStreamChunk = {
|
const usageChunk: ApiStreamChunk = {
|
||||||
type: 'usage',
|
type: "usage",
|
||||||
inputTokens: 100,
|
inputTokens: 100,
|
||||||
outputTokens: 50
|
outputTokens: 50,
|
||||||
};
|
}
|
||||||
|
|
||||||
expect(usageChunk.type).toBe('usage');
|
expect(usageChunk.type).toBe("usage")
|
||||||
expect(usageChunk.inputTokens).toBe(100);
|
expect(usageChunk.inputTokens).toBe(100)
|
||||||
expect(usageChunk.outputTokens).toBe(50);
|
expect(usageChunk.outputTokens).toBe(50)
|
||||||
expect(usageChunk.cacheWriteTokens).toBeUndefined();
|
expect(usageChunk.cacheWriteTokens).toBeUndefined()
|
||||||
expect(usageChunk.cacheReadTokens).toBeUndefined();
|
expect(usageChunk.cacheReadTokens).toBeUndefined()
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle text chunks with empty strings', () => {
|
it("should handle text chunks with empty strings", () => {
|
||||||
const emptyTextChunk: ApiStreamChunk = {
|
const emptyTextChunk: ApiStreamChunk = {
|
||||||
type: 'text',
|
type: "text",
|
||||||
text: ''
|
text: "",
|
||||||
};
|
}
|
||||||
|
|
||||||
expect(emptyTextChunk.type).toBe('text');
|
expect(emptyTextChunk.type).toBe("text")
|
||||||
expect(emptyTextChunk.text).toBe('');
|
expect(emptyTextChunk.text).toBe("")
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle usage chunks with zero tokens', () => {
|
it("should handle usage chunks with zero tokens", () => {
|
||||||
const zeroUsageChunk: ApiStreamChunk = {
|
const zeroUsageChunk: ApiStreamChunk = {
|
||||||
type: 'usage',
|
type: "usage",
|
||||||
inputTokens: 0,
|
inputTokens: 0,
|
||||||
outputTokens: 0
|
outputTokens: 0,
|
||||||
};
|
}
|
||||||
|
|
||||||
expect(zeroUsageChunk.type).toBe('usage');
|
expect(zeroUsageChunk.type).toBe("usage")
|
||||||
expect(zeroUsageChunk.inputTokens).toBe(0);
|
expect(zeroUsageChunk.inputTokens).toBe(0)
|
||||||
expect(zeroUsageChunk.outputTokens).toBe(0);
|
expect(zeroUsageChunk.outputTokens).toBe(0)
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle usage chunks with large token counts', () => {
|
it("should handle usage chunks with large token counts", () => {
|
||||||
const largeUsageChunk: ApiStreamChunk = {
|
const largeUsageChunk: ApiStreamChunk = {
|
||||||
type: 'usage',
|
type: "usage",
|
||||||
inputTokens: 1000000,
|
inputTokens: 1000000,
|
||||||
outputTokens: 500000,
|
outputTokens: 500000,
|
||||||
cacheWriteTokens: 200000,
|
cacheWriteTokens: 200000,
|
||||||
cacheReadTokens: 100000
|
cacheReadTokens: 100000,
|
||||||
};
|
}
|
||||||
|
|
||||||
expect(largeUsageChunk.type).toBe('usage');
|
expect(largeUsageChunk.type).toBe("usage")
|
||||||
expect(largeUsageChunk.inputTokens).toBe(1000000);
|
expect(largeUsageChunk.inputTokens).toBe(1000000)
|
||||||
expect(largeUsageChunk.outputTokens).toBe(500000);
|
expect(largeUsageChunk.outputTokens).toBe(500000)
|
||||||
expect(largeUsageChunk.cacheWriteTokens).toBe(200000);
|
expect(largeUsageChunk.cacheWriteTokens).toBe(200000)
|
||||||
expect(largeUsageChunk.cacheReadTokens).toBe(100000);
|
expect(largeUsageChunk.cacheReadTokens).toBe(100000)
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle text chunks with special characters', () => {
|
it("should handle text chunks with special characters", () => {
|
||||||
const specialCharsChunk: ApiStreamChunk = {
|
const specialCharsChunk: ApiStreamChunk = {
|
||||||
type: 'text',
|
type: "text",
|
||||||
text: '!@#$%^&*()_+-=[]{}|;:,.<>?`~'
|
text: "!@#$%^&*()_+-=[]{}|;:,.<>?`~",
|
||||||
};
|
}
|
||||||
|
|
||||||
expect(specialCharsChunk.type).toBe('text');
|
expect(specialCharsChunk.type).toBe("text")
|
||||||
expect(specialCharsChunk.text).toBe('!@#$%^&*()_+-=[]{}|;:,.<>?`~');
|
expect(specialCharsChunk.text).toBe("!@#$%^&*()_+-=[]{}|;:,.<>?`~")
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle text chunks with unicode characters', () => {
|
it("should handle text chunks with unicode characters", () => {
|
||||||
const unicodeChunk: ApiStreamChunk = {
|
const unicodeChunk: ApiStreamChunk = {
|
||||||
type: 'text',
|
type: "text",
|
||||||
text: '你好世界👋🌍'
|
text: "你好世界👋🌍",
|
||||||
};
|
}
|
||||||
|
|
||||||
expect(unicodeChunk.type).toBe('text');
|
expect(unicodeChunk.type).toBe("text")
|
||||||
expect(unicodeChunk.text).toBe('你好世界👋🌍');
|
expect(unicodeChunk.text).toBe("你好世界👋🌍")
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle text chunks with multiline content', () => {
|
it("should handle text chunks with multiline content", () => {
|
||||||
const multilineChunk: ApiStreamChunk = {
|
const multilineChunk: ApiStreamChunk = {
|
||||||
type: 'text',
|
type: "text",
|
||||||
text: 'Line 1\nLine 2\nLine 3'
|
text: "Line 1\nLine 2\nLine 3",
|
||||||
};
|
}
|
||||||
|
|
||||||
expect(multilineChunk.type).toBe('text');
|
expect(multilineChunk.type).toBe("text")
|
||||||
expect(multilineChunk.text).toBe('Line 1\nLine 2\nLine 3');
|
expect(multilineChunk.text).toBe("Line 1\nLine 2\nLine 3")
|
||||||
expect(multilineChunk.text.split('\n')).toHaveLength(3);
|
expect(multilineChunk.text.split("\n")).toHaveLength(3)
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|||||||
@@ -1,66 +1,66 @@
|
|||||||
import { Anthropic } from "@anthropic-ai/sdk";
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import * as vscode from 'vscode';
|
import * as vscode from "vscode"
|
||||||
import { convertToVsCodeLmMessages, convertToAnthropicRole, convertToAnthropicMessage } from '../vscode-lm-format';
|
import { convertToVsCodeLmMessages, convertToAnthropicRole, convertToAnthropicMessage } from "../vscode-lm-format"
|
||||||
|
|
||||||
// Mock crypto
|
// Mock crypto
|
||||||
const mockCrypto = {
|
const mockCrypto = {
|
||||||
randomUUID: () => 'test-uuid'
|
randomUUID: () => "test-uuid",
|
||||||
};
|
}
|
||||||
global.crypto = mockCrypto as any;
|
global.crypto = mockCrypto as any
|
||||||
|
|
||||||
// Define types for our mocked classes
|
// Define types for our mocked classes
|
||||||
interface MockLanguageModelTextPart {
|
interface MockLanguageModelTextPart {
|
||||||
type: 'text';
|
type: "text"
|
||||||
value: string;
|
value: string
|
||||||
}
|
}
|
||||||
|
|
||||||
interface MockLanguageModelToolCallPart {
|
interface MockLanguageModelToolCallPart {
|
||||||
type: 'tool_call';
|
type: "tool_call"
|
||||||
callId: string;
|
callId: string
|
||||||
name: string;
|
name: string
|
||||||
input: any;
|
input: any
|
||||||
}
|
}
|
||||||
|
|
||||||
interface MockLanguageModelToolResultPart {
|
interface MockLanguageModelToolResultPart {
|
||||||
type: 'tool_result';
|
type: "tool_result"
|
||||||
toolUseId: string;
|
toolUseId: string
|
||||||
parts: MockLanguageModelTextPart[];
|
parts: MockLanguageModelTextPart[]
|
||||||
}
|
}
|
||||||
|
|
||||||
type MockMessageContent = MockLanguageModelTextPart | MockLanguageModelToolCallPart | MockLanguageModelToolResultPart;
|
type MockMessageContent = MockLanguageModelTextPart | MockLanguageModelToolCallPart | MockLanguageModelToolResultPart
|
||||||
|
|
||||||
interface MockLanguageModelChatMessage {
|
interface MockLanguageModelChatMessage {
|
||||||
role: string;
|
role: string
|
||||||
name?: string;
|
name?: string
|
||||||
content: MockMessageContent[];
|
content: MockMessageContent[]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mock vscode namespace
|
// Mock vscode namespace
|
||||||
jest.mock('vscode', () => {
|
jest.mock("vscode", () => {
|
||||||
const LanguageModelChatMessageRole = {
|
const LanguageModelChatMessageRole = {
|
||||||
Assistant: 'assistant',
|
Assistant: "assistant",
|
||||||
User: 'user'
|
User: "user",
|
||||||
};
|
}
|
||||||
|
|
||||||
class MockLanguageModelTextPart {
|
class MockLanguageModelTextPart {
|
||||||
type = 'text';
|
type = "text"
|
||||||
constructor(public value: string) {}
|
constructor(public value: string) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
class MockLanguageModelToolCallPart {
|
class MockLanguageModelToolCallPart {
|
||||||
type = 'tool_call';
|
type = "tool_call"
|
||||||
constructor(
|
constructor(
|
||||||
public callId: string,
|
public callId: string,
|
||||||
public name: string,
|
public name: string,
|
||||||
public input: any
|
public input: any,
|
||||||
) {}
|
) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
class MockLanguageModelToolResultPart {
|
class MockLanguageModelToolResultPart {
|
||||||
type = 'tool_result';
|
type = "tool_result"
|
||||||
constructor(
|
constructor(
|
||||||
public toolUseId: string,
|
public toolUseId: string,
|
||||||
public parts: MockLanguageModelTextPart[]
|
public parts: MockLanguageModelTextPart[],
|
||||||
) {}
|
) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,179 +68,189 @@ jest.mock('vscode', () => {
|
|||||||
LanguageModelChatMessage: {
|
LanguageModelChatMessage: {
|
||||||
Assistant: jest.fn((content) => ({
|
Assistant: jest.fn((content) => ({
|
||||||
role: LanguageModelChatMessageRole.Assistant,
|
role: LanguageModelChatMessageRole.Assistant,
|
||||||
name: 'assistant',
|
name: "assistant",
|
||||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
|
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
|
||||||
})),
|
})),
|
||||||
User: jest.fn((content) => ({
|
User: jest.fn((content) => ({
|
||||||
role: LanguageModelChatMessageRole.User,
|
role: LanguageModelChatMessageRole.User,
|
||||||
name: 'user',
|
name: "user",
|
||||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
|
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
|
||||||
}))
|
})),
|
||||||
},
|
},
|
||||||
LanguageModelChatMessageRole,
|
LanguageModelChatMessageRole,
|
||||||
LanguageModelTextPart: MockLanguageModelTextPart,
|
LanguageModelTextPart: MockLanguageModelTextPart,
|
||||||
LanguageModelToolCallPart: MockLanguageModelToolCallPart,
|
LanguageModelToolCallPart: MockLanguageModelToolCallPart,
|
||||||
LanguageModelToolResultPart: MockLanguageModelToolResultPart
|
LanguageModelToolResultPart: MockLanguageModelToolResultPart,
|
||||||
};
|
}
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('vscode-lm-format', () => {
|
describe("vscode-lm-format", () => {
|
||||||
describe('convertToVsCodeLmMessages', () => {
|
describe("convertToVsCodeLmMessages", () => {
|
||||||
it('should convert simple string messages', () => {
|
it("should convert simple string messages", () => {
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
{ role: 'user', content: 'Hello' },
|
{ role: "user", content: "Hello" },
|
||||||
{ role: 'assistant', content: 'Hi there' }
|
{ role: "assistant", content: "Hi there" },
|
||||||
];
|
]
|
||||||
|
|
||||||
const result = convertToVsCodeLmMessages(messages);
|
const result = convertToVsCodeLmMessages(messages)
|
||||||
|
|
||||||
expect(result).toHaveLength(2);
|
|
||||||
expect(result[0].role).toBe('user');
|
|
||||||
expect((result[0].content[0] as MockLanguageModelTextPart).value).toBe('Hello');
|
|
||||||
expect(result[1].role).toBe('assistant');
|
|
||||||
expect((result[1].content[0] as MockLanguageModelTextPart).value).toBe('Hi there');
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle complex user messages with tool results', () => {
|
expect(result).toHaveLength(2)
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
expect(result[0].role).toBe("user")
|
||||||
role: 'user',
|
expect((result[0].content[0] as MockLanguageModelTextPart).value).toBe("Hello")
|
||||||
content: [
|
expect(result[1].role).toBe("assistant")
|
||||||
{ type: 'text', text: 'Here is the result:' },
|
expect((result[1].content[0] as MockLanguageModelTextPart).value).toBe("Hi there")
|
||||||
{
|
})
|
||||||
type: 'tool_result',
|
|
||||||
tool_use_id: 'tool-1',
|
|
||||||
content: 'Tool output'
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}];
|
|
||||||
|
|
||||||
const result = convertToVsCodeLmMessages(messages);
|
it("should handle complex user messages with tool results", () => {
|
||||||
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
expect(result).toHaveLength(1);
|
{
|
||||||
expect(result[0].role).toBe('user');
|
role: "user",
|
||||||
expect(result[0].content).toHaveLength(2);
|
content: [
|
||||||
const [toolResult, textContent] = result[0].content as [MockLanguageModelToolResultPart, MockLanguageModelTextPart];
|
{ type: "text", text: "Here is the result:" },
|
||||||
expect(toolResult.type).toBe('tool_result');
|
{
|
||||||
expect(textContent.type).toBe('text');
|
type: "tool_result",
|
||||||
});
|
tool_use_id: "tool-1",
|
||||||
|
content: "Tool output",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
it('should handle complex assistant messages with tool calls', () => {
|
const result = convertToVsCodeLmMessages(messages)
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
|
||||||
role: 'assistant',
|
|
||||||
content: [
|
|
||||||
{ type: 'text', text: 'Let me help you with that.' },
|
|
||||||
{
|
|
||||||
type: 'tool_use',
|
|
||||||
id: 'tool-1',
|
|
||||||
name: 'calculator',
|
|
||||||
input: { operation: 'add', numbers: [2, 2] }
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}];
|
|
||||||
|
|
||||||
const result = convertToVsCodeLmMessages(messages);
|
expect(result).toHaveLength(1)
|
||||||
|
expect(result[0].role).toBe("user")
|
||||||
expect(result).toHaveLength(1);
|
expect(result[0].content).toHaveLength(2)
|
||||||
expect(result[0].role).toBe('assistant');
|
const [toolResult, textContent] = result[0].content as [
|
||||||
expect(result[0].content).toHaveLength(2);
|
MockLanguageModelToolResultPart,
|
||||||
const [toolCall, textContent] = result[0].content as [MockLanguageModelToolCallPart, MockLanguageModelTextPart];
|
MockLanguageModelTextPart,
|
||||||
expect(toolCall.type).toBe('tool_call');
|
]
|
||||||
expect(textContent.type).toBe('text');
|
expect(toolResult.type).toBe("tool_result")
|
||||||
});
|
expect(textContent.type).toBe("text")
|
||||||
|
})
|
||||||
|
|
||||||
it('should handle image blocks with appropriate placeholders', () => {
|
it("should handle complex assistant messages with tool calls", () => {
|
||||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
role: 'user',
|
{
|
||||||
content: [
|
role: "assistant",
|
||||||
{ type: 'text', text: 'Look at this:' },
|
content: [
|
||||||
{
|
{ type: "text", text: "Let me help you with that." },
|
||||||
type: 'image',
|
{
|
||||||
source: {
|
type: "tool_use",
|
||||||
type: 'base64',
|
id: "tool-1",
|
||||||
media_type: 'image/png',
|
name: "calculator",
|
||||||
data: 'base64data'
|
input: { operation: "add", numbers: [2, 2] },
|
||||||
}
|
},
|
||||||
}
|
],
|
||||||
]
|
},
|
||||||
}];
|
]
|
||||||
|
|
||||||
const result = convertToVsCodeLmMessages(messages);
|
const result = convertToVsCodeLmMessages(messages)
|
||||||
|
|
||||||
expect(result).toHaveLength(1);
|
|
||||||
const imagePlaceholder = result[0].content[1] as MockLanguageModelTextPart;
|
|
||||||
expect(imagePlaceholder.value).toContain('[Image (base64): image/png not supported by VSCode LM API]');
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('convertToAnthropicRole', () => {
|
expect(result).toHaveLength(1)
|
||||||
it('should convert assistant role correctly', () => {
|
expect(result[0].role).toBe("assistant")
|
||||||
const result = convertToAnthropicRole('assistant' as any);
|
expect(result[0].content).toHaveLength(2)
|
||||||
expect(result).toBe('assistant');
|
const [toolCall, textContent] = result[0].content as [
|
||||||
});
|
MockLanguageModelToolCallPart,
|
||||||
|
MockLanguageModelTextPart,
|
||||||
|
]
|
||||||
|
expect(toolCall.type).toBe("tool_call")
|
||||||
|
expect(textContent.type).toBe("text")
|
||||||
|
})
|
||||||
|
|
||||||
it('should convert user role correctly', () => {
|
it("should handle image blocks with appropriate placeholders", () => {
|
||||||
const result = convertToAnthropicRole('user' as any);
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
expect(result).toBe('user');
|
{
|
||||||
});
|
role: "user",
|
||||||
|
content: [
|
||||||
|
{ type: "text", text: "Look at this:" },
|
||||||
|
{
|
||||||
|
type: "image",
|
||||||
|
source: {
|
||||||
|
type: "base64",
|
||||||
|
media_type: "image/png",
|
||||||
|
data: "base64data",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
it('should return null for unknown roles', () => {
|
const result = convertToVsCodeLmMessages(messages)
|
||||||
const result = convertToAnthropicRole('unknown' as any);
|
|
||||||
expect(result).toBeNull();
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('convertToAnthropicMessage', () => {
|
expect(result).toHaveLength(1)
|
||||||
it('should convert assistant message with text content', async () => {
|
const imagePlaceholder = result[0].content[1] as MockLanguageModelTextPart
|
||||||
|
expect(imagePlaceholder.value).toContain("[Image (base64): image/png not supported by VSCode LM API]")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("convertToAnthropicRole", () => {
|
||||||
|
it("should convert assistant role correctly", () => {
|
||||||
|
const result = convertToAnthropicRole("assistant" as any)
|
||||||
|
expect(result).toBe("assistant")
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should convert user role correctly", () => {
|
||||||
|
const result = convertToAnthropicRole("user" as any)
|
||||||
|
expect(result).toBe("user")
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should return null for unknown roles", () => {
|
||||||
|
const result = convertToAnthropicRole("unknown" as any)
|
||||||
|
expect(result).toBeNull()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("convertToAnthropicMessage", () => {
|
||||||
|
it("should convert assistant message with text content", async () => {
|
||||||
const vsCodeMessage = {
|
const vsCodeMessage = {
|
||||||
role: 'assistant',
|
role: "assistant",
|
||||||
name: 'assistant',
|
name: "assistant",
|
||||||
content: [new vscode.LanguageModelTextPart('Hello')]
|
content: [new vscode.LanguageModelTextPart("Hello")],
|
||||||
};
|
}
|
||||||
|
|
||||||
const result = await convertToAnthropicMessage(vsCodeMessage as any);
|
const result = await convertToAnthropicMessage(vsCodeMessage as any)
|
||||||
|
|
||||||
expect(result.role).toBe('assistant');
|
expect(result.role).toBe("assistant")
|
||||||
expect(result.content).toHaveLength(1);
|
expect(result.content).toHaveLength(1)
|
||||||
expect(result.content[0]).toEqual({
|
expect(result.content[0]).toEqual({
|
||||||
type: 'text',
|
type: "text",
|
||||||
text: 'Hello'
|
text: "Hello",
|
||||||
});
|
})
|
||||||
expect(result.id).toBe('test-uuid');
|
expect(result.id).toBe("test-uuid")
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should convert assistant message with tool calls', async () => {
|
it("should convert assistant message with tool calls", async () => {
|
||||||
const vsCodeMessage = {
|
const vsCodeMessage = {
|
||||||
role: 'assistant',
|
role: "assistant",
|
||||||
name: 'assistant',
|
name: "assistant",
|
||||||
content: [new vscode.LanguageModelToolCallPart(
|
content: [
|
||||||
'call-1',
|
new vscode.LanguageModelToolCallPart("call-1", "calculator", { operation: "add", numbers: [2, 2] }),
|
||||||
'calculator',
|
],
|
||||||
{ operation: 'add', numbers: [2, 2] }
|
}
|
||||||
)]
|
|
||||||
};
|
|
||||||
|
|
||||||
const result = await convertToAnthropicMessage(vsCodeMessage as any);
|
const result = await convertToAnthropicMessage(vsCodeMessage as any)
|
||||||
|
|
||||||
expect(result.content).toHaveLength(1);
|
expect(result.content).toHaveLength(1)
|
||||||
expect(result.content[0]).toEqual({
|
expect(result.content[0]).toEqual({
|
||||||
type: 'tool_use',
|
type: "tool_use",
|
||||||
id: 'call-1',
|
id: "call-1",
|
||||||
name: 'calculator',
|
name: "calculator",
|
||||||
input: { operation: 'add', numbers: [2, 2] }
|
input: { operation: "add", numbers: [2, 2] },
|
||||||
});
|
})
|
||||||
expect(result.id).toBe('test-uuid');
|
expect(result.id).toBe("test-uuid")
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should throw error for non-assistant messages', async () => {
|
it("should throw error for non-assistant messages", async () => {
|
||||||
const vsCodeMessage = {
|
const vsCodeMessage = {
|
||||||
role: 'user',
|
role: "user",
|
||||||
name: 'user',
|
name: "user",
|
||||||
content: [new vscode.LanguageModelTextPart('Hello')]
|
content: [new vscode.LanguageModelTextPart("Hello")],
|
||||||
};
|
}
|
||||||
|
|
||||||
await expect(convertToAnthropicMessage(vsCodeMessage as any))
|
await expect(convertToAnthropicMessage(vsCodeMessage as any)).rejects.toThrow(
|
||||||
.rejects
|
"Cline <Language Model API>: Only assistant messages are supported.",
|
||||||
.toThrow('Cline <Language Model API>: Only assistant messages are supported.');
|
)
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|||||||
@@ -8,210 +8,216 @@ import { StreamEvent } from "../providers/bedrock"
|
|||||||
/**
|
/**
|
||||||
* Convert Anthropic messages to Bedrock Converse format
|
* Convert Anthropic messages to Bedrock Converse format
|
||||||
*/
|
*/
|
||||||
export function convertToBedrockConverseMessages(
|
export function convertToBedrockConverseMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): Message[] {
|
||||||
anthropicMessages: Anthropic.Messages.MessageParam[]
|
return anthropicMessages.map((anthropicMessage) => {
|
||||||
): Message[] {
|
// Map Anthropic roles to Bedrock roles
|
||||||
return anthropicMessages.map(anthropicMessage => {
|
const role: ConversationRole = anthropicMessage.role === "assistant" ? "assistant" : "user"
|
||||||
// Map Anthropic roles to Bedrock roles
|
|
||||||
const role: ConversationRole = anthropicMessage.role === "assistant" ? "assistant" : "user"
|
|
||||||
|
|
||||||
if (typeof anthropicMessage.content === "string") {
|
if (typeof anthropicMessage.content === "string") {
|
||||||
return {
|
return {
|
||||||
role,
|
role,
|
||||||
content: [{
|
content: [
|
||||||
text: anthropicMessage.content
|
{
|
||||||
}] as ContentBlock[]
|
text: anthropicMessage.content,
|
||||||
}
|
},
|
||||||
}
|
] as ContentBlock[],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Process complex content types
|
// Process complex content types
|
||||||
const content = anthropicMessage.content.map(block => {
|
const content = anthropicMessage.content.map((block) => {
|
||||||
const messageBlock = block as MessageContent & {
|
const messageBlock = block as MessageContent & {
|
||||||
id?: string,
|
id?: string
|
||||||
tool_use_id?: string,
|
tool_use_id?: string
|
||||||
content?: Array<{ type: string, text: string }>,
|
content?: Array<{ type: string; text: string }>
|
||||||
output?: string | Array<{ type: string, text: string }>
|
output?: string | Array<{ type: string; text: string }>
|
||||||
}
|
}
|
||||||
|
|
||||||
if (messageBlock.type === "text") {
|
if (messageBlock.type === "text") {
|
||||||
return {
|
return {
|
||||||
text: messageBlock.text || ''
|
text: messageBlock.text || "",
|
||||||
} as ContentBlock
|
} as ContentBlock
|
||||||
}
|
}
|
||||||
|
|
||||||
if (messageBlock.type === "image" && messageBlock.source) {
|
|
||||||
// Convert base64 string to byte array if needed
|
|
||||||
let byteArray: Uint8Array
|
|
||||||
if (typeof messageBlock.source.data === 'string') {
|
|
||||||
const binaryString = atob(messageBlock.source.data)
|
|
||||||
byteArray = new Uint8Array(binaryString.length)
|
|
||||||
for (let i = 0; i < binaryString.length; i++) {
|
|
||||||
byteArray[i] = binaryString.charCodeAt(i)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
byteArray = messageBlock.source.data
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract format from media_type (e.g., "image/jpeg" -> "jpeg")
|
if (messageBlock.type === "image" && messageBlock.source) {
|
||||||
const format = messageBlock.source.media_type.split('/')[1]
|
// Convert base64 string to byte array if needed
|
||||||
if (!['png', 'jpeg', 'gif', 'webp'].includes(format)) {
|
let byteArray: Uint8Array
|
||||||
throw new Error(`Unsupported image format: ${format}`)
|
if (typeof messageBlock.source.data === "string") {
|
||||||
}
|
const binaryString = atob(messageBlock.source.data)
|
||||||
|
byteArray = new Uint8Array(binaryString.length)
|
||||||
|
for (let i = 0; i < binaryString.length; i++) {
|
||||||
|
byteArray[i] = binaryString.charCodeAt(i)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
byteArray = messageBlock.source.data
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
// Extract format from media_type (e.g., "image/jpeg" -> "jpeg")
|
||||||
image: {
|
const format = messageBlock.source.media_type.split("/")[1]
|
||||||
format: format as "png" | "jpeg" | "gif" | "webp",
|
if (!["png", "jpeg", "gif", "webp"].includes(format)) {
|
||||||
source: {
|
throw new Error(`Unsupported image format: ${format}`)
|
||||||
bytes: byteArray
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
} as ContentBlock
|
|
||||||
}
|
|
||||||
|
|
||||||
if (messageBlock.type === "tool_use") {
|
return {
|
||||||
// Convert tool use to XML format
|
image: {
|
||||||
const toolParams = Object.entries(messageBlock.input || {})
|
format: format as "png" | "jpeg" | "gif" | "webp",
|
||||||
.map(([key, value]) => `<${key}>\n${value}\n</${key}>`)
|
source: {
|
||||||
.join('\n')
|
bytes: byteArray,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
} as ContentBlock
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
if (messageBlock.type === "tool_use") {
|
||||||
toolUse: {
|
// Convert tool use to XML format
|
||||||
toolUseId: messageBlock.id || '',
|
const toolParams = Object.entries(messageBlock.input || {})
|
||||||
name: messageBlock.name || '',
|
.map(([key, value]) => `<${key}>\n${value}\n</${key}>`)
|
||||||
input: `<${messageBlock.name}>\n${toolParams}\n</${messageBlock.name}>`
|
.join("\n")
|
||||||
}
|
|
||||||
} as ContentBlock
|
|
||||||
}
|
|
||||||
|
|
||||||
if (messageBlock.type === "tool_result") {
|
return {
|
||||||
// First try to use content if available
|
toolUse: {
|
||||||
if (messageBlock.content && Array.isArray(messageBlock.content)) {
|
toolUseId: messageBlock.id || "",
|
||||||
return {
|
name: messageBlock.name || "",
|
||||||
toolResult: {
|
input: `<${messageBlock.name}>\n${toolParams}\n</${messageBlock.name}>`,
|
||||||
toolUseId: messageBlock.tool_use_id || '',
|
},
|
||||||
content: messageBlock.content.map(item => ({
|
} as ContentBlock
|
||||||
text: item.text
|
}
|
||||||
})),
|
|
||||||
status: "success"
|
|
||||||
}
|
|
||||||
} as ContentBlock
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fall back to output handling if content is not available
|
if (messageBlock.type === "tool_result") {
|
||||||
if (messageBlock.output && typeof messageBlock.output === "string") {
|
// First try to use content if available
|
||||||
return {
|
if (messageBlock.content && Array.isArray(messageBlock.content)) {
|
||||||
toolResult: {
|
return {
|
||||||
toolUseId: messageBlock.tool_use_id || '',
|
toolResult: {
|
||||||
content: [{
|
toolUseId: messageBlock.tool_use_id || "",
|
||||||
text: messageBlock.output
|
content: messageBlock.content.map((item) => ({
|
||||||
}],
|
text: item.text,
|
||||||
status: "success"
|
})),
|
||||||
}
|
status: "success",
|
||||||
} as ContentBlock
|
},
|
||||||
}
|
} as ContentBlock
|
||||||
// Handle array of content blocks if output is an array
|
}
|
||||||
if (Array.isArray(messageBlock.output)) {
|
|
||||||
return {
|
|
||||||
toolResult: {
|
|
||||||
toolUseId: messageBlock.tool_use_id || '',
|
|
||||||
content: messageBlock.output.map(part => {
|
|
||||||
if (typeof part === "object" && "text" in part) {
|
|
||||||
return { text: part.text }
|
|
||||||
}
|
|
||||||
// Skip images in tool results as they're handled separately
|
|
||||||
if (typeof part === "object" && "type" in part && part.type === "image") {
|
|
||||||
return { text: "(see following message for image)" }
|
|
||||||
}
|
|
||||||
return { text: String(part) }
|
|
||||||
}),
|
|
||||||
status: "success"
|
|
||||||
}
|
|
||||||
} as ContentBlock
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default case
|
// Fall back to output handling if content is not available
|
||||||
return {
|
if (messageBlock.output && typeof messageBlock.output === "string") {
|
||||||
toolResult: {
|
return {
|
||||||
toolUseId: messageBlock.tool_use_id || '',
|
toolResult: {
|
||||||
content: [{
|
toolUseId: messageBlock.tool_use_id || "",
|
||||||
text: String(messageBlock.output || '')
|
content: [
|
||||||
}],
|
{
|
||||||
status: "success"
|
text: messageBlock.output,
|
||||||
}
|
},
|
||||||
} as ContentBlock
|
],
|
||||||
}
|
status: "success",
|
||||||
|
},
|
||||||
|
} as ContentBlock
|
||||||
|
}
|
||||||
|
// Handle array of content blocks if output is an array
|
||||||
|
if (Array.isArray(messageBlock.output)) {
|
||||||
|
return {
|
||||||
|
toolResult: {
|
||||||
|
toolUseId: messageBlock.tool_use_id || "",
|
||||||
|
content: messageBlock.output.map((part) => {
|
||||||
|
if (typeof part === "object" && "text" in part) {
|
||||||
|
return { text: part.text }
|
||||||
|
}
|
||||||
|
// Skip images in tool results as they're handled separately
|
||||||
|
if (typeof part === "object" && "type" in part && part.type === "image") {
|
||||||
|
return { text: "(see following message for image)" }
|
||||||
|
}
|
||||||
|
return { text: String(part) }
|
||||||
|
}),
|
||||||
|
status: "success",
|
||||||
|
},
|
||||||
|
} as ContentBlock
|
||||||
|
}
|
||||||
|
|
||||||
if (messageBlock.type === "video") {
|
// Default case
|
||||||
const videoContent = messageBlock.s3Location ? {
|
return {
|
||||||
s3Location: {
|
toolResult: {
|
||||||
uri: messageBlock.s3Location.uri,
|
toolUseId: messageBlock.tool_use_id || "",
|
||||||
bucketOwner: messageBlock.s3Location.bucketOwner
|
content: [
|
||||||
}
|
{
|
||||||
} : messageBlock.source
|
text: String(messageBlock.output || ""),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
status: "success",
|
||||||
|
},
|
||||||
|
} as ContentBlock
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
if (messageBlock.type === "video") {
|
||||||
video: {
|
const videoContent = messageBlock.s3Location
|
||||||
format: "mp4", // Default to mp4, adjust based on actual format if needed
|
? {
|
||||||
source: videoContent
|
s3Location: {
|
||||||
}
|
uri: messageBlock.s3Location.uri,
|
||||||
} as ContentBlock
|
bucketOwner: messageBlock.s3Location.bucketOwner,
|
||||||
}
|
},
|
||||||
|
}
|
||||||
|
: messageBlock.source
|
||||||
|
|
||||||
// Default case for unknown block types
|
return {
|
||||||
return {
|
video: {
|
||||||
text: '[Unknown Block Type]'
|
format: "mp4", // Default to mp4, adjust based on actual format if needed
|
||||||
} as ContentBlock
|
source: videoContent,
|
||||||
})
|
},
|
||||||
|
} as ContentBlock
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
// Default case for unknown block types
|
||||||
role,
|
return {
|
||||||
content
|
text: "[Unknown Block Type]",
|
||||||
}
|
} as ContentBlock
|
||||||
})
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
role,
|
||||||
|
content,
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert Bedrock Converse stream events to Anthropic message format
|
* Convert Bedrock Converse stream events to Anthropic message format
|
||||||
*/
|
*/
|
||||||
export function convertToAnthropicMessage(
|
export function convertToAnthropicMessage(
|
||||||
streamEvent: StreamEvent,
|
streamEvent: StreamEvent,
|
||||||
modelId: string
|
modelId: string,
|
||||||
): Partial<Anthropic.Messages.Message> {
|
): Partial<Anthropic.Messages.Message> {
|
||||||
// Handle metadata events
|
// Handle metadata events
|
||||||
if (streamEvent.metadata?.usage) {
|
if (streamEvent.metadata?.usage) {
|
||||||
return {
|
return {
|
||||||
id: '', // Bedrock doesn't provide message IDs
|
id: "", // Bedrock doesn't provide message IDs
|
||||||
type: "message",
|
type: "message",
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
model: modelId,
|
model: modelId,
|
||||||
usage: {
|
usage: {
|
||||||
input_tokens: streamEvent.metadata.usage.inputTokens || 0,
|
input_tokens: streamEvent.metadata.usage.inputTokens || 0,
|
||||||
output_tokens: streamEvent.metadata.usage.outputTokens || 0
|
output_tokens: streamEvent.metadata.usage.outputTokens || 0,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle content blocks
|
// Handle content blocks
|
||||||
const text = streamEvent.contentBlockStart?.start?.text || streamEvent.contentBlockDelta?.delta?.text
|
const text = streamEvent.contentBlockStart?.start?.text || streamEvent.contentBlockDelta?.delta?.text
|
||||||
if (text !== undefined) {
|
if (text !== undefined) {
|
||||||
return {
|
return {
|
||||||
type: "message",
|
type: "message",
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: [{ type: "text", text: text }],
|
content: [{ type: "text", text: text }],
|
||||||
model: modelId
|
model: modelId,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle message stop
|
// Handle message stop
|
||||||
if (streamEvent.messageStop) {
|
if (streamEvent.messageStop) {
|
||||||
return {
|
return {
|
||||||
type: "message",
|
type: "message",
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
stop_reason: streamEvent.messageStop.stopReason || null,
|
stop_reason: streamEvent.messageStop.stopReason || null,
|
||||||
stop_sequence: null,
|
stop_sequence: null,
|
||||||
model: modelId
|
model: modelId,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { Anthropic } from "@anthropic-ai/sdk";
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
import * as vscode from 'vscode';
|
import * as vscode from "vscode"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Safely converts a value into a plain object.
|
* Safely converts a value into a plain object.
|
||||||
@@ -7,30 +7,31 @@ import * as vscode from 'vscode';
|
|||||||
function asObjectSafe(value: any): object {
|
function asObjectSafe(value: any): object {
|
||||||
// Handle null/undefined
|
// Handle null/undefined
|
||||||
if (!value) {
|
if (!value) {
|
||||||
return {};
|
return {}
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Handle strings that might be JSON
|
// Handle strings that might be JSON
|
||||||
if (typeof value === 'string') {
|
if (typeof value === "string") {
|
||||||
return JSON.parse(value);
|
return JSON.parse(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle pre-existing objects
|
// Handle pre-existing objects
|
||||||
if (typeof value === 'object') {
|
if (typeof value === "object") {
|
||||||
return Object.assign({}, value);
|
return Object.assign({}, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
return {};
|
return {}
|
||||||
}
|
} catch (error) {
|
||||||
catch (error) {
|
console.warn("Cline <Language Model API>: Failed to parse object:", error)
|
||||||
console.warn('Cline <Language Model API>: Failed to parse object:', error);
|
return {}
|
||||||
return {};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export function convertToVsCodeLmMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): vscode.LanguageModelChatMessage[] {
|
export function convertToVsCodeLmMessages(
|
||||||
const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = [];
|
anthropicMessages: Anthropic.Messages.MessageParam[],
|
||||||
|
): vscode.LanguageModelChatMessage[] {
|
||||||
|
const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = []
|
||||||
|
|
||||||
for (const anthropicMessage of anthropicMessages) {
|
for (const anthropicMessage of anthropicMessages) {
|
||||||
// Handle simple string messages
|
// Handle simple string messages
|
||||||
@@ -38,135 +39,129 @@ export function convertToVsCodeLmMessages(anthropicMessages: Anthropic.Messages.
|
|||||||
vsCodeLmMessages.push(
|
vsCodeLmMessages.push(
|
||||||
anthropicMessage.role === "assistant"
|
anthropicMessage.role === "assistant"
|
||||||
? vscode.LanguageModelChatMessage.Assistant(anthropicMessage.content)
|
? vscode.LanguageModelChatMessage.Assistant(anthropicMessage.content)
|
||||||
: vscode.LanguageModelChatMessage.User(anthropicMessage.content)
|
: vscode.LanguageModelChatMessage.User(anthropicMessage.content),
|
||||||
);
|
)
|
||||||
continue;
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle complex message structures
|
// Handle complex message structures
|
||||||
switch (anthropicMessage.role) {
|
switch (anthropicMessage.role) {
|
||||||
case "user": {
|
case "user": {
|
||||||
const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
|
const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
|
||||||
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[];
|
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
|
||||||
toolMessages: Anthropic.ToolResultBlockParam[];
|
toolMessages: Anthropic.ToolResultBlockParam[]
|
||||||
}>(
|
}>(
|
||||||
(acc, part) => {
|
(acc, part) => {
|
||||||
if (part.type === "tool_result") {
|
if (part.type === "tool_result") {
|
||||||
acc.toolMessages.push(part);
|
acc.toolMessages.push(part)
|
||||||
|
} else if (part.type === "text" || part.type === "image") {
|
||||||
|
acc.nonToolMessages.push(part)
|
||||||
}
|
}
|
||||||
else if (part.type === "text" || part.type === "image") {
|
return acc
|
||||||
acc.nonToolMessages.push(part);
|
|
||||||
}
|
|
||||||
return acc;
|
|
||||||
},
|
},
|
||||||
{ nonToolMessages: [], toolMessages: [] },
|
{ nonToolMessages: [], toolMessages: [] },
|
||||||
);
|
)
|
||||||
|
|
||||||
// Process tool messages first then non-tool messages
|
// Process tool messages first then non-tool messages
|
||||||
const contentParts = [
|
const contentParts = [
|
||||||
// Convert tool messages to ToolResultParts
|
// Convert tool messages to ToolResultParts
|
||||||
...toolMessages.map((toolMessage) => {
|
...toolMessages.map((toolMessage) => {
|
||||||
// Process tool result content into TextParts
|
// Process tool result content into TextParts
|
||||||
const toolContentParts: vscode.LanguageModelTextPart[] = (
|
const toolContentParts: vscode.LanguageModelTextPart[] =
|
||||||
typeof toolMessage.content === "string"
|
typeof toolMessage.content === "string"
|
||||||
? [new vscode.LanguageModelTextPart(toolMessage.content)]
|
? [new vscode.LanguageModelTextPart(toolMessage.content)]
|
||||||
: (
|
: (toolMessage.content?.map((part) => {
|
||||||
toolMessage.content?.map((part) => {
|
|
||||||
if (part.type === "image") {
|
if (part.type === "image") {
|
||||||
return new vscode.LanguageModelTextPart(
|
return new vscode.LanguageModelTextPart(
|
||||||
`[Image (${part.source?.type || 'Unknown source-type'}): ${part.source?.media_type || 'unknown media-type'} not supported by VSCode LM API]`
|
`[Image (${part.source?.type || "Unknown source-type"}): ${part.source?.media_type || "unknown media-type"} not supported by VSCode LM API]`,
|
||||||
);
|
)
|
||||||
}
|
}
|
||||||
return new vscode.LanguageModelTextPart(part.text);
|
return new vscode.LanguageModelTextPart(part.text)
|
||||||
})
|
}) ?? [new vscode.LanguageModelTextPart("")])
|
||||||
?? [new vscode.LanguageModelTextPart("")]
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
return new vscode.LanguageModelToolResultPart(
|
return new vscode.LanguageModelToolResultPart(toolMessage.tool_use_id, toolContentParts)
|
||||||
toolMessage.tool_use_id,
|
|
||||||
toolContentParts
|
|
||||||
);
|
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// Convert non-tool messages to TextParts after tool messages
|
// Convert non-tool messages to TextParts after tool messages
|
||||||
...nonToolMessages.map((part) => {
|
...nonToolMessages.map((part) => {
|
||||||
if (part.type === "image") {
|
if (part.type === "image") {
|
||||||
return new vscode.LanguageModelTextPart(
|
return new vscode.LanguageModelTextPart(
|
||||||
`[Image (${part.source?.type || 'Unknown source-type'}): ${part.source?.media_type || 'unknown media-type'} not supported by VSCode LM API]`
|
`[Image (${part.source?.type || "Unknown source-type"}): ${part.source?.media_type || "unknown media-type"} not supported by VSCode LM API]`,
|
||||||
);
|
)
|
||||||
}
|
}
|
||||||
return new vscode.LanguageModelTextPart(part.text);
|
return new vscode.LanguageModelTextPart(part.text)
|
||||||
})
|
}),
|
||||||
];
|
]
|
||||||
|
|
||||||
// Add single user message with all content parts
|
// Add single user message with all content parts
|
||||||
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.User(contentParts));
|
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.User(contentParts))
|
||||||
break;
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
case "assistant": {
|
case "assistant": {
|
||||||
const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
|
const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
|
||||||
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[];
|
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
|
||||||
toolMessages: Anthropic.ToolUseBlockParam[];
|
toolMessages: Anthropic.ToolUseBlockParam[]
|
||||||
}>(
|
}>(
|
||||||
(acc, part) => {
|
(acc, part) => {
|
||||||
if (part.type === "tool_use") {
|
if (part.type === "tool_use") {
|
||||||
acc.toolMessages.push(part);
|
acc.toolMessages.push(part)
|
||||||
|
} else if (part.type === "text" || part.type === "image") {
|
||||||
|
acc.nonToolMessages.push(part)
|
||||||
}
|
}
|
||||||
else if (part.type === "text" || part.type === "image") {
|
return acc
|
||||||
acc.nonToolMessages.push(part);
|
|
||||||
}
|
|
||||||
return acc;
|
|
||||||
},
|
},
|
||||||
{ nonToolMessages: [], toolMessages: [] },
|
{ nonToolMessages: [], toolMessages: [] },
|
||||||
);
|
)
|
||||||
|
|
||||||
// Process tool messages first then non-tool messages
|
// Process tool messages first then non-tool messages
|
||||||
const contentParts = [
|
const contentParts = [
|
||||||
// Convert tool messages to ToolCallParts first
|
// Convert tool messages to ToolCallParts first
|
||||||
...toolMessages.map((toolMessage) =>
|
...toolMessages.map(
|
||||||
new vscode.LanguageModelToolCallPart(
|
(toolMessage) =>
|
||||||
toolMessage.id,
|
new vscode.LanguageModelToolCallPart(
|
||||||
toolMessage.name,
|
toolMessage.id,
|
||||||
asObjectSafe(toolMessage.input)
|
toolMessage.name,
|
||||||
)
|
asObjectSafe(toolMessage.input),
|
||||||
|
),
|
||||||
),
|
),
|
||||||
|
|
||||||
// Convert non-tool messages to TextParts after tool messages
|
// Convert non-tool messages to TextParts after tool messages
|
||||||
...nonToolMessages.map((part) => {
|
...nonToolMessages.map((part) => {
|
||||||
if (part.type === "image") {
|
if (part.type === "image") {
|
||||||
return new vscode.LanguageModelTextPart("[Image generation not supported by VSCode LM API]");
|
return new vscode.LanguageModelTextPart("[Image generation not supported by VSCode LM API]")
|
||||||
}
|
}
|
||||||
return new vscode.LanguageModelTextPart(part.text);
|
return new vscode.LanguageModelTextPart(part.text)
|
||||||
})
|
}),
|
||||||
];
|
]
|
||||||
|
|
||||||
// Add the assistant message to the list of messages
|
// Add the assistant message to the list of messages
|
||||||
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.Assistant(contentParts));
|
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.Assistant(contentParts))
|
||||||
break;
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return vsCodeLmMessages;
|
return vsCodeLmMessages
|
||||||
}
|
}
|
||||||
|
|
||||||
export function convertToAnthropicRole(vsCodeLmMessageRole: vscode.LanguageModelChatMessageRole): string | null {
|
export function convertToAnthropicRole(vsCodeLmMessageRole: vscode.LanguageModelChatMessageRole): string | null {
|
||||||
switch (vsCodeLmMessageRole) {
|
switch (vsCodeLmMessageRole) {
|
||||||
case vscode.LanguageModelChatMessageRole.Assistant:
|
case vscode.LanguageModelChatMessageRole.Assistant:
|
||||||
return "assistant";
|
return "assistant"
|
||||||
case vscode.LanguageModelChatMessageRole.User:
|
case vscode.LanguageModelChatMessageRole.User:
|
||||||
return "user";
|
return "user"
|
||||||
default:
|
default:
|
||||||
return null;
|
return null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function convertToAnthropicMessage(vsCodeLmMessage: vscode.LanguageModelChatMessage): Promise<Anthropic.Messages.Message> {
|
export async function convertToAnthropicMessage(
|
||||||
const anthropicRole: string | null = convertToAnthropicRole(vsCodeLmMessage.role);
|
vsCodeLmMessage: vscode.LanguageModelChatMessage,
|
||||||
|
): Promise<Anthropic.Messages.Message> {
|
||||||
|
const anthropicRole: string | null = convertToAnthropicRole(vsCodeLmMessage.role)
|
||||||
if (anthropicRole !== "assistant") {
|
if (anthropicRole !== "assistant") {
|
||||||
throw new Error("Cline <Language Model API>: Only assistant messages are supported.");
|
throw new Error("Cline <Language Model API>: Only assistant messages are supported.")
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -174,36 +169,32 @@ export async function convertToAnthropicMessage(vsCodeLmMessage: vscode.Language
|
|||||||
type: "message",
|
type: "message",
|
||||||
model: "vscode-lm",
|
model: "vscode-lm",
|
||||||
role: anthropicRole,
|
role: anthropicRole,
|
||||||
content: (
|
content: vsCodeLmMessage.content
|
||||||
vsCodeLmMessage.content
|
.map((part): Anthropic.ContentBlock | null => {
|
||||||
.map((part): Anthropic.ContentBlock | null => {
|
if (part instanceof vscode.LanguageModelTextPart) {
|
||||||
if (part instanceof vscode.LanguageModelTextPart) {
|
return {
|
||||||
return {
|
type: "text",
|
||||||
type: "text",
|
text: part.value,
|
||||||
text: part.value
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (part instanceof vscode.LanguageModelToolCallPart) {
|
if (part instanceof vscode.LanguageModelToolCallPart) {
|
||||||
return {
|
return {
|
||||||
type: "tool_use",
|
type: "tool_use",
|
||||||
id: part.callId || crypto.randomUUID(),
|
id: part.callId || crypto.randomUUID(),
|
||||||
name: part.name,
|
name: part.name,
|
||||||
input: asObjectSafe(part.input)
|
input: asObjectSafe(part.input),
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return null;
|
return null
|
||||||
})
|
})
|
||||||
.filter(
|
.filter((part): part is Anthropic.ContentBlock => part !== null),
|
||||||
(part): part is Anthropic.ContentBlock => part !== null
|
|
||||||
)
|
|
||||||
),
|
|
||||||
stop_reason: null,
|
stop_reason: null,
|
||||||
stop_sequence: null,
|
stop_sequence: null,
|
||||||
usage: {
|
usage: {
|
||||||
input_tokens: 0,
|
input_tokens: 0,
|
||||||
output_tokens: 0,
|
output_tokens: 0,
|
||||||
}
|
},
|
||||||
};
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,13 @@ import { ApiHandler, SingleCompletionHandler, buildApiHandler } from "../api"
|
|||||||
import { ApiStream } from "../api/transform/stream"
|
import { ApiStream } from "../api/transform/stream"
|
||||||
import { DiffViewProvider } from "../integrations/editor/DiffViewProvider"
|
import { DiffViewProvider } from "../integrations/editor/DiffViewProvider"
|
||||||
import { findToolName, formatContentBlockToMarkdown } from "../integrations/misc/export-markdown"
|
import { findToolName, formatContentBlockToMarkdown } from "../integrations/misc/export-markdown"
|
||||||
import { extractTextFromFile, addLineNumbers, stripLineNumbers, everyLineHasLineNumbers, truncateOutput } from "../integrations/misc/extract-text"
|
import {
|
||||||
|
extractTextFromFile,
|
||||||
|
addLineNumbers,
|
||||||
|
stripLineNumbers,
|
||||||
|
everyLineHasLineNumbers,
|
||||||
|
truncateOutput,
|
||||||
|
} from "../integrations/misc/extract-text"
|
||||||
import { TerminalManager } from "../integrations/terminal/TerminalManager"
|
import { TerminalManager } from "../integrations/terminal/TerminalManager"
|
||||||
import { UrlContentFetcher } from "../services/browser/UrlContentFetcher"
|
import { UrlContentFetcher } from "../services/browser/UrlContentFetcher"
|
||||||
import { listFiles } from "../services/glob/list-files"
|
import { listFiles } from "../services/glob/list-files"
|
||||||
@@ -112,7 +118,7 @@ export class Cline {
|
|||||||
experimentalDiffStrategy: boolean = false,
|
experimentalDiffStrategy: boolean = false,
|
||||||
) {
|
) {
|
||||||
if (!task && !images && !historyItem) {
|
if (!task && !images && !historyItem) {
|
||||||
throw new Error('Either historyItem or task/images must be provided');
|
throw new Error("Either historyItem or task/images must be provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
this.taskId = crypto.randomUUID()
|
this.taskId = crypto.randomUUID()
|
||||||
@@ -144,7 +150,8 @@ export class Cline {
|
|||||||
async updateDiffStrategy(experimentalDiffStrategy?: boolean) {
|
async updateDiffStrategy(experimentalDiffStrategy?: boolean) {
|
||||||
// If not provided, get from current state
|
// If not provided, get from current state
|
||||||
if (experimentalDiffStrategy === undefined) {
|
if (experimentalDiffStrategy === undefined) {
|
||||||
const { experimentalDiffStrategy: stateExperimentalDiffStrategy } = await this.providerRef.deref()?.getState() ?? {}
|
const { experimentalDiffStrategy: stateExperimentalDiffStrategy } =
|
||||||
|
(await this.providerRef.deref()?.getState()) ?? {}
|
||||||
experimentalDiffStrategy = stateExperimentalDiffStrategy ?? false
|
experimentalDiffStrategy = stateExperimentalDiffStrategy ?? false
|
||||||
}
|
}
|
||||||
this.diffStrategy = getDiffStrategy(this.api.getModel().id, this.fuzzyMatchThreshold, experimentalDiffStrategy)
|
this.diffStrategy = getDiffStrategy(this.api.getModel().id, this.fuzzyMatchThreshold, experimentalDiffStrategy)
|
||||||
@@ -471,7 +478,7 @@ export class Cline {
|
|||||||
// need to make sure that the api conversation history can be resumed by the api, even if it goes out of sync with cline messages
|
// need to make sure that the api conversation history can be resumed by the api, even if it goes out of sync with cline messages
|
||||||
|
|
||||||
let existingApiConversationHistory: Anthropic.Messages.MessageParam[] =
|
let existingApiConversationHistory: Anthropic.Messages.MessageParam[] =
|
||||||
await this.getSavedApiConversationHistory()
|
await this.getSavedApiConversationHistory()
|
||||||
|
|
||||||
// Now present the cline messages to the user and ask if they want to resume
|
// Now present the cline messages to the user and ask if they want to resume
|
||||||
|
|
||||||
@@ -582,8 +589,8 @@ export class Cline {
|
|||||||
: [{ type: "text", text: lastMessage.content }]
|
: [{ type: "text", text: lastMessage.content }]
|
||||||
if (previousAssistantMessage && previousAssistantMessage.role === "assistant") {
|
if (previousAssistantMessage && previousAssistantMessage.role === "assistant") {
|
||||||
const assistantContent = Array.isArray(previousAssistantMessage.content)
|
const assistantContent = Array.isArray(previousAssistantMessage.content)
|
||||||
? previousAssistantMessage.content
|
? previousAssistantMessage.content
|
||||||
: [{ type: "text", text: previousAssistantMessage.content }]
|
: [{ type: "text", text: previousAssistantMessage.content }]
|
||||||
|
|
||||||
const toolUseBlocks = assistantContent.filter(
|
const toolUseBlocks = assistantContent.filter(
|
||||||
(block) => block.type === "tool_use",
|
(block) => block.type === "tool_use",
|
||||||
@@ -756,8 +763,8 @@ export class Cline {
|
|||||||
// grouping command_output messages despite any gaps anyways)
|
// grouping command_output messages despite any gaps anyways)
|
||||||
await delay(50)
|
await delay(50)
|
||||||
|
|
||||||
const { terminalOutputLineLimit } = await this.providerRef.deref()?.getState() ?? {}
|
const { terminalOutputLineLimit } = (await this.providerRef.deref()?.getState()) ?? {}
|
||||||
const output = truncateOutput(lines.join('\n'), terminalOutputLineLimit)
|
const output = truncateOutput(lines.join("\n"), terminalOutputLineLimit)
|
||||||
const result = output.trim()
|
const result = output.trim()
|
||||||
|
|
||||||
if (userFeedback) {
|
if (userFeedback) {
|
||||||
@@ -788,7 +795,8 @@ export class Cline {
|
|||||||
async *attemptApiRequest(previousApiReqIndex: number): ApiStream {
|
async *attemptApiRequest(previousApiReqIndex: number): ApiStream {
|
||||||
let mcpHub: McpHub | undefined
|
let mcpHub: McpHub | undefined
|
||||||
|
|
||||||
const { mcpEnabled, alwaysApproveResubmit, requestDelaySeconds } = await this.providerRef.deref()?.getState() ?? {}
|
const { mcpEnabled, alwaysApproveResubmit, requestDelaySeconds } =
|
||||||
|
(await this.providerRef.deref()?.getState()) ?? {}
|
||||||
|
|
||||||
if (mcpEnabled ?? true) {
|
if (mcpEnabled ?? true) {
|
||||||
mcpHub = this.providerRef.deref()?.mcpHub
|
mcpHub = this.providerRef.deref()?.mcpHub
|
||||||
@@ -801,24 +809,27 @@ export class Cline {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
const { browserViewportSize, preferredLanguage, mode, customPrompts } = await this.providerRef.deref()?.getState() ?? {}
|
const { browserViewportSize, preferredLanguage, mode, customPrompts } =
|
||||||
const systemPrompt = await SYSTEM_PROMPT(
|
(await this.providerRef.deref()?.getState()) ?? {}
|
||||||
cwd,
|
const systemPrompt =
|
||||||
this.api.getModel().info.supportsComputerUse ?? false,
|
(await SYSTEM_PROMPT(
|
||||||
mcpHub,
|
cwd,
|
||||||
this.diffStrategy,
|
this.api.getModel().info.supportsComputerUse ?? false,
|
||||||
browserViewportSize,
|
mcpHub,
|
||||||
mode,
|
this.diffStrategy,
|
||||||
customPrompts
|
browserViewportSize,
|
||||||
) + await addCustomInstructions(
|
mode,
|
||||||
{
|
|
||||||
customInstructions: this.customInstructions,
|
|
||||||
customPrompts,
|
customPrompts,
|
||||||
preferredLanguage
|
)) +
|
||||||
},
|
(await addCustomInstructions(
|
||||||
cwd,
|
{
|
||||||
mode
|
customInstructions: this.customInstructions,
|
||||||
)
|
customPrompts,
|
||||||
|
preferredLanguage,
|
||||||
|
},
|
||||||
|
cwd,
|
||||||
|
mode,
|
||||||
|
))
|
||||||
|
|
||||||
// If the previous API request's total token usage is close to the context window, truncate the conversation history to free up space for the new request
|
// If the previous API request's total token usage is close to the context window, truncate the conversation history to free up space for the new request
|
||||||
if (previousApiReqIndex >= 0) {
|
if (previousApiReqIndex >= 0) {
|
||||||
@@ -845,18 +856,18 @@ export class Cline {
|
|||||||
if (Array.isArray(content)) {
|
if (Array.isArray(content)) {
|
||||||
if (!this.api.getModel().info.supportsImages) {
|
if (!this.api.getModel().info.supportsImages) {
|
||||||
// Convert image blocks to text descriptions
|
// Convert image blocks to text descriptions
|
||||||
content = content.map(block => {
|
content = content.map((block) => {
|
||||||
if (block.type === 'image') {
|
if (block.type === "image") {
|
||||||
// Convert image blocks to text descriptions
|
// Convert image blocks to text descriptions
|
||||||
// Note: We can't access the actual image content/url due to API limitations,
|
// Note: We can't access the actual image content/url due to API limitations,
|
||||||
// but we can indicate that an image was present in the conversation
|
// but we can indicate that an image was present in the conversation
|
||||||
return {
|
return {
|
||||||
type: 'text',
|
type: "text",
|
||||||
text: '[Referenced image in conversation]'
|
text: "[Referenced image in conversation]",
|
||||||
};
|
}
|
||||||
}
|
}
|
||||||
return block;
|
return block
|
||||||
});
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return { role, content }
|
return { role, content }
|
||||||
@@ -876,7 +887,12 @@ export class Cline {
|
|||||||
// Automatically retry with delay
|
// Automatically retry with delay
|
||||||
// Show countdown timer in error color
|
// Show countdown timer in error color
|
||||||
for (let i = requestDelay; i > 0; i--) {
|
for (let i = requestDelay; i > 0; i--) {
|
||||||
await this.say("api_req_retry_delayed", `${errorMsg}\n\nRetrying in ${i} seconds...`, undefined, true)
|
await this.say(
|
||||||
|
"api_req_retry_delayed",
|
||||||
|
`${errorMsg}\n\nRetrying in ${i} seconds...`,
|
||||||
|
undefined,
|
||||||
|
true,
|
||||||
|
)
|
||||||
await delay(1000)
|
await delay(1000)
|
||||||
}
|
}
|
||||||
await this.say("api_req_retry_delayed", `${errorMsg}\n\nRetrying now...`, undefined, false)
|
await this.say("api_req_retry_delayed", `${errorMsg}\n\nRetrying now...`, undefined, false)
|
||||||
@@ -1125,7 +1141,7 @@ export class Cline {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate tool use based on current mode
|
// Validate tool use based on current mode
|
||||||
const { mode } = await this.providerRef.deref()?.getState() ?? {}
|
const { mode } = (await this.providerRef.deref()?.getState()) ?? {}
|
||||||
try {
|
try {
|
||||||
validateToolUse(block.name, mode ?? defaultModeSlug)
|
validateToolUse(block.name, mode ?? defaultModeSlug)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -1192,7 +1208,10 @@ export class Cline {
|
|||||||
await this.diffViewProvider.open(relPath)
|
await this.diffViewProvider.open(relPath)
|
||||||
}
|
}
|
||||||
// editor is open, stream content in
|
// editor is open, stream content in
|
||||||
await this.diffViewProvider.update(everyLineHasLineNumbers(newContent) ? stripLineNumbers(newContent) : newContent, false)
|
await this.diffViewProvider.update(
|
||||||
|
everyLineHasLineNumbers(newContent) ? stripLineNumbers(newContent) : newContent,
|
||||||
|
false,
|
||||||
|
)
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
if (!relPath) {
|
if (!relPath) {
|
||||||
@@ -1209,7 +1228,9 @@ export class Cline {
|
|||||||
}
|
}
|
||||||
if (!predictedLineCount) {
|
if (!predictedLineCount) {
|
||||||
this.consecutiveMistakeCount++
|
this.consecutiveMistakeCount++
|
||||||
pushToolResult(await this.sayAndCreateMissingParamError("write_to_file", "line_count"))
|
pushToolResult(
|
||||||
|
await this.sayAndCreateMissingParamError("write_to_file", "line_count"),
|
||||||
|
)
|
||||||
await this.diffViewProvider.reset()
|
await this.diffViewProvider.reset()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -1224,17 +1245,28 @@ export class Cline {
|
|||||||
await this.ask("tool", partialMessage, true).catch(() => {}) // sending true for partial even though it's not a partial, this shows the edit row before the content is streamed into the editor
|
await this.ask("tool", partialMessage, true).catch(() => {}) // sending true for partial even though it's not a partial, this shows the edit row before the content is streamed into the editor
|
||||||
await this.diffViewProvider.open(relPath)
|
await this.diffViewProvider.open(relPath)
|
||||||
}
|
}
|
||||||
await this.diffViewProvider.update(everyLineHasLineNumbers(newContent) ? stripLineNumbers(newContent) : newContent, true)
|
await this.diffViewProvider.update(
|
||||||
|
everyLineHasLineNumbers(newContent) ? stripLineNumbers(newContent) : newContent,
|
||||||
|
true,
|
||||||
|
)
|
||||||
await delay(300) // wait for diff view to update
|
await delay(300) // wait for diff view to update
|
||||||
this.diffViewProvider.scrollToFirstDiff()
|
this.diffViewProvider.scrollToFirstDiff()
|
||||||
|
|
||||||
// Check for code omissions before proceeding
|
// Check for code omissions before proceeding
|
||||||
if (detectCodeOmission(this.diffViewProvider.originalContent || "", newContent, predictedLineCount)) {
|
if (
|
||||||
|
detectCodeOmission(
|
||||||
|
this.diffViewProvider.originalContent || "",
|
||||||
|
newContent,
|
||||||
|
predictedLineCount,
|
||||||
|
)
|
||||||
|
) {
|
||||||
if (this.diffStrategy) {
|
if (this.diffStrategy) {
|
||||||
await this.diffViewProvider.revertChanges()
|
await this.diffViewProvider.revertChanges()
|
||||||
pushToolResult(formatResponse.toolError(
|
pushToolResult(
|
||||||
`Content appears to be truncated (file has ${newContent.split("\n").length} lines but was predicted to have ${predictedLineCount} lines), and found comments indicating omitted code (e.g., '// rest of code unchanged', '/* previous code */'). Please provide the complete file content without any omissions if possible, or otherwise use the 'apply_diff' tool to apply the diff to the original file.`
|
formatResponse.toolError(
|
||||||
))
|
`Content appears to be truncated (file has ${newContent.split("\n").length} lines but was predicted to have ${predictedLineCount} lines), and found comments indicating omitted code (e.g., '// rest of code unchanged', '/* previous code */'). Please provide the complete file content without any omissions if possible, or otherwise use the 'apply_diff' tool to apply the diff to the original file.`,
|
||||||
|
),
|
||||||
|
)
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
vscode.window
|
vscode.window
|
||||||
@@ -1285,7 +1317,7 @@ export class Cline {
|
|||||||
pushToolResult(
|
pushToolResult(
|
||||||
`The user made the following updates to your content:\n\n${userEdits}\n\n` +
|
`The user made the following updates to your content:\n\n${userEdits}\n\n` +
|
||||||
`The updated content, which includes both your original modifications and the user's edits, has been successfully saved to ${relPath.toPosix()}. Here is the full, updated content of the file, including line numbers:\n\n` +
|
`The updated content, which includes both your original modifications and the user's edits, has been successfully saved to ${relPath.toPosix()}. Here is the full, updated content of the file, including line numbers:\n\n` +
|
||||||
`<final_file_content path="${relPath.toPosix()}">\n${addLineNumbers(finalContent || '')}\n</final_file_content>\n\n` +
|
`<final_file_content path="${relPath.toPosix()}">\n${addLineNumbers(finalContent || "")}\n</final_file_content>\n\n` +
|
||||||
`Please note:\n` +
|
`Please note:\n` +
|
||||||
`1. You do not need to re-write the file with these changes, as they have already been applied.\n` +
|
`1. You do not need to re-write the file with these changes, as they have already been applied.\n` +
|
||||||
`2. Proceed with the task using this updated file content as the new baseline.\n` +
|
`2. Proceed with the task using this updated file content as the new baseline.\n` +
|
||||||
@@ -1347,21 +1379,24 @@ export class Cline {
|
|||||||
const originalContent = await fs.readFile(absolutePath, "utf-8")
|
const originalContent = await fs.readFile(absolutePath, "utf-8")
|
||||||
|
|
||||||
// Apply the diff to the original content
|
// Apply the diff to the original content
|
||||||
const diffResult = await this.diffStrategy?.applyDiff(
|
const diffResult = (await this.diffStrategy?.applyDiff(
|
||||||
originalContent,
|
originalContent,
|
||||||
diffContent,
|
diffContent,
|
||||||
parseInt(block.params.start_line ?? ''),
|
parseInt(block.params.start_line ?? ""),
|
||||||
parseInt(block.params.end_line ?? '')
|
parseInt(block.params.end_line ?? ""),
|
||||||
) ?? {
|
)) ?? {
|
||||||
success: false,
|
success: false,
|
||||||
error: "No diff strategy available"
|
error: "No diff strategy available",
|
||||||
}
|
}
|
||||||
if (!diffResult.success) {
|
if (!diffResult.success) {
|
||||||
this.consecutiveMistakeCount++
|
this.consecutiveMistakeCount++
|
||||||
const currentCount = (this.consecutiveMistakeCountForApplyDiff.get(relPath) || 0) + 1
|
const currentCount =
|
||||||
|
(this.consecutiveMistakeCountForApplyDiff.get(relPath) || 0) + 1
|
||||||
this.consecutiveMistakeCountForApplyDiff.set(relPath, currentCount)
|
this.consecutiveMistakeCountForApplyDiff.set(relPath, currentCount)
|
||||||
const errorDetails = diffResult.details ? JSON.stringify(diffResult.details, null, 2) : ''
|
const errorDetails = diffResult.details
|
||||||
const formattedError = `Unable to apply diff to file: ${absolutePath}\n\n<error_details>\n${diffResult.error}${errorDetails ? `\n\nDetails:\n${errorDetails}` : ''}\n</error_details>`
|
? JSON.stringify(diffResult.details, null, 2)
|
||||||
|
: ""
|
||||||
|
const formattedError = `Unable to apply diff to file: ${absolutePath}\n\n<error_details>\n${diffResult.error}${errorDetails ? `\n\nDetails:\n${errorDetails}` : ""}\n</error_details>`
|
||||||
if (currentCount >= 2) {
|
if (currentCount >= 2) {
|
||||||
await this.say("error", formattedError)
|
await this.say("error", formattedError)
|
||||||
}
|
}
|
||||||
@@ -1373,9 +1408,9 @@ export class Cline {
|
|||||||
this.consecutiveMistakeCountForApplyDiff.delete(relPath)
|
this.consecutiveMistakeCountForApplyDiff.delete(relPath)
|
||||||
// Show diff view before asking for approval
|
// Show diff view before asking for approval
|
||||||
this.diffViewProvider.editType = "modify"
|
this.diffViewProvider.editType = "modify"
|
||||||
await this.diffViewProvider.open(relPath);
|
await this.diffViewProvider.open(relPath)
|
||||||
await this.diffViewProvider.update(diffResult.content, true);
|
await this.diffViewProvider.update(diffResult.content, true)
|
||||||
await this.diffViewProvider.scrollToFirstDiff();
|
await this.diffViewProvider.scrollToFirstDiff()
|
||||||
|
|
||||||
const completeMessage = JSON.stringify({
|
const completeMessage = JSON.stringify({
|
||||||
...sharedMessageProps,
|
...sharedMessageProps,
|
||||||
@@ -1403,7 +1438,7 @@ export class Cline {
|
|||||||
pushToolResult(
|
pushToolResult(
|
||||||
`The user made the following updates to your content:\n\n${userEdits}\n\n` +
|
`The user made the following updates to your content:\n\n${userEdits}\n\n` +
|
||||||
`The updated content, which includes both your original modifications and the user's edits, has been successfully saved to ${relPath.toPosix()}. Here is the full, updated content of the file, including line numbers:\n\n` +
|
`The updated content, which includes both your original modifications and the user's edits, has been successfully saved to ${relPath.toPosix()}. Here is the full, updated content of the file, including line numbers:\n\n` +
|
||||||
`<final_file_content path="${relPath.toPosix()}">\n${addLineNumbers(finalContent || '')}\n</final_file_content>\n\n` +
|
`<final_file_content path="${relPath.toPosix()}">\n${addLineNumbers(finalContent || "")}\n</final_file_content>\n\n` +
|
||||||
`Please note:\n` +
|
`Please note:\n` +
|
||||||
`1. You do not need to re-write the file with these changes, as they have already been applied.\n` +
|
`1. You do not need to re-write the file with these changes, as they have already been applied.\n` +
|
||||||
`2. Proceed with the task using this updated file content as the new baseline.\n` +
|
`2. Proceed with the task using this updated file content as the new baseline.\n` +
|
||||||
@@ -1411,7 +1446,9 @@ export class Cline {
|
|||||||
`${newProblemsMessage}`,
|
`${newProblemsMessage}`,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
pushToolResult(`Changes successfully applied to ${relPath.toPosix()}:\n\n${newProblemsMessage}`)
|
pushToolResult(
|
||||||
|
`Changes successfully applied to ${relPath.toPosix()}:\n\n${newProblemsMessage}`,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
await this.diffViewProvider.reset()
|
await this.diffViewProvider.reset()
|
||||||
break
|
break
|
||||||
@@ -1615,7 +1652,7 @@ export class Cline {
|
|||||||
await this.ask(
|
await this.ask(
|
||||||
"browser_action_launch",
|
"browser_action_launch",
|
||||||
removeClosingTag("url", url),
|
removeClosingTag("url", url),
|
||||||
block.partial
|
block.partial,
|
||||||
).catch(() => {})
|
).catch(() => {})
|
||||||
} else {
|
} else {
|
||||||
await this.say(
|
await this.say(
|
||||||
@@ -1744,7 +1781,7 @@ export class Cline {
|
|||||||
try {
|
try {
|
||||||
if (block.partial) {
|
if (block.partial) {
|
||||||
await this.ask("command", removeClosingTag("command", command), block.partial).catch(
|
await this.ask("command", removeClosingTag("command", command), block.partial).catch(
|
||||||
() => {}
|
() => {},
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
@@ -2409,7 +2446,7 @@ export class Cline {
|
|||||||
Promise.all(
|
Promise.all(
|
||||||
userContent.map(async (block) => {
|
userContent.map(async (block) => {
|
||||||
const shouldProcessMentions = (text: string) =>
|
const shouldProcessMentions = (text: string) =>
|
||||||
text.includes("<task>") || text.includes("<feedback>");
|
text.includes("<task>") || text.includes("<feedback>")
|
||||||
|
|
||||||
if (block.type === "text") {
|
if (block.type === "text") {
|
||||||
if (shouldProcessMentions(block.text)) {
|
if (shouldProcessMentions(block.text)) {
|
||||||
@@ -2418,7 +2455,7 @@ export class Cline {
|
|||||||
text: await parseMentions(block.text, cwd, this.urlContentFetcher),
|
text: await parseMentions(block.text, cwd, this.urlContentFetcher),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return block;
|
return block
|
||||||
} else if (block.type === "tool_result") {
|
} else if (block.type === "tool_result") {
|
||||||
if (typeof block.content === "string") {
|
if (typeof block.content === "string") {
|
||||||
if (shouldProcessMentions(block.content)) {
|
if (shouldProcessMentions(block.content)) {
|
||||||
@@ -2427,7 +2464,7 @@ export class Cline {
|
|||||||
content: await parseMentions(block.content, cwd, this.urlContentFetcher),
|
content: await parseMentions(block.content, cwd, this.urlContentFetcher),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return block;
|
return block
|
||||||
} else if (Array.isArray(block.content)) {
|
} else if (Array.isArray(block.content)) {
|
||||||
const parsedContent = await Promise.all(
|
const parsedContent = await Promise.all(
|
||||||
block.content.map(async (contentBlock) => {
|
block.content.map(async (contentBlock) => {
|
||||||
@@ -2445,7 +2482,7 @@ export class Cline {
|
|||||||
content: parsedContent,
|
content: parsedContent,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return block;
|
return block
|
||||||
}
|
}
|
||||||
return block
|
return block
|
||||||
}),
|
}),
|
||||||
@@ -2571,26 +2608,29 @@ export class Cline {
|
|||||||
// Add current time information with timezone
|
// Add current time information with timezone
|
||||||
const now = new Date()
|
const now = new Date()
|
||||||
const formatter = new Intl.DateTimeFormat(undefined, {
|
const formatter = new Intl.DateTimeFormat(undefined, {
|
||||||
year: 'numeric',
|
year: "numeric",
|
||||||
month: 'numeric',
|
month: "numeric",
|
||||||
day: 'numeric',
|
day: "numeric",
|
||||||
hour: 'numeric',
|
hour: "numeric",
|
||||||
minute: 'numeric',
|
minute: "numeric",
|
||||||
second: 'numeric',
|
second: "numeric",
|
||||||
hour12: true
|
hour12: true,
|
||||||
})
|
})
|
||||||
const timeZone = formatter.resolvedOptions().timeZone
|
const timeZone = formatter.resolvedOptions().timeZone
|
||||||
const timeZoneOffset = -now.getTimezoneOffset() / 60 // Convert to hours and invert sign to match conventional notation
|
const timeZoneOffset = -now.getTimezoneOffset() / 60 // Convert to hours and invert sign to match conventional notation
|
||||||
const timeZoneOffsetStr = `${timeZoneOffset >= 0 ? '+' : ''}${timeZoneOffset}:00`
|
const timeZoneOffsetStr = `${timeZoneOffset >= 0 ? "+" : ""}${timeZoneOffset}:00`
|
||||||
details += `\n\n# Current Time\n${formatter.format(now)} (${timeZone}, UTC${timeZoneOffsetStr})`
|
details += `\n\n# Current Time\n${formatter.format(now)} (${timeZone}, UTC${timeZoneOffsetStr})`
|
||||||
|
|
||||||
// Add current mode and any mode-specific warnings
|
// Add current mode and any mode-specific warnings
|
||||||
const { mode } = await this.providerRef.deref()?.getState() ?? {}
|
const { mode } = (await this.providerRef.deref()?.getState()) ?? {}
|
||||||
const currentMode = mode ?? defaultModeSlug
|
const currentMode = mode ?? defaultModeSlug
|
||||||
details += `\n\n# Current Mode\n${currentMode}`
|
details += `\n\n# Current Mode\n${currentMode}`
|
||||||
|
|
||||||
// Add warning if not in code mode
|
// Add warning if not in code mode
|
||||||
if (!isToolAllowedForMode('write_to_file', currentMode) || !isToolAllowedForMode('execute_command', currentMode)) {
|
if (
|
||||||
|
!isToolAllowedForMode("write_to_file", currentMode) ||
|
||||||
|
!isToolAllowedForMode("execute_command", currentMode)
|
||||||
|
) {
|
||||||
details += `\n\nNOTE: You are currently in '${currentMode}' mode which only allows read-only operations. To write files or execute commands, the user will need to switch to '${defaultModeSlug}' mode. Note that only the user can switch modes.`
|
details += `\n\nNOTE: You are currently in '${currentMode}' mode which only allows read-only operations. To write files or execute commands, the user will need to switch to '${defaultModeSlug}' mode. Note that only the user can switch modes.`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2609,4 +2649,4 @@ export class Cline {
|
|||||||
|
|
||||||
return `<environment_details>\n${details.trim()}\n</environment_details>`
|
return `<environment_details>\n${details.trim()}\n</environment_details>`
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,52 +1,52 @@
|
|||||||
import { Mode, isToolAllowedForMode, TestToolName, getModeConfig, modes } from '../../shared/modes';
|
import { Mode, isToolAllowedForMode, TestToolName, getModeConfig, modes } from "../../shared/modes"
|
||||||
import { validateToolUse } from '../mode-validator';
|
import { validateToolUse } from "../mode-validator"
|
||||||
|
|
||||||
const asTestTool = (tool: string): TestToolName => tool as TestToolName;
|
const asTestTool = (tool: string): TestToolName => tool as TestToolName
|
||||||
const [codeMode, architectMode, askMode] = modes.map(mode => mode.slug);
|
const [codeMode, architectMode, askMode] = modes.map((mode) => mode.slug)
|
||||||
|
|
||||||
describe('mode-validator', () => {
|
describe("mode-validator", () => {
|
||||||
describe('isToolAllowedForMode', () => {
|
describe("isToolAllowedForMode", () => {
|
||||||
describe('code mode', () => {
|
describe("code mode", () => {
|
||||||
it('allows all code mode tools', () => {
|
it("allows all code mode tools", () => {
|
||||||
const mode = getModeConfig(codeMode);
|
const mode = getModeConfig(codeMode)
|
||||||
mode.tools.forEach(([tool]) => {
|
mode.tools.forEach(([tool]) => {
|
||||||
expect(isToolAllowedForMode(tool, codeMode)).toBe(true)
|
expect(isToolAllowedForMode(tool, codeMode)).toBe(true)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('disallows unknown tools', () => {
|
it("disallows unknown tools", () => {
|
||||||
expect(isToolAllowedForMode(asTestTool('unknown_tool'), codeMode)).toBe(false)
|
expect(isToolAllowedForMode(asTestTool("unknown_tool"), codeMode)).toBe(false)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('architect mode', () => {
|
describe("architect mode", () => {
|
||||||
it('allows configured tools', () => {
|
it("allows configured tools", () => {
|
||||||
const mode = getModeConfig(architectMode);
|
const mode = getModeConfig(architectMode)
|
||||||
mode.tools.forEach(([tool]) => {
|
mode.tools.forEach(([tool]) => {
|
||||||
expect(isToolAllowedForMode(tool, architectMode)).toBe(true)
|
expect(isToolAllowedForMode(tool, architectMode)).toBe(true)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('ask mode', () => {
|
describe("ask mode", () => {
|
||||||
it('allows configured tools', () => {
|
it("allows configured tools", () => {
|
||||||
const mode = getModeConfig(askMode);
|
const mode = getModeConfig(askMode)
|
||||||
mode.tools.forEach(([tool]) => {
|
mode.tools.forEach(([tool]) => {
|
||||||
expect(isToolAllowedForMode(tool, askMode)).toBe(true)
|
expect(isToolAllowedForMode(tool, askMode)).toBe(true)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('validateToolUse', () => {
|
describe("validateToolUse", () => {
|
||||||
it('throws error for disallowed tools in architect mode', () => {
|
it("throws error for disallowed tools in architect mode", () => {
|
||||||
expect(() => validateToolUse('unknown_tool', 'architect')).toThrow(
|
expect(() => validateToolUse("unknown_tool", "architect")).toThrow(
|
||||||
'Tool "unknown_tool" is not allowed in architect mode.'
|
'Tool "unknown_tool" is not allowed in architect mode.',
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('does not throw for allowed tools in architect mode', () => {
|
it("does not throw for allowed tools in architect mode", () => {
|
||||||
expect(() => validateToolUse('read_file', 'architect')).not.toThrow()
|
expect(() => validateToolUse("read_file", "architect")).not.toThrow()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,221 +1,221 @@
|
|||||||
import { ExtensionContext } from 'vscode'
|
import { ExtensionContext } from "vscode"
|
||||||
import { ApiConfiguration } from '../../shared/api'
|
import { ApiConfiguration } from "../../shared/api"
|
||||||
import { Mode } from '../prompts/types'
|
import { Mode } from "../prompts/types"
|
||||||
import { ApiConfigMeta } from '../../shared/ExtensionMessage'
|
import { ApiConfigMeta } from "../../shared/ExtensionMessage"
|
||||||
|
|
||||||
export interface ApiConfigData {
|
export interface ApiConfigData {
|
||||||
currentApiConfigName: string
|
currentApiConfigName: string
|
||||||
apiConfigs: {
|
apiConfigs: {
|
||||||
[key: string]: ApiConfiguration
|
[key: string]: ApiConfiguration
|
||||||
}
|
}
|
||||||
modeApiConfigs?: Partial<Record<Mode, string>>
|
modeApiConfigs?: Partial<Record<Mode, string>>
|
||||||
}
|
}
|
||||||
|
|
||||||
export class ConfigManager {
|
export class ConfigManager {
|
||||||
private readonly defaultConfig: ApiConfigData = {
|
private readonly defaultConfig: ApiConfigData = {
|
||||||
currentApiConfigName: 'default',
|
currentApiConfigName: "default",
|
||||||
apiConfigs: {
|
apiConfigs: {
|
||||||
default: {
|
default: {
|
||||||
id: this.generateId()
|
id: this.generateId(),
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
private readonly SCOPE_PREFIX = "roo_cline_config_"
|
private readonly SCOPE_PREFIX = "roo_cline_config_"
|
||||||
private readonly context: ExtensionContext
|
private readonly context: ExtensionContext
|
||||||
|
|
||||||
constructor(context: ExtensionContext) {
|
constructor(context: ExtensionContext) {
|
||||||
this.context = context
|
this.context = context
|
||||||
this.initConfig().catch(console.error)
|
this.initConfig().catch(console.error)
|
||||||
}
|
}
|
||||||
|
|
||||||
private generateId(): string {
|
private generateId(): string {
|
||||||
return Math.random().toString(36).substring(2, 15)
|
return Math.random().toString(36).substring(2, 15)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize config if it doesn't exist
|
* Initialize config if it doesn't exist
|
||||||
*/
|
*/
|
||||||
async initConfig(): Promise<void> {
|
async initConfig(): Promise<void> {
|
||||||
try {
|
try {
|
||||||
const config = await this.readConfig()
|
const config = await this.readConfig()
|
||||||
if (!config) {
|
if (!config) {
|
||||||
await this.writeConfig(this.defaultConfig)
|
await this.writeConfig(this.defaultConfig)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migrate: ensure all configs have IDs
|
// Migrate: ensure all configs have IDs
|
||||||
let needsMigration = false
|
let needsMigration = false
|
||||||
for (const [name, apiConfig] of Object.entries(config.apiConfigs)) {
|
for (const [name, apiConfig] of Object.entries(config.apiConfigs)) {
|
||||||
if (!apiConfig.id) {
|
if (!apiConfig.id) {
|
||||||
apiConfig.id = this.generateId()
|
apiConfig.id = this.generateId()
|
||||||
needsMigration = true
|
needsMigration = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (needsMigration) {
|
if (needsMigration) {
|
||||||
await this.writeConfig(config)
|
await this.writeConfig(config)
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
throw new Error(`Failed to initialize config: ${error}`)
|
throw new Error(`Failed to initialize config: ${error}`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* List all available configs with metadata
|
* List all available configs with metadata
|
||||||
*/
|
*/
|
||||||
async ListConfig(): Promise<ApiConfigMeta[]> {
|
async ListConfig(): Promise<ApiConfigMeta[]> {
|
||||||
try {
|
try {
|
||||||
const config = await this.readConfig()
|
const config = await this.readConfig()
|
||||||
return Object.entries(config.apiConfigs).map(([name, apiConfig]) => ({
|
return Object.entries(config.apiConfigs).map(([name, apiConfig]) => ({
|
||||||
name,
|
name,
|
||||||
id: apiConfig.id || '',
|
id: apiConfig.id || "",
|
||||||
apiProvider: apiConfig.apiProvider,
|
apiProvider: apiConfig.apiProvider,
|
||||||
}))
|
}))
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
throw new Error(`Failed to list configs: ${error}`)
|
throw new Error(`Failed to list configs: ${error}`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Save a config with the given name
|
* Save a config with the given name
|
||||||
*/
|
*/
|
||||||
async SaveConfig(name: string, config: ApiConfiguration): Promise<void> {
|
async SaveConfig(name: string, config: ApiConfiguration): Promise<void> {
|
||||||
try {
|
try {
|
||||||
const currentConfig = await this.readConfig()
|
const currentConfig = await this.readConfig()
|
||||||
const existingConfig = currentConfig.apiConfigs[name]
|
const existingConfig = currentConfig.apiConfigs[name]
|
||||||
currentConfig.apiConfigs[name] = {
|
currentConfig.apiConfigs[name] = {
|
||||||
...config,
|
...config,
|
||||||
id: existingConfig?.id || this.generateId()
|
id: existingConfig?.id || this.generateId(),
|
||||||
}
|
}
|
||||||
await this.writeConfig(currentConfig)
|
await this.writeConfig(currentConfig)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
throw new Error(`Failed to save config: ${error}`)
|
throw new Error(`Failed to save config: ${error}`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load a config by name
|
* Load a config by name
|
||||||
*/
|
*/
|
||||||
async LoadConfig(name: string): Promise<ApiConfiguration> {
|
async LoadConfig(name: string): Promise<ApiConfiguration> {
|
||||||
try {
|
try {
|
||||||
const config = await this.readConfig()
|
const config = await this.readConfig()
|
||||||
const apiConfig = config.apiConfigs[name]
|
const apiConfig = config.apiConfigs[name]
|
||||||
|
|
||||||
if (!apiConfig) {
|
|
||||||
throw new Error(`Config '${name}' not found`)
|
|
||||||
}
|
|
||||||
|
|
||||||
config.currentApiConfigName = name;
|
|
||||||
await this.writeConfig(config)
|
|
||||||
|
|
||||||
return apiConfig
|
|
||||||
} catch (error) {
|
|
||||||
throw new Error(`Failed to load config: ${error}`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
if (!apiConfig) {
|
||||||
* Delete a config by name
|
throw new Error(`Config '${name}' not found`)
|
||||||
*/
|
}
|
||||||
async DeleteConfig(name: string): Promise<void> {
|
|
||||||
try {
|
|
||||||
const currentConfig = await this.readConfig()
|
|
||||||
if (!currentConfig.apiConfigs[name]) {
|
|
||||||
throw new Error(`Config '${name}' not found`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Don't allow deleting the default config
|
config.currentApiConfigName = name
|
||||||
if (Object.keys(currentConfig.apiConfigs).length === 1) {
|
await this.writeConfig(config)
|
||||||
throw new Error(`Cannot delete the last remaining configuration.`)
|
|
||||||
}
|
|
||||||
|
|
||||||
delete currentConfig.apiConfigs[name]
|
return apiConfig
|
||||||
await this.writeConfig(currentConfig)
|
} catch (error) {
|
||||||
} catch (error) {
|
throw new Error(`Failed to load config: ${error}`)
|
||||||
throw new Error(`Failed to delete config: ${error}`)
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the current active API configuration
|
* Delete a config by name
|
||||||
*/
|
*/
|
||||||
async SetCurrentConfig(name: string): Promise<void> {
|
async DeleteConfig(name: string): Promise<void> {
|
||||||
try {
|
try {
|
||||||
const currentConfig = await this.readConfig()
|
const currentConfig = await this.readConfig()
|
||||||
if (!currentConfig.apiConfigs[name]) {
|
if (!currentConfig.apiConfigs[name]) {
|
||||||
throw new Error(`Config '${name}' not found`)
|
throw new Error(`Config '${name}' not found`)
|
||||||
}
|
}
|
||||||
|
|
||||||
currentConfig.currentApiConfigName = name
|
// Don't allow deleting the default config
|
||||||
await this.writeConfig(currentConfig)
|
if (Object.keys(currentConfig.apiConfigs).length === 1) {
|
||||||
} catch (error) {
|
throw new Error(`Cannot delete the last remaining configuration.`)
|
||||||
throw new Error(`Failed to set current config: ${error}`)
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
delete currentConfig.apiConfigs[name]
|
||||||
* Check if a config exists by name
|
await this.writeConfig(currentConfig)
|
||||||
*/
|
} catch (error) {
|
||||||
async HasConfig(name: string): Promise<boolean> {
|
throw new Error(`Failed to delete config: ${error}`)
|
||||||
try {
|
}
|
||||||
const config = await this.readConfig()
|
}
|
||||||
return name in config.apiConfigs
|
|
||||||
} catch (error) {
|
|
||||||
throw new Error(`Failed to check config existence: ${error}`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the API config for a specific mode
|
* Set the current active API configuration
|
||||||
*/
|
*/
|
||||||
async SetModeConfig(mode: Mode, configId: string): Promise<void> {
|
async SetCurrentConfig(name: string): Promise<void> {
|
||||||
try {
|
try {
|
||||||
const currentConfig = await this.readConfig()
|
const currentConfig = await this.readConfig()
|
||||||
if (!currentConfig.modeApiConfigs) {
|
if (!currentConfig.apiConfigs[name]) {
|
||||||
currentConfig.modeApiConfigs = {}
|
throw new Error(`Config '${name}' not found`)
|
||||||
}
|
}
|
||||||
currentConfig.modeApiConfigs[mode] = configId
|
|
||||||
await this.writeConfig(currentConfig)
|
|
||||||
} catch (error) {
|
|
||||||
throw new Error(`Failed to set mode config: ${error}`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
currentConfig.currentApiConfigName = name
|
||||||
* Get the API config ID for a specific mode
|
await this.writeConfig(currentConfig)
|
||||||
*/
|
} catch (error) {
|
||||||
async GetModeConfigId(mode: Mode): Promise<string | undefined> {
|
throw new Error(`Failed to set current config: ${error}`)
|
||||||
try {
|
}
|
||||||
const config = await this.readConfig()
|
}
|
||||||
return config.modeApiConfigs?.[mode]
|
|
||||||
} catch (error) {
|
|
||||||
throw new Error(`Failed to get mode config: ${error}`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private async readConfig(): Promise<ApiConfigData> {
|
/**
|
||||||
try {
|
* Check if a config exists by name
|
||||||
const configKey = `${this.SCOPE_PREFIX}api_config`
|
*/
|
||||||
const content = await this.context.secrets.get(configKey)
|
async HasConfig(name: string): Promise<boolean> {
|
||||||
|
try {
|
||||||
if (!content) {
|
const config = await this.readConfig()
|
||||||
return this.defaultConfig
|
return name in config.apiConfigs
|
||||||
}
|
} catch (error) {
|
||||||
|
throw new Error(`Failed to check config existence: ${error}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return JSON.parse(content)
|
/**
|
||||||
} catch (error) {
|
* Set the API config for a specific mode
|
||||||
throw new Error(`Failed to read config from secrets: ${error}`)
|
*/
|
||||||
}
|
async SetModeConfig(mode: Mode, configId: string): Promise<void> {
|
||||||
}
|
try {
|
||||||
|
const currentConfig = await this.readConfig()
|
||||||
|
if (!currentConfig.modeApiConfigs) {
|
||||||
|
currentConfig.modeApiConfigs = {}
|
||||||
|
}
|
||||||
|
currentConfig.modeApiConfigs[mode] = configId
|
||||||
|
await this.writeConfig(currentConfig)
|
||||||
|
} catch (error) {
|
||||||
|
throw new Error(`Failed to set mode config: ${error}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private async writeConfig(config: ApiConfigData): Promise<void> {
|
/**
|
||||||
try {
|
* Get the API config ID for a specific mode
|
||||||
const configKey = `${this.SCOPE_PREFIX}api_config`
|
*/
|
||||||
const content = JSON.stringify(config, null, 2)
|
async GetModeConfigId(mode: Mode): Promise<string | undefined> {
|
||||||
await this.context.secrets.store(configKey, content)
|
try {
|
||||||
} catch (error) {
|
const config = await this.readConfig()
|
||||||
throw new Error(`Failed to write config to secrets: ${error}`)
|
return config.modeApiConfigs?.[mode]
|
||||||
}
|
} catch (error) {
|
||||||
}
|
throw new Error(`Failed to get mode config: ${error}`)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async readConfig(): Promise<ApiConfigData> {
|
||||||
|
try {
|
||||||
|
const configKey = `${this.SCOPE_PREFIX}api_config`
|
||||||
|
const content = await this.context.secrets.get(configKey)
|
||||||
|
|
||||||
|
if (!content) {
|
||||||
|
return this.defaultConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
return JSON.parse(content)
|
||||||
|
} catch (error) {
|
||||||
|
throw new Error(`Failed to read config from secrets: ${error}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async writeConfig(config: ApiConfigData): Promise<void> {
|
||||||
|
try {
|
||||||
|
const configKey = `${this.SCOPE_PREFIX}api_config`
|
||||||
|
const content = JSON.stringify(config, null, 2)
|
||||||
|
await this.context.secrets.store(configKey, content)
|
||||||
|
} catch (error) {
|
||||||
|
throw new Error(`Failed to write config to secrets: ${error}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,452 +1,470 @@
|
|||||||
import { ExtensionContext } from 'vscode'
|
import { ExtensionContext } from "vscode"
|
||||||
import { ConfigManager, ApiConfigData } from '../ConfigManager'
|
import { ConfigManager, ApiConfigData } from "../ConfigManager"
|
||||||
import { ApiConfiguration } from '../../../shared/api'
|
import { ApiConfiguration } from "../../../shared/api"
|
||||||
|
|
||||||
// Mock VSCode ExtensionContext
|
// Mock VSCode ExtensionContext
|
||||||
const mockSecrets = {
|
const mockSecrets = {
|
||||||
get: jest.fn(),
|
get: jest.fn(),
|
||||||
store: jest.fn(),
|
store: jest.fn(),
|
||||||
delete: jest.fn()
|
delete: jest.fn(),
|
||||||
}
|
}
|
||||||
|
|
||||||
const mockContext = {
|
const mockContext = {
|
||||||
secrets: mockSecrets
|
secrets: mockSecrets,
|
||||||
} as unknown as ExtensionContext
|
} as unknown as ExtensionContext
|
||||||
|
|
||||||
describe('ConfigManager', () => {
|
describe("ConfigManager", () => {
|
||||||
let configManager: ConfigManager
|
let configManager: ConfigManager
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
jest.clearAllMocks()
|
||||||
configManager = new ConfigManager(mockContext)
|
configManager = new ConfigManager(mockContext)
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('initConfig', () => {
|
describe("initConfig", () => {
|
||||||
it('should not write to storage when secrets.get returns null', async () => {
|
it("should not write to storage when secrets.get returns null", async () => {
|
||||||
// Mock readConfig to return null
|
// Mock readConfig to return null
|
||||||
mockSecrets.get.mockResolvedValueOnce(null)
|
mockSecrets.get.mockResolvedValueOnce(null)
|
||||||
|
|
||||||
await configManager.initConfig()
|
await configManager.initConfig()
|
||||||
|
|
||||||
// Should not write to storage because readConfig returns defaultConfig
|
// Should not write to storage because readConfig returns defaultConfig
|
||||||
expect(mockSecrets.store).not.toHaveBeenCalled()
|
expect(mockSecrets.store).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should not initialize config if it exists', async () => {
|
it("should not initialize config if it exists", async () => {
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
mockSecrets.get.mockResolvedValue(
|
||||||
currentApiConfigName: 'default',
|
JSON.stringify({
|
||||||
apiConfigs: {
|
currentApiConfigName: "default",
|
||||||
default: {
|
apiConfigs: {
|
||||||
config: {},
|
default: {
|
||||||
id: 'default'
|
config: {},
|
||||||
}
|
id: "default",
|
||||||
}
|
},
|
||||||
}))
|
},
|
||||||
|
}),
|
||||||
await configManager.initConfig()
|
)
|
||||||
|
|
||||||
expect(mockSecrets.store).not.toHaveBeenCalled()
|
await configManager.initConfig()
|
||||||
})
|
|
||||||
|
expect(mockSecrets.store).not.toHaveBeenCalled()
|
||||||
it('should generate IDs for configs that lack them', async () => {
|
})
|
||||||
// Mock a config with missing IDs
|
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
it("should generate IDs for configs that lack them", async () => {
|
||||||
currentApiConfigName: 'default',
|
// Mock a config with missing IDs
|
||||||
apiConfigs: {
|
mockSecrets.get.mockResolvedValue(
|
||||||
default: {
|
JSON.stringify({
|
||||||
config: {}
|
currentApiConfigName: "default",
|
||||||
},
|
apiConfigs: {
|
||||||
test: {
|
default: {
|
||||||
apiProvider: 'anthropic'
|
config: {},
|
||||||
}
|
},
|
||||||
}
|
test: {
|
||||||
}))
|
apiProvider: "anthropic",
|
||||||
|
},
|
||||||
await configManager.initConfig()
|
},
|
||||||
|
}),
|
||||||
// Should have written the config with new IDs
|
)
|
||||||
expect(mockSecrets.store).toHaveBeenCalled()
|
|
||||||
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
|
await configManager.initConfig()
|
||||||
expect(storedConfig.apiConfigs.default.id).toBeTruthy()
|
|
||||||
expect(storedConfig.apiConfigs.test.id).toBeTruthy()
|
// Should have written the config with new IDs
|
||||||
})
|
expect(mockSecrets.store).toHaveBeenCalled()
|
||||||
|
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
|
||||||
it('should throw error if secrets storage fails', async () => {
|
expect(storedConfig.apiConfigs.default.id).toBeTruthy()
|
||||||
mockSecrets.get.mockRejectedValue(new Error('Storage failed'))
|
expect(storedConfig.apiConfigs.test.id).toBeTruthy()
|
||||||
|
})
|
||||||
await expect(configManager.initConfig()).rejects.toThrow(
|
|
||||||
'Failed to initialize config: Error: Failed to read config from secrets: Error: Storage failed'
|
it("should throw error if secrets storage fails", async () => {
|
||||||
)
|
mockSecrets.get.mockRejectedValue(new Error("Storage failed"))
|
||||||
})
|
|
||||||
})
|
await expect(configManager.initConfig()).rejects.toThrow(
|
||||||
|
"Failed to initialize config: Error: Failed to read config from secrets: Error: Storage failed",
|
||||||
describe('ListConfig', () => {
|
)
|
||||||
it('should list all available configs', async () => {
|
})
|
||||||
const existingConfig: ApiConfigData = {
|
})
|
||||||
currentApiConfigName: 'default',
|
|
||||||
apiConfigs: {
|
describe("ListConfig", () => {
|
||||||
default: {
|
it("should list all available configs", async () => {
|
||||||
id: 'default'
|
const existingConfig: ApiConfigData = {
|
||||||
},
|
currentApiConfigName: "default",
|
||||||
test: {
|
apiConfigs: {
|
||||||
apiProvider: 'anthropic',
|
default: {
|
||||||
id: 'test-id'
|
id: "default",
|
||||||
}
|
},
|
||||||
},
|
test: {
|
||||||
modeApiConfigs: {
|
apiProvider: "anthropic",
|
||||||
code: 'default',
|
id: "test-id",
|
||||||
architect: 'default',
|
},
|
||||||
ask: 'default'
|
},
|
||||||
}
|
modeApiConfigs: {
|
||||||
}
|
code: "default",
|
||||||
|
architect: "default",
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
ask: "default",
|
||||||
|
},
|
||||||
const configs = await configManager.ListConfig()
|
}
|
||||||
expect(configs).toEqual([
|
|
||||||
{ name: 'default', id: 'default', apiProvider: undefined },
|
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
||||||
{ name: 'test', id: 'test-id', apiProvider: 'anthropic' }
|
|
||||||
])
|
const configs = await configManager.ListConfig()
|
||||||
})
|
expect(configs).toEqual([
|
||||||
|
{ name: "default", id: "default", apiProvider: undefined },
|
||||||
it('should handle empty config file', async () => {
|
{ name: "test", id: "test-id", apiProvider: "anthropic" },
|
||||||
const emptyConfig: ApiConfigData = {
|
])
|
||||||
currentApiConfigName: 'default',
|
})
|
||||||
apiConfigs: {},
|
|
||||||
modeApiConfigs: {
|
it("should handle empty config file", async () => {
|
||||||
code: 'default',
|
const emptyConfig: ApiConfigData = {
|
||||||
architect: 'default',
|
currentApiConfigName: "default",
|
||||||
ask: 'default'
|
apiConfigs: {},
|
||||||
}
|
modeApiConfigs: {
|
||||||
}
|
code: "default",
|
||||||
|
architect: "default",
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify(emptyConfig))
|
ask: "default",
|
||||||
|
},
|
||||||
const configs = await configManager.ListConfig()
|
}
|
||||||
expect(configs).toEqual([])
|
|
||||||
})
|
mockSecrets.get.mockResolvedValue(JSON.stringify(emptyConfig))
|
||||||
|
|
||||||
it('should throw error if reading from secrets fails', async () => {
|
const configs = await configManager.ListConfig()
|
||||||
mockSecrets.get.mockRejectedValue(new Error('Read failed'))
|
expect(configs).toEqual([])
|
||||||
|
})
|
||||||
await expect(configManager.ListConfig()).rejects.toThrow(
|
|
||||||
'Failed to list configs: Error: Failed to read config from secrets: Error: Read failed'
|
it("should throw error if reading from secrets fails", async () => {
|
||||||
)
|
mockSecrets.get.mockRejectedValue(new Error("Read failed"))
|
||||||
})
|
|
||||||
})
|
await expect(configManager.ListConfig()).rejects.toThrow(
|
||||||
|
"Failed to list configs: Error: Failed to read config from secrets: Error: Read failed",
|
||||||
describe('SaveConfig', () => {
|
)
|
||||||
it('should save new config', async () => {
|
})
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
})
|
||||||
currentApiConfigName: 'default',
|
|
||||||
apiConfigs: {
|
describe("SaveConfig", () => {
|
||||||
default: {}
|
it("should save new config", async () => {
|
||||||
},
|
mockSecrets.get.mockResolvedValue(
|
||||||
modeApiConfigs: {
|
JSON.stringify({
|
||||||
code: 'default',
|
currentApiConfigName: "default",
|
||||||
architect: 'default',
|
apiConfigs: {
|
||||||
ask: 'default'
|
default: {},
|
||||||
}
|
},
|
||||||
}))
|
modeApiConfigs: {
|
||||||
|
code: "default",
|
||||||
const newConfig: ApiConfiguration = {
|
architect: "default",
|
||||||
apiProvider: 'anthropic',
|
ask: "default",
|
||||||
apiKey: 'test-key'
|
},
|
||||||
}
|
}),
|
||||||
|
)
|
||||||
await configManager.SaveConfig('test', newConfig)
|
|
||||||
|
const newConfig: ApiConfiguration = {
|
||||||
// Get the actual stored config to check the generated ID
|
apiProvider: "anthropic",
|
||||||
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
|
apiKey: "test-key",
|
||||||
const testConfigId = storedConfig.apiConfigs.test.id
|
}
|
||||||
|
|
||||||
const expectedConfig = {
|
await configManager.SaveConfig("test", newConfig)
|
||||||
currentApiConfigName: 'default',
|
|
||||||
apiConfigs: {
|
// Get the actual stored config to check the generated ID
|
||||||
default: {},
|
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
|
||||||
test: {
|
const testConfigId = storedConfig.apiConfigs.test.id
|
||||||
...newConfig,
|
|
||||||
id: testConfigId
|
const expectedConfig = {
|
||||||
}
|
currentApiConfigName: "default",
|
||||||
},
|
apiConfigs: {
|
||||||
modeApiConfigs: {
|
default: {},
|
||||||
code: 'default',
|
test: {
|
||||||
architect: 'default',
|
...newConfig,
|
||||||
ask: 'default'
|
id: testConfigId,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
|
modeApiConfigs: {
|
||||||
expect(mockSecrets.store).toHaveBeenCalledWith(
|
code: "default",
|
||||||
'roo_cline_config_api_config',
|
architect: "default",
|
||||||
JSON.stringify(expectedConfig, null, 2)
|
ask: "default",
|
||||||
)
|
},
|
||||||
})
|
}
|
||||||
|
|
||||||
it('should update existing config', async () => {
|
expect(mockSecrets.store).toHaveBeenCalledWith(
|
||||||
const existingConfig: ApiConfigData = {
|
"roo_cline_config_api_config",
|
||||||
currentApiConfigName: 'default',
|
JSON.stringify(expectedConfig, null, 2),
|
||||||
apiConfigs: {
|
)
|
||||||
test: {
|
})
|
||||||
apiProvider: 'anthropic',
|
|
||||||
apiKey: 'old-key',
|
it("should update existing config", async () => {
|
||||||
id: 'test-id'
|
const existingConfig: ApiConfigData = {
|
||||||
}
|
currentApiConfigName: "default",
|
||||||
}
|
apiConfigs: {
|
||||||
}
|
test: {
|
||||||
|
apiProvider: "anthropic",
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
apiKey: "old-key",
|
||||||
|
id: "test-id",
|
||||||
const updatedConfig: ApiConfiguration = {
|
},
|
||||||
apiProvider: 'anthropic',
|
},
|
||||||
apiKey: 'new-key'
|
}
|
||||||
}
|
|
||||||
|
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
||||||
await configManager.SaveConfig('test', updatedConfig)
|
|
||||||
|
const updatedConfig: ApiConfiguration = {
|
||||||
const expectedConfig = {
|
apiProvider: "anthropic",
|
||||||
currentApiConfigName: 'default',
|
apiKey: "new-key",
|
||||||
apiConfigs: {
|
}
|
||||||
test: {
|
|
||||||
apiProvider: 'anthropic',
|
await configManager.SaveConfig("test", updatedConfig)
|
||||||
apiKey: 'new-key',
|
|
||||||
id: 'test-id'
|
const expectedConfig = {
|
||||||
}
|
currentApiConfigName: "default",
|
||||||
}
|
apiConfigs: {
|
||||||
}
|
test: {
|
||||||
|
apiProvider: "anthropic",
|
||||||
expect(mockSecrets.store).toHaveBeenCalledWith(
|
apiKey: "new-key",
|
||||||
'roo_cline_config_api_config',
|
id: "test-id",
|
||||||
JSON.stringify(expectedConfig, null, 2)
|
},
|
||||||
)
|
},
|
||||||
})
|
}
|
||||||
|
|
||||||
it('should throw error if secrets storage fails', async () => {
|
expect(mockSecrets.store).toHaveBeenCalledWith(
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
"roo_cline_config_api_config",
|
||||||
currentApiConfigName: 'default',
|
JSON.stringify(expectedConfig, null, 2),
|
||||||
apiConfigs: { default: {} }
|
)
|
||||||
}))
|
})
|
||||||
mockSecrets.store.mockRejectedValueOnce(new Error('Storage failed'))
|
|
||||||
|
it("should throw error if secrets storage fails", async () => {
|
||||||
await expect(configManager.SaveConfig('test', {})).rejects.toThrow(
|
mockSecrets.get.mockResolvedValue(
|
||||||
'Failed to save config: Error: Failed to write config to secrets: Error: Storage failed'
|
JSON.stringify({
|
||||||
)
|
currentApiConfigName: "default",
|
||||||
})
|
apiConfigs: { default: {} },
|
||||||
})
|
}),
|
||||||
|
)
|
||||||
describe('DeleteConfig', () => {
|
mockSecrets.store.mockRejectedValueOnce(new Error("Storage failed"))
|
||||||
it('should delete existing config', async () => {
|
|
||||||
const existingConfig: ApiConfigData = {
|
await expect(configManager.SaveConfig("test", {})).rejects.toThrow(
|
||||||
currentApiConfigName: 'default',
|
"Failed to save config: Error: Failed to write config to secrets: Error: Storage failed",
|
||||||
apiConfigs: {
|
)
|
||||||
default: {
|
})
|
||||||
id: 'default'
|
})
|
||||||
},
|
|
||||||
test: {
|
describe("DeleteConfig", () => {
|
||||||
apiProvider: 'anthropic',
|
it("should delete existing config", async () => {
|
||||||
id: 'test-id'
|
const existingConfig: ApiConfigData = {
|
||||||
}
|
currentApiConfigName: "default",
|
||||||
}
|
apiConfigs: {
|
||||||
}
|
default: {
|
||||||
|
id: "default",
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
},
|
||||||
|
test: {
|
||||||
await configManager.DeleteConfig('test')
|
apiProvider: "anthropic",
|
||||||
|
id: "test-id",
|
||||||
// Get the stored config to check the ID
|
},
|
||||||
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
|
},
|
||||||
expect(storedConfig.currentApiConfigName).toBe('default')
|
}
|
||||||
expect(Object.keys(storedConfig.apiConfigs)).toEqual(['default'])
|
|
||||||
expect(storedConfig.apiConfigs.default.id).toBeTruthy()
|
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
||||||
})
|
|
||||||
|
await configManager.DeleteConfig("test")
|
||||||
it('should throw error when trying to delete non-existent config', async () => {
|
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
// Get the stored config to check the ID
|
||||||
currentApiConfigName: 'default',
|
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
|
||||||
apiConfigs: { default: {} }
|
expect(storedConfig.currentApiConfigName).toBe("default")
|
||||||
}))
|
expect(Object.keys(storedConfig.apiConfigs)).toEqual(["default"])
|
||||||
|
expect(storedConfig.apiConfigs.default.id).toBeTruthy()
|
||||||
await expect(configManager.DeleteConfig('nonexistent')).rejects.toThrow(
|
})
|
||||||
"Config 'nonexistent' not found"
|
|
||||||
)
|
it("should throw error when trying to delete non-existent config", async () => {
|
||||||
})
|
mockSecrets.get.mockResolvedValue(
|
||||||
|
JSON.stringify({
|
||||||
it('should throw error when trying to delete last remaining config', async () => {
|
currentApiConfigName: "default",
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
apiConfigs: { default: {} },
|
||||||
currentApiConfigName: 'default',
|
}),
|
||||||
apiConfigs: {
|
)
|
||||||
default: {
|
|
||||||
id: 'default'
|
await expect(configManager.DeleteConfig("nonexistent")).rejects.toThrow("Config 'nonexistent' not found")
|
||||||
}
|
})
|
||||||
}
|
|
||||||
}))
|
it("should throw error when trying to delete last remaining config", async () => {
|
||||||
|
mockSecrets.get.mockResolvedValue(
|
||||||
await expect(configManager.DeleteConfig('default')).rejects.toThrow(
|
JSON.stringify({
|
||||||
'Cannot delete the last remaining configuration.'
|
currentApiConfigName: "default",
|
||||||
)
|
apiConfigs: {
|
||||||
})
|
default: {
|
||||||
})
|
id: "default",
|
||||||
|
},
|
||||||
describe('LoadConfig', () => {
|
},
|
||||||
it('should load config and update current config name', async () => {
|
}),
|
||||||
const existingConfig: ApiConfigData = {
|
)
|
||||||
currentApiConfigName: 'default',
|
|
||||||
apiConfigs: {
|
await expect(configManager.DeleteConfig("default")).rejects.toThrow(
|
||||||
test: {
|
"Cannot delete the last remaining configuration.",
|
||||||
apiProvider: 'anthropic',
|
)
|
||||||
apiKey: 'test-key',
|
})
|
||||||
id: 'test-id'
|
})
|
||||||
}
|
|
||||||
}
|
describe("LoadConfig", () => {
|
||||||
}
|
it("should load config and update current config name", async () => {
|
||||||
|
const existingConfig: ApiConfigData = {
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
currentApiConfigName: "default",
|
||||||
|
apiConfigs: {
|
||||||
const config = await configManager.LoadConfig('test')
|
test: {
|
||||||
|
apiProvider: "anthropic",
|
||||||
expect(config).toEqual({
|
apiKey: "test-key",
|
||||||
apiProvider: 'anthropic',
|
id: "test-id",
|
||||||
apiKey: 'test-key',
|
},
|
||||||
id: 'test-id'
|
},
|
||||||
})
|
}
|
||||||
|
|
||||||
// Get the stored config to check the structure
|
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
||||||
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
|
|
||||||
expect(storedConfig.currentApiConfigName).toBe('test')
|
const config = await configManager.LoadConfig("test")
|
||||||
expect(storedConfig.apiConfigs.test).toEqual({
|
|
||||||
apiProvider: 'anthropic',
|
expect(config).toEqual({
|
||||||
apiKey: 'test-key',
|
apiProvider: "anthropic",
|
||||||
id: 'test-id'
|
apiKey: "test-key",
|
||||||
})
|
id: "test-id",
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should throw error when config does not exist', async () => {
|
// Get the stored config to check the structure
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
|
||||||
currentApiConfigName: 'default',
|
expect(storedConfig.currentApiConfigName).toBe("test")
|
||||||
apiConfigs: {
|
expect(storedConfig.apiConfigs.test).toEqual({
|
||||||
default: {
|
apiProvider: "anthropic",
|
||||||
config: {},
|
apiKey: "test-key",
|
||||||
id: 'default'
|
id: "test-id",
|
||||||
}
|
})
|
||||||
}
|
})
|
||||||
}))
|
|
||||||
|
it("should throw error when config does not exist", async () => {
|
||||||
await expect(configManager.LoadConfig('nonexistent')).rejects.toThrow(
|
mockSecrets.get.mockResolvedValue(
|
||||||
"Config 'nonexistent' not found"
|
JSON.stringify({
|
||||||
)
|
currentApiConfigName: "default",
|
||||||
})
|
apiConfigs: {
|
||||||
|
default: {
|
||||||
it('should throw error if secrets storage fails', async () => {
|
config: {},
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
id: "default",
|
||||||
currentApiConfigName: 'default',
|
},
|
||||||
apiConfigs: {
|
},
|
||||||
test: {
|
}),
|
||||||
config: {
|
)
|
||||||
apiProvider: 'anthropic'
|
|
||||||
},
|
await expect(configManager.LoadConfig("nonexistent")).rejects.toThrow("Config 'nonexistent' not found")
|
||||||
id: 'test-id'
|
})
|
||||||
}
|
|
||||||
}
|
it("should throw error if secrets storage fails", async () => {
|
||||||
}))
|
mockSecrets.get.mockResolvedValue(
|
||||||
mockSecrets.store.mockRejectedValueOnce(new Error('Storage failed'))
|
JSON.stringify({
|
||||||
|
currentApiConfigName: "default",
|
||||||
await expect(configManager.LoadConfig('test')).rejects.toThrow(
|
apiConfigs: {
|
||||||
'Failed to load config: Error: Failed to write config to secrets: Error: Storage failed'
|
test: {
|
||||||
)
|
config: {
|
||||||
})
|
apiProvider: "anthropic",
|
||||||
})
|
},
|
||||||
|
id: "test-id",
|
||||||
describe('SetCurrentConfig', () => {
|
},
|
||||||
it('should set current config', async () => {
|
},
|
||||||
const existingConfig: ApiConfigData = {
|
}),
|
||||||
currentApiConfigName: 'default',
|
)
|
||||||
apiConfigs: {
|
mockSecrets.store.mockRejectedValueOnce(new Error("Storage failed"))
|
||||||
default: {
|
|
||||||
id: 'default'
|
await expect(configManager.LoadConfig("test")).rejects.toThrow(
|
||||||
},
|
"Failed to load config: Error: Failed to write config to secrets: Error: Storage failed",
|
||||||
test: {
|
)
|
||||||
apiProvider: 'anthropic',
|
})
|
||||||
id: 'test-id'
|
})
|
||||||
}
|
|
||||||
}
|
describe("SetCurrentConfig", () => {
|
||||||
}
|
it("should set current config", async () => {
|
||||||
|
const existingConfig: ApiConfigData = {
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
currentApiConfigName: "default",
|
||||||
|
apiConfigs: {
|
||||||
await configManager.SetCurrentConfig('test')
|
default: {
|
||||||
|
id: "default",
|
||||||
// Get the stored config to check the structure
|
},
|
||||||
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
|
test: {
|
||||||
expect(storedConfig.currentApiConfigName).toBe('test')
|
apiProvider: "anthropic",
|
||||||
expect(storedConfig.apiConfigs.default.id).toBe('default')
|
id: "test-id",
|
||||||
expect(storedConfig.apiConfigs.test).toEqual({
|
},
|
||||||
apiProvider: 'anthropic',
|
},
|
||||||
id: 'test-id'
|
}
|
||||||
})
|
|
||||||
})
|
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
||||||
|
|
||||||
it('should throw error when config does not exist', async () => {
|
await configManager.SetCurrentConfig("test")
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
|
||||||
currentApiConfigName: 'default',
|
// Get the stored config to check the structure
|
||||||
apiConfigs: { default: {} }
|
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
|
||||||
}))
|
expect(storedConfig.currentApiConfigName).toBe("test")
|
||||||
|
expect(storedConfig.apiConfigs.default.id).toBe("default")
|
||||||
await expect(configManager.SetCurrentConfig('nonexistent')).rejects.toThrow(
|
expect(storedConfig.apiConfigs.test).toEqual({
|
||||||
"Config 'nonexistent' not found"
|
apiProvider: "anthropic",
|
||||||
)
|
id: "test-id",
|
||||||
})
|
})
|
||||||
|
})
|
||||||
it('should throw error if secrets storage fails', async () => {
|
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
it("should throw error when config does not exist", async () => {
|
||||||
currentApiConfigName: 'default',
|
mockSecrets.get.mockResolvedValue(
|
||||||
apiConfigs: {
|
JSON.stringify({
|
||||||
test: { apiProvider: 'anthropic' }
|
currentApiConfigName: "default",
|
||||||
}
|
apiConfigs: { default: {} },
|
||||||
}))
|
}),
|
||||||
mockSecrets.store.mockRejectedValueOnce(new Error('Storage failed'))
|
)
|
||||||
|
|
||||||
await expect(configManager.SetCurrentConfig('test')).rejects.toThrow(
|
await expect(configManager.SetCurrentConfig("nonexistent")).rejects.toThrow(
|
||||||
'Failed to set current config: Error: Failed to write config to secrets: Error: Storage failed'
|
"Config 'nonexistent' not found",
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
})
|
|
||||||
|
it("should throw error if secrets storage fails", async () => {
|
||||||
describe('HasConfig', () => {
|
mockSecrets.get.mockResolvedValue(
|
||||||
it('should return true for existing config', async () => {
|
JSON.stringify({
|
||||||
const existingConfig: ApiConfigData = {
|
currentApiConfigName: "default",
|
||||||
currentApiConfigName: 'default',
|
apiConfigs: {
|
||||||
apiConfigs: {
|
test: { apiProvider: "anthropic" },
|
||||||
default: {
|
},
|
||||||
id: 'default'
|
}),
|
||||||
},
|
)
|
||||||
test: {
|
mockSecrets.store.mockRejectedValueOnce(new Error("Storage failed"))
|
||||||
apiProvider: 'anthropic',
|
|
||||||
id: 'test-id'
|
await expect(configManager.SetCurrentConfig("test")).rejects.toThrow(
|
||||||
}
|
"Failed to set current config: Error: Failed to write config to secrets: Error: Storage failed",
|
||||||
}
|
)
|
||||||
}
|
})
|
||||||
|
})
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
|
||||||
|
describe("HasConfig", () => {
|
||||||
const hasConfig = await configManager.HasConfig('test')
|
it("should return true for existing config", async () => {
|
||||||
expect(hasConfig).toBe(true)
|
const existingConfig: ApiConfigData = {
|
||||||
})
|
currentApiConfigName: "default",
|
||||||
|
apiConfigs: {
|
||||||
it('should return false for non-existent config', async () => {
|
default: {
|
||||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
id: "default",
|
||||||
currentApiConfigName: 'default',
|
},
|
||||||
apiConfigs: { default: {} }
|
test: {
|
||||||
}))
|
apiProvider: "anthropic",
|
||||||
|
id: "test-id",
|
||||||
const hasConfig = await configManager.HasConfig('nonexistent')
|
},
|
||||||
expect(hasConfig).toBe(false)
|
},
|
||||||
})
|
}
|
||||||
|
|
||||||
it('should throw error if secrets storage fails', async () => {
|
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
||||||
mockSecrets.get.mockRejectedValue(new Error('Storage failed'))
|
|
||||||
|
const hasConfig = await configManager.HasConfig("test")
|
||||||
await expect(configManager.HasConfig('test')).rejects.toThrow(
|
expect(hasConfig).toBe(true)
|
||||||
'Failed to check config existence: Error: Failed to read config from secrets: Error: Storage failed'
|
})
|
||||||
)
|
|
||||||
})
|
it("should return false for non-existent config", async () => {
|
||||||
})
|
mockSecrets.get.mockResolvedValue(
|
||||||
})
|
JSON.stringify({
|
||||||
|
currentApiConfigName: "default",
|
||||||
|
apiConfigs: { default: {} },
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
const hasConfig = await configManager.HasConfig("nonexistent")
|
||||||
|
expect(hasConfig).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should throw error if secrets storage fails", async () => {
|
||||||
|
mockSecrets.get.mockRejectedValue(new Error("Storage failed"))
|
||||||
|
|
||||||
|
await expect(configManager.HasConfig("test")).rejects.toThrow(
|
||||||
|
"Failed to check config existence: Error: Failed to read config from secrets: Error: Storage failed",
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -1,17 +1,21 @@
|
|||||||
import type { DiffStrategy } from './types'
|
import type { DiffStrategy } from "./types"
|
||||||
import { UnifiedDiffStrategy } from './strategies/unified'
|
import { UnifiedDiffStrategy } from "./strategies/unified"
|
||||||
import { SearchReplaceDiffStrategy } from './strategies/search-replace'
|
import { SearchReplaceDiffStrategy } from "./strategies/search-replace"
|
||||||
import { NewUnifiedDiffStrategy } from './strategies/new-unified'
|
import { NewUnifiedDiffStrategy } from "./strategies/new-unified"
|
||||||
/**
|
/**
|
||||||
* Get the appropriate diff strategy for the given model
|
* Get the appropriate diff strategy for the given model
|
||||||
* @param model The name of the model being used (e.g., 'gpt-4', 'claude-3-opus')
|
* @param model The name of the model being used (e.g., 'gpt-4', 'claude-3-opus')
|
||||||
* @returns The appropriate diff strategy for the model
|
* @returns The appropriate diff strategy for the model
|
||||||
*/
|
*/
|
||||||
export function getDiffStrategy(model: string, fuzzyMatchThreshold?: number, experimentalDiffStrategy: boolean = false): DiffStrategy {
|
export function getDiffStrategy(
|
||||||
if (experimentalDiffStrategy) {
|
model: string,
|
||||||
return new NewUnifiedDiffStrategy(fuzzyMatchThreshold)
|
fuzzyMatchThreshold?: number,
|
||||||
}
|
experimentalDiffStrategy: boolean = false,
|
||||||
return new SearchReplaceDiffStrategy(fuzzyMatchThreshold)
|
): DiffStrategy {
|
||||||
|
if (experimentalDiffStrategy) {
|
||||||
|
return new NewUnifiedDiffStrategy(fuzzyMatchThreshold)
|
||||||
|
}
|
||||||
|
return new SearchReplaceDiffStrategy(fuzzyMatchThreshold)
|
||||||
}
|
}
|
||||||
|
|
||||||
export type { DiffStrategy }
|
export type { DiffStrategy }
|
||||||
|
|||||||
@@ -1,74 +1,73 @@
|
|||||||
import { NewUnifiedDiffStrategy } from '../new-unified';
|
import { NewUnifiedDiffStrategy } from "../new-unified"
|
||||||
|
|
||||||
describe('main', () => {
|
describe("main", () => {
|
||||||
|
let strategy: NewUnifiedDiffStrategy
|
||||||
|
|
||||||
let strategy: NewUnifiedDiffStrategy
|
beforeEach(() => {
|
||||||
|
strategy = new NewUnifiedDiffStrategy(0.97)
|
||||||
|
})
|
||||||
|
|
||||||
beforeEach(() => {
|
describe("constructor", () => {
|
||||||
strategy = new NewUnifiedDiffStrategy(0.97)
|
it("should use default confidence threshold when not provided", () => {
|
||||||
})
|
const defaultStrategy = new NewUnifiedDiffStrategy()
|
||||||
|
expect(defaultStrategy["confidenceThreshold"]).toBe(1)
|
||||||
|
})
|
||||||
|
|
||||||
describe('constructor', () => {
|
it("should use provided confidence threshold", () => {
|
||||||
it('should use default confidence threshold when not provided', () => {
|
const customStrategy = new NewUnifiedDiffStrategy(0.85)
|
||||||
const defaultStrategy = new NewUnifiedDiffStrategy()
|
expect(customStrategy["confidenceThreshold"]).toBe(0.85)
|
||||||
expect(defaultStrategy['confidenceThreshold']).toBe(1)
|
})
|
||||||
})
|
|
||||||
|
|
||||||
it('should use provided confidence threshold', () => {
|
it("should enforce minimum confidence threshold", () => {
|
||||||
const customStrategy = new NewUnifiedDiffStrategy(0.85)
|
const lowStrategy = new NewUnifiedDiffStrategy(0.7) // Below minimum of 0.8
|
||||||
expect(customStrategy['confidenceThreshold']).toBe(0.85)
|
expect(lowStrategy["confidenceThreshold"]).toBe(0.8)
|
||||||
})
|
})
|
||||||
|
})
|
||||||
|
|
||||||
it('should enforce minimum confidence threshold', () => {
|
describe("getToolDescription", () => {
|
||||||
const lowStrategy = new NewUnifiedDiffStrategy(0.7) // Below minimum of 0.8
|
it("should return tool description with correct cwd", () => {
|
||||||
expect(lowStrategy['confidenceThreshold']).toBe(0.8)
|
const cwd = "/test/path"
|
||||||
})
|
const description = strategy.getToolDescription({ cwd })
|
||||||
})
|
|
||||||
|
|
||||||
describe('getToolDescription', () => {
|
expect(description).toContain("apply_diff")
|
||||||
it('should return tool description with correct cwd', () => {
|
expect(description).toContain(cwd)
|
||||||
const cwd = '/test/path'
|
expect(description).toContain("Parameters:")
|
||||||
const description = strategy.getToolDescription({ cwd })
|
expect(description).toContain("Format Requirements:")
|
||||||
|
})
|
||||||
expect(description).toContain('apply_diff')
|
})
|
||||||
expect(description).toContain(cwd)
|
|
||||||
expect(description).toContain('Parameters:')
|
|
||||||
expect(description).toContain('Format Requirements:')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should apply simple diff correctly', async () => {
|
it("should apply simple diff correctly", async () => {
|
||||||
const original = `line1
|
const original = `line1
|
||||||
line2
|
line2
|
||||||
line3`;
|
line3`
|
||||||
|
|
||||||
const diff = `--- a/file.txt
|
const diff = `--- a/file.txt
|
||||||
+++ b/file.txt
|
+++ b/file.txt
|
||||||
@@ ... @@
|
@@ ... @@
|
||||||
line1
|
line1
|
||||||
+new line
|
+new line
|
||||||
line2
|
line2
|
||||||
-line3
|
-line3
|
||||||
+modified line3`;
|
+modified line3`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(original, diff);
|
const result = await strategy.applyDiff(original, diff)
|
||||||
expect(result.success).toBe(true);
|
expect(result.success).toBe(true)
|
||||||
if(result.success) {
|
if (result.success) {
|
||||||
expect(result.content).toBe(`line1
|
expect(result.content).toBe(`line1
|
||||||
new line
|
new line
|
||||||
line2
|
line2
|
||||||
modified line3`);
|
modified line3`)
|
||||||
}
|
}
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle multiple hunks', async () => {
|
it("should handle multiple hunks", async () => {
|
||||||
const original = `line1
|
const original = `line1
|
||||||
line2
|
line2
|
||||||
line3
|
line3
|
||||||
line4
|
line4
|
||||||
line5`;
|
line5`
|
||||||
|
|
||||||
const diff = `--- a/file.txt
|
const diff = `--- a/file.txt
|
||||||
+++ b/file.txt
|
+++ b/file.txt
|
||||||
@@ ... @@
|
@@ ... @@
|
||||||
line1
|
line1
|
||||||
@@ -80,23 +79,23 @@ line5`;
|
|||||||
line4
|
line4
|
||||||
-line5
|
-line5
|
||||||
+modified line5
|
+modified line5
|
||||||
+new line at end`;
|
+new line at end`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(original, diff);
|
const result = await strategy.applyDiff(original, diff)
|
||||||
expect(result.success).toBe(true);
|
expect(result.success).toBe(true)
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
expect(result.content).toBe(`line1
|
expect(result.content).toBe(`line1
|
||||||
new line
|
new line
|
||||||
line2
|
line2
|
||||||
modified line3
|
modified line3
|
||||||
line4
|
line4
|
||||||
modified line5
|
modified line5
|
||||||
new line at end`);
|
new line at end`)
|
||||||
}
|
}
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle complex large', async () => {
|
it("should handle complex large", async () => {
|
||||||
const original = `line1
|
const original = `line1
|
||||||
line2
|
line2
|
||||||
line3
|
line3
|
||||||
line4
|
line4
|
||||||
@@ -105,9 +104,9 @@ line6
|
|||||||
line7
|
line7
|
||||||
line8
|
line8
|
||||||
line9
|
line9
|
||||||
line10`;
|
line10`
|
||||||
|
|
||||||
const diff = `--- a/file.txt
|
const diff = `--- a/file.txt
|
||||||
+++ b/file.txt
|
+++ b/file.txt
|
||||||
@@ ... @@
|
@@ ... @@
|
||||||
line1
|
line1
|
||||||
@@ -130,12 +129,12 @@ line10`;
|
|||||||
line9
|
line9
|
||||||
-line10
|
-line10
|
||||||
+final line
|
+final line
|
||||||
+very last line`;
|
+very last line`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(original, diff);
|
const result = await strategy.applyDiff(original, diff)
|
||||||
expect(result.success).toBe(true);
|
expect(result.success).toBe(true)
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
expect(result.content).toBe(`line1
|
expect(result.content).toBe(`line1
|
||||||
header line
|
header line
|
||||||
another header
|
another header
|
||||||
line2
|
line2
|
||||||
@@ -150,12 +149,12 @@ changed line8
|
|||||||
bonus line
|
bonus line
|
||||||
line9
|
line9
|
||||||
final line
|
final line
|
||||||
very last line`);
|
very last line`)
|
||||||
}
|
}
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle indentation changes', async () => {
|
it("should handle indentation changes", async () => {
|
||||||
const original = `first line
|
const original = `first line
|
||||||
indented line
|
indented line
|
||||||
double indented line
|
double indented line
|
||||||
back to single indent
|
back to single indent
|
||||||
@@ -164,9 +163,9 @@ no indent
|
|||||||
double indent again
|
double indent again
|
||||||
triple indent
|
triple indent
|
||||||
back to single
|
back to single
|
||||||
last line`;
|
last line`
|
||||||
|
|
||||||
const diff = `--- original
|
const diff = `--- original
|
||||||
+++ modified
|
+++ modified
|
||||||
@@ ... @@
|
@@ ... @@
|
||||||
first line
|
first line
|
||||||
@@ -181,9 +180,9 @@ last line`;
|
|||||||
- triple indent
|
- triple indent
|
||||||
+ hi there mate
|
+ hi there mate
|
||||||
back to single
|
back to single
|
||||||
last line`;
|
last line`
|
||||||
|
|
||||||
const expected = `first line
|
const expected = `first line
|
||||||
indented line
|
indented line
|
||||||
tab indented line
|
tab indented line
|
||||||
new indented line
|
new indented line
|
||||||
@@ -194,23 +193,22 @@ no indent
|
|||||||
double indent again
|
double indent again
|
||||||
hi there mate
|
hi there mate
|
||||||
back to single
|
back to single
|
||||||
last line`;
|
last line`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(original, diff);
|
const result = await strategy.applyDiff(original, diff)
|
||||||
expect(result.success).toBe(true);
|
expect(result.success).toBe(true)
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
expect(result.content).toBe(expected);
|
expect(result.content).toBe(expected)
|
||||||
}
|
}
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle high level edits', async () => {
|
it("should handle high level edits", async () => {
|
||||||
|
const original = `def factorial(n):
|
||||||
const original = `def factorial(n):
|
|
||||||
if n == 0:
|
if n == 0:
|
||||||
return 1
|
return 1
|
||||||
else:
|
else:
|
||||||
return n * factorial(n-1)`
|
return n * factorial(n-1)`
|
||||||
const diff = `@@ ... @@
|
const diff = `@@ ... @@
|
||||||
-def factorial(n):
|
-def factorial(n):
|
||||||
- if n == 0:
|
- if n == 0:
|
||||||
- return 1
|
- return 1
|
||||||
@@ -222,21 +220,21 @@ last line`;
|
|||||||
+ else:
|
+ else:
|
||||||
+ return number * factorial(number-1)`
|
+ return number * factorial(number-1)`
|
||||||
|
|
||||||
const expected = `def factorial(number):
|
const expected = `def factorial(number):
|
||||||
if number == 0:
|
if number == 0:
|
||||||
return 1
|
return 1
|
||||||
else:
|
else:
|
||||||
return number * factorial(number-1)`
|
return number * factorial(number-1)`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(original, diff);
|
const result = await strategy.applyDiff(original, diff)
|
||||||
expect(result.success).toBe(true);
|
expect(result.success).toBe(true)
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
expect(result.content).toBe(expected);
|
expect(result.content).toBe(expected)
|
||||||
}
|
}
|
||||||
});
|
})
|
||||||
|
|
||||||
it('it should handle very complex edits', async () => {
|
it("it should handle very complex edits", async () => {
|
||||||
const original = `//Initialize the array that will hold the primes
|
const original = `//Initialize the array that will hold the primes
|
||||||
var primeArray = [];
|
var primeArray = [];
|
||||||
/*Write a function that checks for primeness and
|
/*Write a function that checks for primeness and
|
||||||
pushes those values to t*he array*/
|
pushes those values to t*he array*/
|
||||||
@@ -269,7 +267,7 @@ for (var i = 2; primeArray.length < numPrimes; i++) {
|
|||||||
console.log(primeArray);
|
console.log(primeArray);
|
||||||
`
|
`
|
||||||
|
|
||||||
const diff = `--- test_diff.js
|
const diff = `--- test_diff.js
|
||||||
+++ test_diff.js
|
+++ test_diff.js
|
||||||
@@ ... @@
|
@@ ... @@
|
||||||
-//Initialize the array that will hold the primes
|
-//Initialize the array that will hold the primes
|
||||||
@@ -297,7 +295,7 @@ console.log(primeArray);
|
|||||||
}
|
}
|
||||||
console.log(primeArray);`
|
console.log(primeArray);`
|
||||||
|
|
||||||
const expected = `var primeArray = [];
|
const expected = `var primeArray = [];
|
||||||
function PrimeCheck(candidate){
|
function PrimeCheck(candidate){
|
||||||
isPrime = true;
|
isPrime = true;
|
||||||
for(var i = 2; i < candidate && isPrime; i++){
|
for(var i = 2; i < candidate && isPrime; i++){
|
||||||
@@ -320,58 +318,57 @@ for (var i = 2; primeArray.length < numPrimes; i++) {
|
|||||||
}
|
}
|
||||||
console.log(primeArray);
|
console.log(primeArray);
|
||||||
`
|
`
|
||||||
|
|
||||||
|
|
||||||
const result = await strategy.applyDiff(original, diff);
|
const result = await strategy.applyDiff(original, diff)
|
||||||
expect(result.success).toBe(true);
|
expect(result.success).toBe(true)
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
expect(result.content).toBe(expected);
|
expect(result.content).toBe(expected)
|
||||||
}
|
}
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('error handling and edge cases', () => {
|
describe("error handling and edge cases", () => {
|
||||||
it('should reject completely invalid diff format', async () => {
|
it("should reject completely invalid diff format", async () => {
|
||||||
const original = 'line1\nline2\nline3';
|
const original = "line1\nline2\nline3"
|
||||||
const invalidDiff = 'this is not a diff at all';
|
const invalidDiff = "this is not a diff at all"
|
||||||
|
|
||||||
const result = await strategy.applyDiff(original, invalidDiff);
|
|
||||||
expect(result.success).toBe(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should reject diff with invalid hunk format', async () => {
|
const result = await strategy.applyDiff(original, invalidDiff)
|
||||||
const original = 'line1\nline2\nline3';
|
expect(result.success).toBe(false)
|
||||||
const invalidHunkDiff = `--- a/file.txt
|
})
|
||||||
|
|
||||||
|
it("should reject diff with invalid hunk format", async () => {
|
||||||
|
const original = "line1\nline2\nline3"
|
||||||
|
const invalidHunkDiff = `--- a/file.txt
|
||||||
+++ b/file.txt
|
+++ b/file.txt
|
||||||
invalid hunk header
|
invalid hunk header
|
||||||
line1
|
line1
|
||||||
-line2
|
-line2
|
||||||
+new line`;
|
+new line`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(original, invalidHunkDiff);
|
|
||||||
expect(result.success).toBe(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should fail when diff tries to modify non-existent content', async () => {
|
const result = await strategy.applyDiff(original, invalidHunkDiff)
|
||||||
const original = 'line1\nline2\nline3';
|
expect(result.success).toBe(false)
|
||||||
const nonMatchingDiff = `--- a/file.txt
|
})
|
||||||
|
|
||||||
|
it("should fail when diff tries to modify non-existent content", async () => {
|
||||||
|
const original = "line1\nline2\nline3"
|
||||||
|
const nonMatchingDiff = `--- a/file.txt
|
||||||
+++ b/file.txt
|
+++ b/file.txt
|
||||||
@@ ... @@
|
@@ ... @@
|
||||||
line1
|
line1
|
||||||
-nonexistent line
|
-nonexistent line
|
||||||
+new line
|
+new line
|
||||||
line3`;
|
line3`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(original, nonMatchingDiff);
|
|
||||||
expect(result.success).toBe(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle overlapping hunks', async () => {
|
const result = await strategy.applyDiff(original, nonMatchingDiff)
|
||||||
const original = `line1
|
expect(result.success).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle overlapping hunks", async () => {
|
||||||
|
const original = `line1
|
||||||
line2
|
line2
|
||||||
line3
|
line3
|
||||||
line4
|
line4
|
||||||
line5`;
|
line5`
|
||||||
const overlappingDiff = `--- a/file.txt
|
const overlappingDiff = `--- a/file.txt
|
||||||
+++ b/file.txt
|
+++ b/file.txt
|
||||||
@@ ... @@
|
@@ ... @@
|
||||||
line1
|
line1
|
||||||
@@ -384,19 +381,19 @@ line5`;
|
|||||||
-line3
|
-line3
|
||||||
-line4
|
-line4
|
||||||
+modified3and4
|
+modified3and4
|
||||||
line5`;
|
line5`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(original, overlappingDiff);
|
|
||||||
expect(result.success).toBe(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle empty lines modifications', async () => {
|
const result = await strategy.applyDiff(original, overlappingDiff)
|
||||||
const original = `line1
|
expect(result.success).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle empty lines modifications", async () => {
|
||||||
|
const original = `line1
|
||||||
|
|
||||||
line3
|
line3
|
||||||
|
|
||||||
line5`;
|
line5`
|
||||||
const emptyLinesDiff = `--- a/file.txt
|
const emptyLinesDiff = `--- a/file.txt
|
||||||
+++ b/file.txt
|
+++ b/file.txt
|
||||||
@@ ... @@
|
@@ ... @@
|
||||||
line1
|
line1
|
||||||
@@ -404,73 +401,73 @@ line5`;
|
|||||||
-line3
|
-line3
|
||||||
+line3modified
|
+line3modified
|
||||||
|
|
||||||
line5`;
|
line5`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(original, emptyLinesDiff);
|
const result = await strategy.applyDiff(original, emptyLinesDiff)
|
||||||
expect(result.success).toBe(true);
|
expect(result.success).toBe(true)
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
expect(result.content).toBe(`line1
|
expect(result.content).toBe(`line1
|
||||||
|
|
||||||
line3modified
|
line3modified
|
||||||
|
|
||||||
line5`);
|
line5`)
|
||||||
}
|
}
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle mixed line endings in diff', async () => {
|
it("should handle mixed line endings in diff", async () => {
|
||||||
const original = 'line1\r\nline2\nline3\r\n';
|
const original = "line1\r\nline2\nline3\r\n"
|
||||||
const mixedEndingsDiff = `--- a/file.txt
|
const mixedEndingsDiff = `--- a/file.txt
|
||||||
+++ b/file.txt
|
+++ b/file.txt
|
||||||
@@ ... @@
|
@@ ... @@
|
||||||
line1\r
|
line1\r
|
||||||
-line2
|
-line2
|
||||||
+modified2\r
|
+modified2\r
|
||||||
line3`;
|
line3`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(original, mixedEndingsDiff);
|
|
||||||
expect(result.success).toBe(true);
|
|
||||||
if (result.success) {
|
|
||||||
expect(result.content).toBe('line1\r\nmodified2\r\nline3\r\n');
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle partial line modifications', async () => {
|
const result = await strategy.applyDiff(original, mixedEndingsDiff)
|
||||||
const original = 'const value = oldValue + 123;';
|
expect(result.success).toBe(true)
|
||||||
const partialDiff = `--- a/file.txt
|
if (result.success) {
|
||||||
|
expect(result.content).toBe("line1\r\nmodified2\r\nline3\r\n")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle partial line modifications", async () => {
|
||||||
|
const original = "const value = oldValue + 123;"
|
||||||
|
const partialDiff = `--- a/file.txt
|
||||||
+++ b/file.txt
|
+++ b/file.txt
|
||||||
@@ ... @@
|
@@ ... @@
|
||||||
-const value = oldValue + 123;
|
-const value = oldValue + 123;
|
||||||
+const value = newValue + 123;`;
|
+const value = newValue + 123;`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(original, partialDiff);
|
|
||||||
expect(result.success).toBe(true);
|
|
||||||
if (result.success) {
|
|
||||||
expect(result.content).toBe('const value = newValue + 123;');
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle slightly malformed but recoverable diff', async () => {
|
const result = await strategy.applyDiff(original, partialDiff)
|
||||||
const original = 'line1\nline2\nline3';
|
expect(result.success).toBe(true)
|
||||||
// Missing space after --- and +++
|
if (result.success) {
|
||||||
const slightlyBadDiff = `---a/file.txt
|
expect(result.content).toBe("const value = newValue + 123;")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle slightly malformed but recoverable diff", async () => {
|
||||||
|
const original = "line1\nline2\nline3"
|
||||||
|
// Missing space after --- and +++
|
||||||
|
const slightlyBadDiff = `---a/file.txt
|
||||||
+++b/file.txt
|
+++b/file.txt
|
||||||
@@ ... @@
|
@@ ... @@
|
||||||
line1
|
line1
|
||||||
-line2
|
-line2
|
||||||
+new line
|
+new line
|
||||||
line3`;
|
line3`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(original, slightlyBadDiff);
|
|
||||||
expect(result.success).toBe(true);
|
|
||||||
if (result.success) {
|
|
||||||
expect(result.content).toBe('line1\nnew line\nline3');
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('similar code sections', () => {
|
const result = await strategy.applyDiff(original, slightlyBadDiff)
|
||||||
it('should correctly modify the right section when similar code exists', async () => {
|
expect(result.success).toBe(true)
|
||||||
const original = `function add(a, b) {
|
if (result.success) {
|
||||||
|
expect(result.content).toBe("line1\nnew line\nline3")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("similar code sections", () => {
|
||||||
|
it("should correctly modify the right section when similar code exists", async () => {
|
||||||
|
const original = `function add(a, b) {
|
||||||
return a + b;
|
return a + b;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -480,20 +477,20 @@ function subtract(a, b) {
|
|||||||
|
|
||||||
function multiply(a, b) {
|
function multiply(a, b) {
|
||||||
return a + b; // Bug here
|
return a + b; // Bug here
|
||||||
}`;
|
}`
|
||||||
|
|
||||||
const diff = `--- a/math.js
|
const diff = `--- a/math.js
|
||||||
+++ b/math.js
|
+++ b/math.js
|
||||||
@@ ... @@
|
@@ ... @@
|
||||||
function multiply(a, b) {
|
function multiply(a, b) {
|
||||||
- return a + b; // Bug here
|
- return a + b; // Bug here
|
||||||
+ return a * b;
|
+ return a * b;
|
||||||
}`;
|
}`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(original, diff);
|
const result = await strategy.applyDiff(original, diff)
|
||||||
expect(result.success).toBe(true);
|
expect(result.success).toBe(true)
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
expect(result.content).toBe(`function add(a, b) {
|
expect(result.content).toBe(`function add(a, b) {
|
||||||
return a + b;
|
return a + b;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -503,12 +500,12 @@ function subtract(a, b) {
|
|||||||
|
|
||||||
function multiply(a, b) {
|
function multiply(a, b) {
|
||||||
return a * b;
|
return a * b;
|
||||||
}`);
|
}`)
|
||||||
}
|
}
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should handle multiple similar sections with correct context', async () => {
|
it("should handle multiple similar sections with correct context", async () => {
|
||||||
const original = `if (condition) {
|
const original = `if (condition) {
|
||||||
doSomething();
|
doSomething();
|
||||||
doSomething();
|
doSomething();
|
||||||
doSomething();
|
doSomething();
|
||||||
@@ -518,9 +515,9 @@ if (otherCondition) {
|
|||||||
doSomething();
|
doSomething();
|
||||||
doSomething();
|
doSomething();
|
||||||
doSomething();
|
doSomething();
|
||||||
}`;
|
}`
|
||||||
|
|
||||||
const diff = `--- a/file.js
|
const diff = `--- a/file.js
|
||||||
+++ b/file.js
|
+++ b/file.js
|
||||||
@@ ... @@
|
@@ ... @@
|
||||||
if (otherCondition) {
|
if (otherCondition) {
|
||||||
@@ -528,12 +525,12 @@ if (otherCondition) {
|
|||||||
- doSomething();
|
- doSomething();
|
||||||
+ doSomethingElse();
|
+ doSomethingElse();
|
||||||
doSomething();
|
doSomething();
|
||||||
}`;
|
}`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(original, diff);
|
const result = await strategy.applyDiff(original, diff)
|
||||||
expect(result.success).toBe(true);
|
expect(result.success).toBe(true)
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
expect(result.content).toBe(`if (condition) {
|
expect(result.content).toBe(`if (condition) {
|
||||||
doSomething();
|
doSomething();
|
||||||
doSomething();
|
doSomething();
|
||||||
doSomething();
|
doSomething();
|
||||||
@@ -543,14 +540,14 @@ if (otherCondition) {
|
|||||||
doSomething();
|
doSomething();
|
||||||
doSomethingElse();
|
doSomethingElse();
|
||||||
doSomething();
|
doSomething();
|
||||||
}`);
|
}`)
|
||||||
}
|
}
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('hunk splitting', () => {
|
describe("hunk splitting", () => {
|
||||||
it('should handle large diffs with multiple non-contiguous changes', async () => {
|
it("should handle large diffs with multiple non-contiguous changes", async () => {
|
||||||
const original = `import { readFile } from 'fs';
|
const original = `import { readFile } from 'fs';
|
||||||
import { join } from 'path';
|
import { join } from 'path';
|
||||||
import { Logger } from './logger';
|
import { Logger } from './logger';
|
||||||
|
|
||||||
@@ -595,9 +592,9 @@ export {
|
|||||||
validateInput,
|
validateInput,
|
||||||
writeOutput,
|
writeOutput,
|
||||||
parseConfig
|
parseConfig
|
||||||
};`;
|
};`
|
||||||
|
|
||||||
const diff = `--- a/file.ts
|
const diff = `--- a/file.ts
|
||||||
+++ b/file.ts
|
+++ b/file.ts
|
||||||
@@ ... @@
|
@@ ... @@
|
||||||
-import { readFile } from 'fs';
|
-import { readFile } from 'fs';
|
||||||
@@ -672,9 +669,9 @@ export {
|
|||||||
- parseConfig
|
- parseConfig
|
||||||
+ parseConfig,
|
+ parseConfig,
|
||||||
+ type Config
|
+ type Config
|
||||||
};`;
|
};`
|
||||||
|
|
||||||
const expected = `import { readFile, writeFile } from 'fs';
|
const expected = `import { readFile, writeFile } from 'fs';
|
||||||
import { join } from 'path';
|
import { join } from 'path';
|
||||||
import { Logger } from './utils/logger';
|
import { Logger } from './utils/logger';
|
||||||
import { Config } from './types';
|
import { Config } from './types';
|
||||||
@@ -727,13 +724,13 @@ export {
|
|||||||
writeOutput,
|
writeOutput,
|
||||||
parseConfig,
|
parseConfig,
|
||||||
type Config
|
type Config
|
||||||
};`;
|
};`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(original, diff);
|
const result = await strategy.applyDiff(original, diff)
|
||||||
expect(result.success).toBe(true);
|
expect(result.success).toBe(true)
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
expect(result.content).toBe(expected);
|
expect(result.content).toBe(expected)
|
||||||
}
|
}
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,27 +1,27 @@
|
|||||||
import { UnifiedDiffStrategy } from '../unified'
|
import { UnifiedDiffStrategy } from "../unified"
|
||||||
|
|
||||||
describe('UnifiedDiffStrategy', () => {
|
describe("UnifiedDiffStrategy", () => {
|
||||||
let strategy: UnifiedDiffStrategy
|
let strategy: UnifiedDiffStrategy
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
strategy = new UnifiedDiffStrategy()
|
strategy = new UnifiedDiffStrategy()
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('getToolDescription', () => {
|
describe("getToolDescription", () => {
|
||||||
it('should return tool description with correct cwd', () => {
|
it("should return tool description with correct cwd", () => {
|
||||||
const cwd = '/test/path'
|
const cwd = "/test/path"
|
||||||
const description = strategy.getToolDescription({ cwd })
|
const description = strategy.getToolDescription({ cwd })
|
||||||
|
|
||||||
expect(description).toContain('apply_diff')
|
|
||||||
expect(description).toContain(cwd)
|
|
||||||
expect(description).toContain('Parameters:')
|
|
||||||
expect(description).toContain('Format Requirements:')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('applyDiff', () => {
|
expect(description).toContain("apply_diff")
|
||||||
it('should successfully apply a function modification diff', async () => {
|
expect(description).toContain(cwd)
|
||||||
const originalContent = `import { Logger } from '../logger';
|
expect(description).toContain("Parameters:")
|
||||||
|
expect(description).toContain("Format Requirements:")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("applyDiff", () => {
|
||||||
|
it("should successfully apply a function modification diff", async () => {
|
||||||
|
const originalContent = `import { Logger } from '../logger';
|
||||||
|
|
||||||
function calculateTotal(items: number[]): number {
|
function calculateTotal(items: number[]): number {
|
||||||
return items.reduce((sum, item) => {
|
return items.reduce((sum, item) => {
|
||||||
@@ -31,7 +31,7 @@ function calculateTotal(items: number[]): number {
|
|||||||
|
|
||||||
export { calculateTotal };`
|
export { calculateTotal };`
|
||||||
|
|
||||||
const diffContent = `--- src/utils/helper.ts
|
const diffContent = `--- src/utils/helper.ts
|
||||||
+++ src/utils/helper.ts
|
+++ src/utils/helper.ts
|
||||||
@@ -1,9 +1,10 @@
|
@@ -1,9 +1,10 @@
|
||||||
import { Logger } from '../logger';
|
import { Logger } from '../logger';
|
||||||
@@ -47,7 +47,7 @@ export { calculateTotal };`
|
|||||||
|
|
||||||
export { calculateTotal };`
|
export { calculateTotal };`
|
||||||
|
|
||||||
const expected = `import { Logger } from '../logger';
|
const expected = `import { Logger } from '../logger';
|
||||||
|
|
||||||
function calculateTotal(items: number[]): number {
|
function calculateTotal(items: number[]): number {
|
||||||
const total = items.reduce((sum, item) => {
|
const total = items.reduce((sum, item) => {
|
||||||
@@ -58,21 +58,21 @@ function calculateTotal(items: number[]): number {
|
|||||||
|
|
||||||
export { calculateTotal };`
|
export { calculateTotal };`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||||
expect(result.success).toBe(true)
|
expect(result.success).toBe(true)
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
expect(result.content).toBe(expected)
|
expect(result.content).toBe(expected)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should successfully apply a diff adding a new method', async () => {
|
it("should successfully apply a diff adding a new method", async () => {
|
||||||
const originalContent = `class Calculator {
|
const originalContent = `class Calculator {
|
||||||
add(a: number, b: number): number {
|
add(a: number, b: number): number {
|
||||||
return a + b;
|
return a + b;
|
||||||
}
|
}
|
||||||
}`
|
}`
|
||||||
|
|
||||||
const diffContent = `--- src/Calculator.ts
|
const diffContent = `--- src/Calculator.ts
|
||||||
+++ src/Calculator.ts
|
+++ src/Calculator.ts
|
||||||
@@ -1,5 +1,9 @@
|
@@ -1,5 +1,9 @@
|
||||||
class Calculator {
|
class Calculator {
|
||||||
@@ -85,7 +85,7 @@ export { calculateTotal };`
|
|||||||
+ }
|
+ }
|
||||||
}`
|
}`
|
||||||
|
|
||||||
const expected = `class Calculator {
|
const expected = `class Calculator {
|
||||||
add(a: number, b: number): number {
|
add(a: number, b: number): number {
|
||||||
return a + b;
|
return a + b;
|
||||||
}
|
}
|
||||||
@@ -95,15 +95,15 @@ export { calculateTotal };`
|
|||||||
}
|
}
|
||||||
}`
|
}`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||||
expect(result.success).toBe(true)
|
expect(result.success).toBe(true)
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
expect(result.content).toBe(expected)
|
expect(result.content).toBe(expected)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should successfully apply a diff modifying imports', async () => {
|
it("should successfully apply a diff modifying imports", async () => {
|
||||||
const originalContent = `import { useState } from 'react';
|
const originalContent = `import { useState } from 'react';
|
||||||
import { Button } from './components';
|
import { Button } from './components';
|
||||||
|
|
||||||
function App() {
|
function App() {
|
||||||
@@ -111,7 +111,7 @@ function App() {
|
|||||||
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
|
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
|
||||||
}`
|
}`
|
||||||
|
|
||||||
const diffContent = `--- src/App.tsx
|
const diffContent = `--- src/App.tsx
|
||||||
+++ src/App.tsx
|
+++ src/App.tsx
|
||||||
@@ -1,7 +1,8 @@
|
@@ -1,7 +1,8 @@
|
||||||
-import { useState } from 'react';
|
-import { useState } from 'react';
|
||||||
@@ -124,7 +124,7 @@ function App() {
|
|||||||
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
|
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
|
||||||
}`
|
}`
|
||||||
|
|
||||||
const expected = `import { useState, useEffect } from 'react';
|
const expected = `import { useState, useEffect } from 'react';
|
||||||
import { Button } from './components';
|
import { Button } from './components';
|
||||||
|
|
||||||
function App() {
|
function App() {
|
||||||
@@ -132,16 +132,16 @@ function App() {
|
|||||||
useEffect(() => { document.title = \`Count: \${count}\` }, [count]);
|
useEffect(() => { document.title = \`Count: \${count}\` }, [count]);
|
||||||
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
|
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
|
||||||
}`
|
}`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
|
||||||
expect(result.success).toBe(true)
|
|
||||||
if (result.success) {
|
|
||||||
expect(result.content).toBe(expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should successfully apply a diff with multiple hunks', async () => {
|
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||||
const originalContent = `import { readFile, writeFile } from 'fs';
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe(expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should successfully apply a diff with multiple hunks", async () => {
|
||||||
|
const originalContent = `import { readFile, writeFile } from 'fs';
|
||||||
|
|
||||||
function processFile(path: string) {
|
function processFile(path: string) {
|
||||||
readFile(path, 'utf8', (err, data) => {
|
readFile(path, 'utf8', (err, data) => {
|
||||||
@@ -155,7 +155,7 @@ function processFile(path: string) {
|
|||||||
|
|
||||||
export { processFile };`
|
export { processFile };`
|
||||||
|
|
||||||
const diffContent = `--- src/file-processor.ts
|
const diffContent = `--- src/file-processor.ts
|
||||||
+++ src/file-processor.ts
|
+++ src/file-processor.ts
|
||||||
@@ -1,12 +1,14 @@
|
@@ -1,12 +1,14 @@
|
||||||
-import { readFile, writeFile } from 'fs';
|
-import { readFile, writeFile } from 'fs';
|
||||||
@@ -182,7 +182,7 @@ export { processFile };`
|
|||||||
|
|
||||||
export { processFile };`
|
export { processFile };`
|
||||||
|
|
||||||
const expected = `import { promises as fs } from 'fs';
|
const expected = `import { promises as fs } from 'fs';
|
||||||
import { join } from 'path';
|
import { join } from 'path';
|
||||||
|
|
||||||
async function processFile(path: string) {
|
async function processFile(path: string) {
|
||||||
@@ -198,32 +198,31 @@ async function processFile(path: string) {
|
|||||||
|
|
||||||
export { processFile };`
|
export { processFile };`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||||
expect(result.success).toBe(true)
|
expect(result.success).toBe(true)
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
expect(result.content).toBe(expected)
|
expect(result.content).toBe(expected)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should handle empty original content', async () => {
|
it("should handle empty original content", async () => {
|
||||||
const originalContent = ''
|
const originalContent = ""
|
||||||
const diffContent = `--- empty.ts
|
const diffContent = `--- empty.ts
|
||||||
+++ empty.ts
|
+++ empty.ts
|
||||||
@@ -0,0 +1,3 @@
|
@@ -0,0 +1,3 @@
|
||||||
+export function greet(name: string): string {
|
+export function greet(name: string): string {
|
||||||
+ return \`Hello, \${name}!\`;
|
+ return \`Hello, \${name}!\`;
|
||||||
+}`
|
+}`
|
||||||
|
|
||||||
const expected = `export function greet(name: string): string {
|
const expected = `export function greet(name: string): string {
|
||||||
return \`Hello, \${name}!\`;
|
return \`Hello, \${name}!\`;
|
||||||
}\n`
|
}\n`
|
||||||
|
|
||||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||||
expect(result.success).toBe(true)
|
expect(result.success).toBe(true)
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
expect(result.content).toBe(expected)
|
expect(result.content).toBe(expected)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -265,8 +265,8 @@ describe("applyGitFallback", () => {
|
|||||||
{ type: "context", content: "line1", indent: "" },
|
{ type: "context", content: "line1", indent: "" },
|
||||||
{ type: "remove", content: "line2", indent: "" },
|
{ type: "remove", content: "line2", indent: "" },
|
||||||
{ type: "add", content: "new line2", indent: "" },
|
{ type: "add", content: "new line2", indent: "" },
|
||||||
{ type: "context", content: "line3", indent: "" }
|
{ type: "context", content: "line3", indent: "" },
|
||||||
]
|
],
|
||||||
} as Hunk
|
} as Hunk
|
||||||
|
|
||||||
const content = ["line1", "line2", "line3"]
|
const content = ["line1", "line2", "line3"]
|
||||||
@@ -281,8 +281,8 @@ describe("applyGitFallback", () => {
|
|||||||
const hunk = {
|
const hunk = {
|
||||||
changes: [
|
changes: [
|
||||||
{ type: "context", content: "nonexistent", indent: "" },
|
{ type: "context", content: "nonexistent", indent: "" },
|
||||||
{ type: "add", content: "new line", indent: "" }
|
{ type: "add", content: "new line", indent: "" },
|
||||||
]
|
],
|
||||||
} as Hunk
|
} as Hunk
|
||||||
|
|
||||||
const content = ["line1", "line2", "line3"]
|
const content = ["line1", "line2", "line3"]
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import { findAnchorMatch, findExactMatch, findSimilarityMatch, findLevenshteinMa
|
|||||||
type SearchStrategy = (
|
type SearchStrategy = (
|
||||||
searchStr: string,
|
searchStr: string,
|
||||||
content: string[],
|
content: string[],
|
||||||
startIndex?: number
|
startIndex?: number,
|
||||||
) => {
|
) => {
|
||||||
index: number
|
index: number
|
||||||
confidence: number
|
confidence: number
|
||||||
@@ -11,141 +11,141 @@ type SearchStrategy = (
|
|||||||
}
|
}
|
||||||
|
|
||||||
const testCases = [
|
const testCases = [
|
||||||
{
|
{
|
||||||
name: "should return no match if the search string is not found",
|
name: "should return no match if the search string is not found",
|
||||||
searchStr: "not found",
|
searchStr: "not found",
|
||||||
content: ["line1", "line2", "line3"],
|
content: ["line1", "line2", "line3"],
|
||||||
expected: { index: -1, confidence: 0 },
|
expected: { index: -1, confidence: 0 },
|
||||||
strategies: ["exact", "similarity", "levenshtein"],
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match if the search string is found",
|
name: "should return a match if the search string is found",
|
||||||
searchStr: "line2",
|
searchStr: "line2",
|
||||||
content: ["line1", "line2", "line3"],
|
content: ["line1", "line2", "line3"],
|
||||||
expected: { index: 1, confidence: 1 },
|
expected: { index: 1, confidence: 1 },
|
||||||
strategies: ["exact", "similarity", "levenshtein"],
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match with correct index when startIndex is provided",
|
name: "should return a match with correct index when startIndex is provided",
|
||||||
searchStr: "line3",
|
searchStr: "line3",
|
||||||
content: ["line1", "line2", "line3", "line4", "line3"],
|
content: ["line1", "line2", "line3", "line4", "line3"],
|
||||||
startIndex: 3,
|
startIndex: 3,
|
||||||
expected: { index: 4, confidence: 1 },
|
expected: { index: 4, confidence: 1 },
|
||||||
strategies: ["exact", "similarity", "levenshtein"],
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match even if there are more lines in content",
|
name: "should return a match even if there are more lines in content",
|
||||||
searchStr: "line2",
|
searchStr: "line2",
|
||||||
content: ["line1", "line2", "line3", "line4", "line5"],
|
content: ["line1", "line2", "line3", "line4", "line5"],
|
||||||
expected: { index: 1, confidence: 1 },
|
expected: { index: 1, confidence: 1 },
|
||||||
strategies: ["exact", "similarity", "levenshtein"],
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match even if the search string is at the beginning of the content",
|
name: "should return a match even if the search string is at the beginning of the content",
|
||||||
searchStr: "line1",
|
searchStr: "line1",
|
||||||
content: ["line1", "line2", "line3"],
|
content: ["line1", "line2", "line3"],
|
||||||
expected: { index: 0, confidence: 1 },
|
expected: { index: 0, confidence: 1 },
|
||||||
strategies: ["exact", "similarity", "levenshtein"],
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match even if the search string is at the end of the content",
|
name: "should return a match even if the search string is at the end of the content",
|
||||||
searchStr: "line3",
|
searchStr: "line3",
|
||||||
content: ["line1", "line2", "line3"],
|
content: ["line1", "line2", "line3"],
|
||||||
expected: { index: 2, confidence: 1 },
|
expected: { index: 2, confidence: 1 },
|
||||||
strategies: ["exact", "similarity", "levenshtein"],
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match for a multi-line search string",
|
name: "should return a match for a multi-line search string",
|
||||||
searchStr: "line2\nline3",
|
searchStr: "line2\nline3",
|
||||||
content: ["line1", "line2", "line3", "line4"],
|
content: ["line1", "line2", "line3", "line4"],
|
||||||
expected: { index: 1, confidence: 1 },
|
expected: { index: 1, confidence: 1 },
|
||||||
strategies: ["exact", "similarity", "levenshtein"],
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return no match if a multi-line search string is not found",
|
name: "should return no match if a multi-line search string is not found",
|
||||||
searchStr: "line2\nline4",
|
searchStr: "line2\nline4",
|
||||||
content: ["line1", "line2", "line3", "line4"],
|
content: ["line1", "line2", "line3", "line4"],
|
||||||
expected: { index: -1, confidence: 0 },
|
expected: { index: -1, confidence: 0 },
|
||||||
strategies: ["exact", "similarity"],
|
strategies: ["exact", "similarity"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match with indentation",
|
name: "should return a match with indentation",
|
||||||
searchStr: " line2",
|
searchStr: " line2",
|
||||||
content: ["line1", " line2", "line3"],
|
content: ["line1", " line2", "line3"],
|
||||||
expected: { index: 1, confidence: 1 },
|
expected: { index: 1, confidence: 1 },
|
||||||
strategies: ["exact", "similarity", "levenshtein"],
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match with more complex indentation",
|
name: "should return a match with more complex indentation",
|
||||||
searchStr: " line3",
|
searchStr: " line3",
|
||||||
content: [" line1", " line2", " line3", " line4"],
|
content: [" line1", " line2", " line3", " line4"],
|
||||||
expected: { index: 2, confidence: 1 },
|
expected: { index: 2, confidence: 1 },
|
||||||
strategies: ["exact", "similarity", "levenshtein"],
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match with mixed indentation",
|
name: "should return a match with mixed indentation",
|
||||||
searchStr: "\tline2",
|
searchStr: "\tline2",
|
||||||
content: [" line1", "\tline2", " line3"],
|
content: [" line1", "\tline2", " line3"],
|
||||||
expected: { index: 1, confidence: 1 },
|
expected: { index: 1, confidence: 1 },
|
||||||
strategies: ["exact", "similarity", "levenshtein"],
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match with mixed indentation and multi-line",
|
name: "should return a match with mixed indentation and multi-line",
|
||||||
searchStr: " line2\n\tline3",
|
searchStr: " line2\n\tline3",
|
||||||
content: ["line1", " line2", "\tline3", " line4"],
|
content: ["line1", " line2", "\tline3", " line4"],
|
||||||
expected: { index: 1, confidence: 1 },
|
expected: { index: 1, confidence: 1 },
|
||||||
strategies: ["exact", "similarity", "levenshtein"],
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return no match if mixed indentation and multi-line is not found",
|
name: "should return no match if mixed indentation and multi-line is not found",
|
||||||
searchStr: " line2\n line4",
|
searchStr: " line2\n line4",
|
||||||
content: ["line1", " line2", "\tline3", " line4"],
|
content: ["line1", " line2", "\tline3", " line4"],
|
||||||
expected: { index: -1, confidence: 0 },
|
expected: { index: -1, confidence: 0 },
|
||||||
strategies: ["exact", "similarity"],
|
strategies: ["exact", "similarity"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match with leading and trailing spaces",
|
name: "should return a match with leading and trailing spaces",
|
||||||
searchStr: " line2 ",
|
searchStr: " line2 ",
|
||||||
content: ["line1", " line2 ", "line3"],
|
content: ["line1", " line2 ", "line3"],
|
||||||
expected: { index: 1, confidence: 1 },
|
expected: { index: 1, confidence: 1 },
|
||||||
strategies: ["exact", "similarity", "levenshtein"],
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match with leading and trailing tabs",
|
name: "should return a match with leading and trailing tabs",
|
||||||
searchStr: "\tline2\t",
|
searchStr: "\tline2\t",
|
||||||
content: ["line1", "\tline2\t", "line3"],
|
content: ["line1", "\tline2\t", "line3"],
|
||||||
expected: { index: 1, confidence: 1 },
|
expected: { index: 1, confidence: 1 },
|
||||||
strategies: ["exact", "similarity", "levenshtein"],
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match with mixed leading and trailing spaces and tabs",
|
name: "should return a match with mixed leading and trailing spaces and tabs",
|
||||||
searchStr: " \tline2\t ",
|
searchStr: " \tline2\t ",
|
||||||
content: ["line1", " \tline2\t ", "line3"],
|
content: ["line1", " \tline2\t ", "line3"],
|
||||||
expected: { index: 1, confidence: 1 },
|
expected: { index: 1, confidence: 1 },
|
||||||
strategies: ["exact", "similarity", "levenshtein"],
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match with mixed leading and trailing spaces and tabs and multi-line",
|
name: "should return a match with mixed leading and trailing spaces and tabs and multi-line",
|
||||||
searchStr: " \tline2\t \n line3 ",
|
searchStr: " \tline2\t \n line3 ",
|
||||||
content: ["line1", " \tline2\t ", " line3 ", "line4"],
|
content: ["line1", " \tline2\t ", " line3 ", "line4"],
|
||||||
expected: { index: 1, confidence: 1 },
|
expected: { index: 1, confidence: 1 },
|
||||||
strategies: ["exact", "similarity", "levenshtein"],
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return no match if mixed leading and trailing spaces and tabs and multi-line is not found",
|
name: "should return no match if mixed leading and trailing spaces and tabs and multi-line is not found",
|
||||||
searchStr: " \tline2\t \n line4 ",
|
searchStr: " \tline2\t \n line4 ",
|
||||||
content: ["line1", " \tline2\t ", " line3 ", "line4"],
|
content: ["line1", " \tline2\t ", " line3 ", "line4"],
|
||||||
expected: { index: -1, confidence: 0 },
|
expected: { index: -1, confidence: 0 },
|
||||||
strategies: ["exact", "similarity"],
|
strategies: ["exact", "similarity"],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
describe("findExactMatch", () => {
|
describe("findExactMatch", () => {
|
||||||
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
|
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
|
||||||
if (!strategies?.includes("exact")) {
|
if (!strategies?.includes("exact")) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
it(name, () => {
|
it(name, () => {
|
||||||
const result = findExactMatch(searchStr, content, startIndex)
|
const result = findExactMatch(searchStr, content, startIndex)
|
||||||
expect(result.index).toBe(expected.index)
|
expect(result.index).toBe(expected.index)
|
||||||
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
||||||
@@ -155,16 +155,16 @@ describe("findExactMatch", () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
describe("findAnchorMatch", () => {
|
describe("findAnchorMatch", () => {
|
||||||
const anchorTestCases = [
|
const anchorTestCases = [
|
||||||
{
|
{
|
||||||
name: "should return no match if no anchors are found",
|
name: "should return no match if no anchors are found",
|
||||||
searchStr: " \n \n ",
|
searchStr: " \n \n ",
|
||||||
content: ["line1", "line2", "line3"],
|
content: ["line1", "line2", "line3"],
|
||||||
expected: { index: -1, confidence: 0 },
|
expected: { index: -1, confidence: 0 },
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return no match if anchor positions cannot be validated",
|
name: "should return no match if anchor positions cannot be validated",
|
||||||
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||||
content: [
|
content: [
|
||||||
"different line 1",
|
"different line 1",
|
||||||
"different line 2",
|
"different line 2",
|
||||||
@@ -173,24 +173,24 @@ describe("findAnchorMatch", () => {
|
|||||||
"context line 1",
|
"context line 1",
|
||||||
"context line 2",
|
"context line 2",
|
||||||
],
|
],
|
||||||
expected: { index: -1, confidence: 0 },
|
expected: { index: -1, confidence: 0 },
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match if anchor positions can be validated",
|
name: "should return a match if anchor positions can be validated",
|
||||||
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||||
content: ["line1", "line2", "unique line", "context line 1", "context line 2", "line 6"],
|
content: ["line1", "line2", "unique line", "context line 1", "context line 2", "line 6"],
|
||||||
expected: { index: 2, confidence: 1 },
|
expected: { index: 2, confidence: 1 },
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match with correct index when startIndex is provided",
|
name: "should return a match with correct index when startIndex is provided",
|
||||||
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||||
content: ["line1", "line2", "line3", "unique line", "context line 1", "context line 2", "line 7"],
|
content: ["line1", "line2", "line3", "unique line", "context line 1", "context line 2", "line 7"],
|
||||||
startIndex: 3,
|
startIndex: 3,
|
||||||
expected: { index: 3, confidence: 1 },
|
expected: { index: 3, confidence: 1 },
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match even if there are more lines in content",
|
name: "should return a match even if there are more lines in content",
|
||||||
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||||
content: [
|
content: [
|
||||||
"line1",
|
"line1",
|
||||||
"line2",
|
"line2",
|
||||||
@@ -201,30 +201,30 @@ describe("findAnchorMatch", () => {
|
|||||||
"extra line 1",
|
"extra line 1",
|
||||||
"extra line 2",
|
"extra line 2",
|
||||||
],
|
],
|
||||||
expected: { index: 2, confidence: 1 },
|
expected: { index: 2, confidence: 1 },
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match even if the anchor is at the beginning of the content",
|
name: "should return a match even if the anchor is at the beginning of the content",
|
||||||
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||||
content: ["unique line", "context line 1", "context line 2", "line 6"],
|
content: ["unique line", "context line 1", "context line 2", "line 6"],
|
||||||
expected: { index: 0, confidence: 1 },
|
expected: { index: 0, confidence: 1 },
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return a match even if the anchor is at the end of the content",
|
name: "should return a match even if the anchor is at the end of the content",
|
||||||
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||||
content: ["line1", "line2", "unique line", "context line 1", "context line 2"],
|
content: ["line1", "line2", "unique line", "context line 1", "context line 2"],
|
||||||
expected: { index: 2, confidence: 1 },
|
expected: { index: 2, confidence: 1 },
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return no match if no valid anchor is found",
|
name: "should return no match if no valid anchor is found",
|
||||||
searchStr: "non-unique line\ncontext line 1\ncontext line 2",
|
searchStr: "non-unique line\ncontext line 1\ncontext line 2",
|
||||||
content: ["line1", "line2", "non-unique line", "context line 1", "context line 2", "non-unique line"],
|
content: ["line1", "line2", "non-unique line", "context line 1", "context line 2", "non-unique line"],
|
||||||
expected: { index: -1, confidence: 0 },
|
expected: { index: -1, confidence: 0 },
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
anchorTestCases.forEach(({ name, searchStr, content, startIndex, expected }) => {
|
anchorTestCases.forEach(({ name, searchStr, content, startIndex, expected }) => {
|
||||||
it(name, () => {
|
it(name, () => {
|
||||||
const result = findAnchorMatch(searchStr, content, startIndex)
|
const result = findAnchorMatch(searchStr, content, startIndex)
|
||||||
expect(result.index).toBe(expected.index)
|
expect(result.index).toBe(expected.index)
|
||||||
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
||||||
@@ -234,11 +234,11 @@ describe("findAnchorMatch", () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
describe("findSimilarityMatch", () => {
|
describe("findSimilarityMatch", () => {
|
||||||
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
|
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
|
||||||
if (!strategies?.includes("similarity")) {
|
if (!strategies?.includes("similarity")) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
it(name, () => {
|
it(name, () => {
|
||||||
const result = findSimilarityMatch(searchStr, content, startIndex)
|
const result = findSimilarityMatch(searchStr, content, startIndex)
|
||||||
expect(result.index).toBe(expected.index)
|
expect(result.index).toBe(expected.index)
|
||||||
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
||||||
@@ -248,11 +248,11 @@ describe("findSimilarityMatch", () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
describe("findLevenshteinMatch", () => {
|
describe("findLevenshteinMatch", () => {
|
||||||
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
|
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
|
||||||
if (!strategies?.includes("levenshtein")) {
|
if (!strategies?.includes("levenshtein")) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
it(name, () => {
|
it(name, () => {
|
||||||
const result = findLevenshteinMatch(searchStr, content, startIndex)
|
const result = findLevenshteinMatch(searchStr, content, startIndex)
|
||||||
expect(result.index).toBe(expected.index)
|
expect(result.index).toBe(expected.index)
|
||||||
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ function inferIndentation(line: string, contextLines: string[], previousIndent:
|
|||||||
const contextLine = contextLines[0]
|
const contextLine = contextLines[0]
|
||||||
if (contextLine) {
|
if (contextLine) {
|
||||||
const contextMatch = contextLine.match(/^(\s+)/)
|
const contextMatch = contextLine.match(/^(\s+)/)
|
||||||
if (contextMatch) {
|
if (contextMatch) {
|
||||||
return contextMatch[1]
|
return contextMatch[1]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -28,19 +28,15 @@ function inferIndentation(line: string, contextLines: string[], previousIndent:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Context matching edit strategy
|
// Context matching edit strategy
|
||||||
export function applyContextMatching(
|
export function applyContextMatching(hunk: Hunk, content: string[], matchPosition: number): EditResult {
|
||||||
hunk: Hunk,
|
if (matchPosition === -1) {
|
||||||
content: string[],
|
|
||||||
matchPosition: number,
|
|
||||||
): EditResult {
|
|
||||||
if (matchPosition === -1) {
|
|
||||||
return { confidence: 0, result: content, strategy: "context" }
|
return { confidence: 0, result: content, strategy: "context" }
|
||||||
}
|
}
|
||||||
|
|
||||||
const newResult = [...content.slice(0, matchPosition)]
|
const newResult = [...content.slice(0, matchPosition)]
|
||||||
let sourceIndex = matchPosition
|
let sourceIndex = matchPosition
|
||||||
|
|
||||||
for (const change of hunk.changes) {
|
for (const change of hunk.changes) {
|
||||||
if (change.type === "context") {
|
if (change.type === "context") {
|
||||||
// Use the original line from content if available
|
// Use the original line from content if available
|
||||||
if (sourceIndex < content.length) {
|
if (sourceIndex < content.length) {
|
||||||
@@ -82,20 +78,16 @@ export function applyContextMatching(
|
|||||||
|
|
||||||
const confidence = validateEditResult(hunk, afterText)
|
const confidence = validateEditResult(hunk, afterText)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
confidence,
|
confidence,
|
||||||
result: newResult,
|
result: newResult,
|
||||||
strategy: "context"
|
strategy: "context",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// DMP edit strategy
|
// DMP edit strategy
|
||||||
export function applyDMP(
|
export function applyDMP(hunk: Hunk, content: string[], matchPosition: number): EditResult {
|
||||||
hunk: Hunk,
|
if (matchPosition === -1) {
|
||||||
content: string[],
|
|
||||||
matchPosition: number,
|
|
||||||
): EditResult {
|
|
||||||
if (matchPosition === -1) {
|
|
||||||
return { confidence: 0, result: content, strategy: "dmp" }
|
return { confidence: 0, result: content, strategy: "dmp" }
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,9 +97,9 @@ export function applyDMP(
|
|||||||
const beforeLineCount = hunk.changes
|
const beforeLineCount = hunk.changes
|
||||||
.filter((change) => change.type === "context" || change.type === "remove")
|
.filter((change) => change.type === "context" || change.type === "remove")
|
||||||
.reduce((count, change) => count + change.content.split("\n").length, 0)
|
.reduce((count, change) => count + change.content.split("\n").length, 0)
|
||||||
|
|
||||||
// Build BEFORE block (context + removals)
|
// Build BEFORE block (context + removals)
|
||||||
const beforeLines = hunk.changes
|
const beforeLines = hunk.changes
|
||||||
.filter((change) => change.type === "context" || change.type === "remove")
|
.filter((change) => change.type === "context" || change.type === "remove")
|
||||||
.map((change) => {
|
.map((change) => {
|
||||||
if (change.originalLine) {
|
if (change.originalLine) {
|
||||||
@@ -115,9 +107,9 @@ export function applyDMP(
|
|||||||
}
|
}
|
||||||
return change.indent ? change.indent + change.content : change.content
|
return change.indent ? change.indent + change.content : change.content
|
||||||
})
|
})
|
||||||
|
|
||||||
// Build AFTER block (context + additions)
|
// Build AFTER block (context + additions)
|
||||||
const afterLines = hunk.changes
|
const afterLines = hunk.changes
|
||||||
.filter((change) => change.type === "context" || change.type === "add")
|
.filter((change) => change.type === "context" || change.type === "add")
|
||||||
.map((change) => {
|
.map((change) => {
|
||||||
if (change.originalLine) {
|
if (change.originalLine) {
|
||||||
@@ -139,17 +131,17 @@ export function applyDMP(
|
|||||||
const patchedLines = patchedText.split("\n")
|
const patchedLines = patchedText.split("\n")
|
||||||
|
|
||||||
// Construct final result
|
// Construct final result
|
||||||
const newResult = [
|
const newResult = [
|
||||||
...content.slice(0, matchPosition),
|
...content.slice(0, matchPosition),
|
||||||
...patchedLines,
|
...patchedLines,
|
||||||
...content.slice(matchPosition + beforeLineCount),
|
...content.slice(matchPosition + beforeLineCount),
|
||||||
]
|
]
|
||||||
|
|
||||||
const confidence = validateEditResult(hunk, patchedText)
|
const confidence = validateEditResult(hunk, patchedText)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
confidence,
|
confidence,
|
||||||
result: newResult,
|
result: newResult,
|
||||||
strategy: "dmp",
|
strategy: "dmp",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -171,7 +163,7 @@ export async function applyGitFallback(hunk: Hunk, content: string[]): Promise<E
|
|||||||
const searchLines = hunk.changes
|
const searchLines = hunk.changes
|
||||||
.filter((change) => change.type === "context" || change.type === "remove")
|
.filter((change) => change.type === "context" || change.type === "remove")
|
||||||
.map((change) => change.originalLine || change.indent + change.content)
|
.map((change) => change.originalLine || change.indent + change.content)
|
||||||
|
|
||||||
const replaceLines = hunk.changes
|
const replaceLines = hunk.changes
|
||||||
.filter((change) => change.type === "context" || change.type === "add")
|
.filter((change) => change.type === "context" || change.type === "add")
|
||||||
.map((change) => change.originalLine || change.indent + change.content)
|
.map((change) => change.originalLine || change.indent + change.content)
|
||||||
@@ -272,16 +264,16 @@ export async function applyGitFallback(hunk: Hunk, content: string[]): Promise<E
|
|||||||
|
|
||||||
// Main edit function that tries strategies sequentially
|
// Main edit function that tries strategies sequentially
|
||||||
export async function applyEdit(
|
export async function applyEdit(
|
||||||
hunk: Hunk,
|
hunk: Hunk,
|
||||||
content: string[],
|
content: string[],
|
||||||
matchPosition: number,
|
matchPosition: number,
|
||||||
confidence: number,
|
confidence: number,
|
||||||
confidenceThreshold: number = 0.97
|
confidenceThreshold: number = 0.97,
|
||||||
): Promise<EditResult> {
|
): Promise<EditResult> {
|
||||||
// Don't attempt regular edits if confidence is too low
|
// Don't attempt regular edits if confidence is too low
|
||||||
if (confidence < confidenceThreshold) {
|
if (confidence < confidenceThreshold) {
|
||||||
console.log(
|
console.log(
|
||||||
`Search confidence (${confidence}) below minimum threshold (${confidenceThreshold}), trying git fallback...`
|
`Search confidence (${confidence}) below minimum threshold (${confidenceThreshold}), trying git fallback...`,
|
||||||
)
|
)
|
||||||
return applyGitFallback(hunk, content)
|
return applyGitFallback(hunk, content)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -242,7 +242,7 @@ Your diff here
|
|||||||
originalContent: string,
|
originalContent: string,
|
||||||
diffContent: string,
|
diffContent: string,
|
||||||
startLine?: number,
|
startLine?: number,
|
||||||
endLine?: number
|
endLine?: number,
|
||||||
): Promise<DiffResult> {
|
): Promise<DiffResult> {
|
||||||
const parsedDiff = this.parseUnifiedDiff(diffContent)
|
const parsedDiff = this.parseUnifiedDiff(diffContent)
|
||||||
const originalLines = originalContent.split("\n")
|
const originalLines = originalContent.split("\n")
|
||||||
@@ -280,7 +280,7 @@ Your diff here
|
|||||||
subHunkResult,
|
subHunkResult,
|
||||||
subSearchResult.index,
|
subSearchResult.index,
|
||||||
subSearchResult.confidence,
|
subSearchResult.confidence,
|
||||||
this.confidenceThreshold
|
this.confidenceThreshold,
|
||||||
)
|
)
|
||||||
if (subEditResult.confidence >= this.confidenceThreshold) {
|
if (subEditResult.confidence >= this.confidenceThreshold) {
|
||||||
subHunkResult = subEditResult.result
|
subHunkResult = subEditResult.result
|
||||||
@@ -302,12 +302,12 @@ Your diff here
|
|||||||
const contextRatio = contextLines / totalLines
|
const contextRatio = contextLines / totalLines
|
||||||
|
|
||||||
let errorMsg = `Failed to find a matching location in the file (${Math.floor(
|
let errorMsg = `Failed to find a matching location in the file (${Math.floor(
|
||||||
confidence * 100
|
confidence * 100,
|
||||||
)}% confidence, needs ${Math.floor(this.confidenceThreshold * 100)}%)\n\n`
|
)}% confidence, needs ${Math.floor(this.confidenceThreshold * 100)}%)\n\n`
|
||||||
errorMsg += "Debug Info:\n"
|
errorMsg += "Debug Info:\n"
|
||||||
errorMsg += `- Search Strategy Used: ${strategy}\n`
|
errorMsg += `- Search Strategy Used: ${strategy}\n`
|
||||||
errorMsg += `- Context Lines: ${contextLines} out of ${totalLines} total lines (${Math.floor(
|
errorMsg += `- Context Lines: ${contextLines} out of ${totalLines} total lines (${Math.floor(
|
||||||
contextRatio * 100
|
contextRatio * 100,
|
||||||
)}%)\n`
|
)}%)\n`
|
||||||
errorMsg += `- Attempted to split into ${subHunks.length} sub-hunks but still failed\n`
|
errorMsg += `- Attempted to split into ${subHunks.length} sub-hunks but still failed\n`
|
||||||
|
|
||||||
@@ -339,7 +339,7 @@ Your diff here
|
|||||||
} else {
|
} else {
|
||||||
// Edit failure - likely due to content mismatch
|
// Edit failure - likely due to content mismatch
|
||||||
let errorMsg = `Failed to apply the edit using ${editResult.strategy} strategy (${Math.floor(
|
let errorMsg = `Failed to apply the edit using ${editResult.strategy} strategy (${Math.floor(
|
||||||
editResult.confidence * 100
|
editResult.confidence * 100,
|
||||||
)}% confidence)\n\n`
|
)}% confidence)\n\n`
|
||||||
errorMsg += "Debug Info:\n"
|
errorMsg += "Debug Info:\n"
|
||||||
errorMsg += "- The location was found but the content didn't match exactly\n"
|
errorMsg += "- The location was found but the content didn't match exactly\n"
|
||||||
|
|||||||
@@ -69,26 +69,26 @@ export function getDMPSimilarity(original: string, modified: string): number {
|
|||||||
export function validateEditResult(hunk: Hunk, result: string): number {
|
export function validateEditResult(hunk: Hunk, result: string): number {
|
||||||
// Build the expected text from the hunk
|
// Build the expected text from the hunk
|
||||||
const expectedText = hunk.changes
|
const expectedText = hunk.changes
|
||||||
.filter(change => change.type === "context" || change.type === "add")
|
.filter((change) => change.type === "context" || change.type === "add")
|
||||||
.map(change => change.indent ? change.indent + change.content : change.content)
|
.map((change) => (change.indent ? change.indent + change.content : change.content))
|
||||||
.join("\n");
|
.join("\n")
|
||||||
|
|
||||||
// Calculate similarity between the result and expected text
|
// Calculate similarity between the result and expected text
|
||||||
const similarity = getDMPSimilarity(expectedText, result);
|
const similarity = getDMPSimilarity(expectedText, result)
|
||||||
|
|
||||||
// If the result is unchanged from original, return low confidence
|
// If the result is unchanged from original, return low confidence
|
||||||
const originalText = hunk.changes
|
const originalText = hunk.changes
|
||||||
.filter(change => change.type === "context" || change.type === "remove")
|
.filter((change) => change.type === "context" || change.type === "remove")
|
||||||
.map(change => change.indent ? change.indent + change.content : change.content)
|
.map((change) => (change.indent ? change.indent + change.content : change.content))
|
||||||
.join("\n");
|
.join("\n")
|
||||||
|
|
||||||
const originalSimilarity = getDMPSimilarity(originalText, result);
|
const originalSimilarity = getDMPSimilarity(originalText, result)
|
||||||
if (originalSimilarity > 0.97 && similarity !== 1) {
|
if (originalSimilarity > 0.97 && similarity !== 1) {
|
||||||
return 0.8 * similarity; // Some confidence since we found the right location
|
return 0.8 * similarity // Some confidence since we found the right location
|
||||||
}
|
}
|
||||||
|
|
||||||
// For partial matches, scale the confidence but keep it high if we're close
|
// For partial matches, scale the confidence but keep it high if we're close
|
||||||
return similarity;
|
return similarity
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to validate context lines against original content
|
// Helper function to validate context lines against original content
|
||||||
@@ -114,7 +114,7 @@ function validateContextLines(searchStr: string, content: string, confidenceThre
|
|||||||
function createOverlappingWindows(
|
function createOverlappingWindows(
|
||||||
content: string[],
|
content: string[],
|
||||||
searchSize: number,
|
searchSize: number,
|
||||||
overlapSize: number = DEFAULT_OVERLAP_SIZE
|
overlapSize: number = DEFAULT_OVERLAP_SIZE,
|
||||||
): { window: string[]; startIndex: number }[] {
|
): { window: string[]; startIndex: number }[] {
|
||||||
const windows: { window: string[]; startIndex: number }[] = []
|
const windows: { window: string[]; startIndex: number }[] = []
|
||||||
|
|
||||||
@@ -140,7 +140,7 @@ function createOverlappingWindows(
|
|||||||
// Helper function to combine overlapping matches
|
// Helper function to combine overlapping matches
|
||||||
function combineOverlappingMatches(
|
function combineOverlappingMatches(
|
||||||
matches: (SearchResult & { windowIndex: number })[],
|
matches: (SearchResult & { windowIndex: number })[],
|
||||||
overlapSize: number = DEFAULT_OVERLAP_SIZE
|
overlapSize: number = DEFAULT_OVERLAP_SIZE,
|
||||||
): SearchResult[] {
|
): SearchResult[] {
|
||||||
if (matches.length === 0) {
|
if (matches.length === 0) {
|
||||||
return []
|
return []
|
||||||
@@ -162,7 +162,7 @@ function combineOverlappingMatches(
|
|||||||
(m) =>
|
(m) =>
|
||||||
Math.abs(m.windowIndex - match.windowIndex) === 1 &&
|
Math.abs(m.windowIndex - match.windowIndex) === 1 &&
|
||||||
Math.abs(m.index - match.index) <= overlapSize &&
|
Math.abs(m.index - match.index) <= overlapSize &&
|
||||||
!usedIndices.has(m.windowIndex)
|
!usedIndices.has(m.windowIndex),
|
||||||
)
|
)
|
||||||
|
|
||||||
if (overlapping.length > 0) {
|
if (overlapping.length > 0) {
|
||||||
@@ -196,7 +196,7 @@ export function findExactMatch(
|
|||||||
searchStr: string,
|
searchStr: string,
|
||||||
content: string[],
|
content: string[],
|
||||||
startIndex: number = 0,
|
startIndex: number = 0,
|
||||||
confidenceThreshold: number = 0.97
|
confidenceThreshold: number = 0.97,
|
||||||
): SearchResult {
|
): SearchResult {
|
||||||
const searchLines = searchStr.split("\n")
|
const searchLines = searchStr.split("\n")
|
||||||
const windows = createOverlappingWindows(content.slice(startIndex), searchLines.length)
|
const windows = createOverlappingWindows(content.slice(startIndex), searchLines.length)
|
||||||
@@ -210,7 +210,7 @@ export function findExactMatch(
|
|||||||
const matchedContent = windowData.window
|
const matchedContent = windowData.window
|
||||||
.slice(
|
.slice(
|
||||||
windowStr.slice(0, exactMatch).split("\n").length - 1,
|
windowStr.slice(0, exactMatch).split("\n").length - 1,
|
||||||
windowStr.slice(0, exactMatch).split("\n").length - 1 + searchLines.length
|
windowStr.slice(0, exactMatch).split("\n").length - 1 + searchLines.length,
|
||||||
)
|
)
|
||||||
.join("\n")
|
.join("\n")
|
||||||
|
|
||||||
@@ -236,7 +236,7 @@ export function findSimilarityMatch(
|
|||||||
searchStr: string,
|
searchStr: string,
|
||||||
content: string[],
|
content: string[],
|
||||||
startIndex: number = 0,
|
startIndex: number = 0,
|
||||||
confidenceThreshold: number = 0.97
|
confidenceThreshold: number = 0.97,
|
||||||
): SearchResult {
|
): SearchResult {
|
||||||
const searchLines = searchStr.split("\n")
|
const searchLines = searchStr.split("\n")
|
||||||
let bestScore = 0
|
let bestScore = 0
|
||||||
@@ -269,7 +269,7 @@ export function findLevenshteinMatch(
|
|||||||
searchStr: string,
|
searchStr: string,
|
||||||
content: string[],
|
content: string[],
|
||||||
startIndex: number = 0,
|
startIndex: number = 0,
|
||||||
confidenceThreshold: number = 0.97
|
confidenceThreshold: number = 0.97,
|
||||||
): SearchResult {
|
): SearchResult {
|
||||||
const searchLines = searchStr.split("\n")
|
const searchLines = searchStr.split("\n")
|
||||||
const candidates = []
|
const candidates = []
|
||||||
@@ -324,7 +324,7 @@ export function findAnchorMatch(
|
|||||||
searchStr: string,
|
searchStr: string,
|
||||||
content: string[],
|
content: string[],
|
||||||
startIndex: number = 0,
|
startIndex: number = 0,
|
||||||
confidenceThreshold: number = 0.97
|
confidenceThreshold: number = 0.97,
|
||||||
): SearchResult {
|
): SearchResult {
|
||||||
const searchLines = searchStr.split("\n")
|
const searchLines = searchStr.split("\n")
|
||||||
const { first, last } = identifyAnchors(searchStr)
|
const { first, last } = identifyAnchors(searchStr)
|
||||||
@@ -391,7 +391,7 @@ export function findBestMatch(
|
|||||||
searchStr: string,
|
searchStr: string,
|
||||||
content: string[],
|
content: string[],
|
||||||
startIndex: number = 0,
|
startIndex: number = 0,
|
||||||
confidenceThreshold: number = 0.97
|
confidenceThreshold: number = 0.97,
|
||||||
): SearchResult {
|
): SearchResult {
|
||||||
const strategies = [findExactMatch, findAnchorMatch, findSimilarityMatch, findLevenshteinMatch]
|
const strategies = [findExactMatch, findAnchorMatch, findSimilarityMatch, findLevenshteinMatch]
|
||||||
|
|
||||||
|
|||||||
@@ -1,20 +1,20 @@
|
|||||||
export type Change = {
|
export type Change = {
|
||||||
type: 'context' | 'add' | 'remove';
|
type: "context" | "add" | "remove"
|
||||||
content: string;
|
content: string
|
||||||
indent: string;
|
indent: string
|
||||||
originalLine?: string;
|
originalLine?: string
|
||||||
};
|
}
|
||||||
|
|
||||||
export type Hunk = {
|
export type Hunk = {
|
||||||
changes: Change[];
|
changes: Change[]
|
||||||
};
|
}
|
||||||
|
|
||||||
export type Diff = {
|
export type Diff = {
|
||||||
hunks: Hunk[];
|
hunks: Hunk[]
|
||||||
};
|
}
|
||||||
|
|
||||||
export type EditResult = {
|
export type EditResult = {
|
||||||
confidence: number;
|
confidence: number
|
||||||
result: string[];
|
result: string[]
|
||||||
strategy: string;
|
strategy: string
|
||||||
};
|
}
|
||||||
|
|||||||
@@ -1,72 +1,74 @@
|
|||||||
import { DiffStrategy, DiffResult } from "../types"
|
import { DiffStrategy, DiffResult } from "../types"
|
||||||
import { addLineNumbers, everyLineHasLineNumbers, stripLineNumbers } from "../../../integrations/misc/extract-text"
|
import { addLineNumbers, everyLineHasLineNumbers, stripLineNumbers } from "../../../integrations/misc/extract-text"
|
||||||
|
|
||||||
const BUFFER_LINES = 20; // Number of extra context lines to show before and after matches
|
const BUFFER_LINES = 20 // Number of extra context lines to show before and after matches
|
||||||
|
|
||||||
function levenshteinDistance(a: string, b: string): number {
|
function levenshteinDistance(a: string, b: string): number {
|
||||||
const matrix: number[][] = [];
|
const matrix: number[][] = []
|
||||||
|
|
||||||
// Initialize matrix
|
// Initialize matrix
|
||||||
for (let i = 0; i <= a.length; i++) {
|
for (let i = 0; i <= a.length; i++) {
|
||||||
matrix[i] = [i];
|
matrix[i] = [i]
|
||||||
}
|
}
|
||||||
for (let j = 0; j <= b.length; j++) {
|
for (let j = 0; j <= b.length; j++) {
|
||||||
matrix[0][j] = j;
|
matrix[0][j] = j
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fill matrix
|
// Fill matrix
|
||||||
for (let i = 1; i <= a.length; i++) {
|
for (let i = 1; i <= a.length; i++) {
|
||||||
for (let j = 1; j <= b.length; j++) {
|
for (let j = 1; j <= b.length; j++) {
|
||||||
if (a[i-1] === b[j-1]) {
|
if (a[i - 1] === b[j - 1]) {
|
||||||
matrix[i][j] = matrix[i-1][j-1];
|
matrix[i][j] = matrix[i - 1][j - 1]
|
||||||
} else {
|
} else {
|
||||||
matrix[i][j] = Math.min(
|
matrix[i][j] = Math.min(
|
||||||
matrix[i-1][j-1] + 1, // substitution
|
matrix[i - 1][j - 1] + 1, // substitution
|
||||||
matrix[i][j-1] + 1, // insertion
|
matrix[i][j - 1] + 1, // insertion
|
||||||
matrix[i-1][j] + 1 // deletion
|
matrix[i - 1][j] + 1, // deletion
|
||||||
);
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return matrix[a.length][b.length];
|
return matrix[a.length][b.length]
|
||||||
}
|
}
|
||||||
|
|
||||||
function getSimilarity(original: string, search: string): number {
|
function getSimilarity(original: string, search: string): number {
|
||||||
if (search === '') {
|
if (search === "") {
|
||||||
return 1;
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normalize strings by removing extra whitespace but preserve case
|
// Normalize strings by removing extra whitespace but preserve case
|
||||||
const normalizeStr = (str: string) => str.replace(/\s+/g, ' ').trim();
|
const normalizeStr = (str: string) => str.replace(/\s+/g, " ").trim()
|
||||||
|
|
||||||
const normalizedOriginal = normalizeStr(original);
|
const normalizedOriginal = normalizeStr(original)
|
||||||
const normalizedSearch = normalizeStr(search);
|
const normalizedSearch = normalizeStr(search)
|
||||||
|
|
||||||
if (normalizedOriginal === normalizedSearch) { return 1; }
|
if (normalizedOriginal === normalizedSearch) {
|
||||||
|
return 1
|
||||||
// Calculate Levenshtein distance
|
}
|
||||||
const distance = levenshteinDistance(normalizedOriginal, normalizedSearch);
|
|
||||||
|
// Calculate Levenshtein distance
|
||||||
// Calculate similarity ratio (0 to 1, where 1 is exact match)
|
const distance = levenshteinDistance(normalizedOriginal, normalizedSearch)
|
||||||
const maxLength = Math.max(normalizedOriginal.length, normalizedSearch.length);
|
|
||||||
return 1 - (distance / maxLength);
|
// Calculate similarity ratio (0 to 1, where 1 is exact match)
|
||||||
|
const maxLength = Math.max(normalizedOriginal.length, normalizedSearch.length)
|
||||||
|
return 1 - distance / maxLength
|
||||||
}
|
}
|
||||||
|
|
||||||
export class SearchReplaceDiffStrategy implements DiffStrategy {
|
export class SearchReplaceDiffStrategy implements DiffStrategy {
|
||||||
private fuzzyThreshold: number;
|
private fuzzyThreshold: number
|
||||||
private bufferLines: number;
|
private bufferLines: number
|
||||||
|
|
||||||
constructor(fuzzyThreshold?: number, bufferLines?: number) {
|
constructor(fuzzyThreshold?: number, bufferLines?: number) {
|
||||||
// Use provided threshold or default to exact matching (1.0)
|
// Use provided threshold or default to exact matching (1.0)
|
||||||
// Note: fuzzyThreshold is inverted in UI (0% = 1.0, 10% = 0.9)
|
// Note: fuzzyThreshold is inverted in UI (0% = 1.0, 10% = 0.9)
|
||||||
// so we use it directly here
|
// so we use it directly here
|
||||||
this.fuzzyThreshold = fuzzyThreshold ?? 1.0;
|
this.fuzzyThreshold = fuzzyThreshold ?? 1.0
|
||||||
this.bufferLines = bufferLines ?? BUFFER_LINES;
|
this.bufferLines = bufferLines ?? BUFFER_LINES
|
||||||
}
|
}
|
||||||
|
|
||||||
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
|
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
|
||||||
return `## apply_diff
|
return `## apply_diff
|
||||||
Description: Request to replace existing code using a search and replace block.
|
Description: Request to replace existing code using a search and replace block.
|
||||||
This tool allows for precise, surgical replaces to files by specifying exactly what content to search for and what to replace it with.
|
This tool allows for precise, surgical replaces to files by specifying exactly what content to search for and what to replace it with.
|
||||||
The tool will maintain proper indentation and formatting while making changes.
|
The tool will maintain proper indentation and formatting while making changes.
|
||||||
@@ -125,193 +127,204 @@ Your search/replace content here
|
|||||||
<start_line>1</start_line>
|
<start_line>1</start_line>
|
||||||
<end_line>5</end_line>
|
<end_line>5</end_line>
|
||||||
</apply_diff>`
|
</apply_diff>`
|
||||||
}
|
}
|
||||||
|
|
||||||
async applyDiff(originalContent: string, diffContent: string, startLine?: number, endLine?: number): Promise<DiffResult> {
|
async applyDiff(
|
||||||
// Extract the search and replace blocks
|
originalContent: string,
|
||||||
const match = diffContent.match(/<<<<<<< SEARCH\n([\s\S]*?)\n?=======\n([\s\S]*?)\n?>>>>>>> REPLACE/);
|
diffContent: string,
|
||||||
if (!match) {
|
startLine?: number,
|
||||||
return {
|
endLine?: number,
|
||||||
success: false,
|
): Promise<DiffResult> {
|
||||||
error: `Invalid diff format - missing required SEARCH/REPLACE sections\n\nDebug Info:\n- Expected Format: <<<<<<< SEARCH\\n[search content]\\n=======\\n[replace content]\\n>>>>>>> REPLACE\n- Tip: Make sure to include both SEARCH and REPLACE sections with correct markers`
|
// Extract the search and replace blocks
|
||||||
};
|
const match = diffContent.match(/<<<<<<< SEARCH\n([\s\S]*?)\n?=======\n([\s\S]*?)\n?>>>>>>> REPLACE/)
|
||||||
}
|
if (!match) {
|
||||||
|
return {
|
||||||
|
success: false,
|
||||||
|
error: `Invalid diff format - missing required SEARCH/REPLACE sections\n\nDebug Info:\n- Expected Format: <<<<<<< SEARCH\\n[search content]\\n=======\\n[replace content]\\n>>>>>>> REPLACE\n- Tip: Make sure to include both SEARCH and REPLACE sections with correct markers`,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let [_, searchContent, replaceContent] = match;
|
let [_, searchContent, replaceContent] = match
|
||||||
|
|
||||||
// Detect line ending from original content
|
// Detect line ending from original content
|
||||||
const lineEnding = originalContent.includes('\r\n') ? '\r\n' : '\n';
|
const lineEnding = originalContent.includes("\r\n") ? "\r\n" : "\n"
|
||||||
|
|
||||||
// Strip line numbers from search and replace content if every line starts with a line number
|
// Strip line numbers from search and replace content if every line starts with a line number
|
||||||
if (everyLineHasLineNumbers(searchContent) && everyLineHasLineNumbers(replaceContent)) {
|
if (everyLineHasLineNumbers(searchContent) && everyLineHasLineNumbers(replaceContent)) {
|
||||||
searchContent = stripLineNumbers(searchContent);
|
searchContent = stripLineNumbers(searchContent)
|
||||||
replaceContent = stripLineNumbers(replaceContent);
|
replaceContent = stripLineNumbers(replaceContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Split content into lines, handling both \n and \r\n
|
|
||||||
const searchLines = searchContent === '' ? [] : searchContent.split(/\r?\n/);
|
|
||||||
const replaceLines = replaceContent === '' ? [] : replaceContent.split(/\r?\n/);
|
|
||||||
const originalLines = originalContent.split(/\r?\n/);
|
|
||||||
|
|
||||||
// Validate that empty search requires start line
|
// Split content into lines, handling both \n and \r\n
|
||||||
if (searchLines.length === 0 && !startLine) {
|
const searchLines = searchContent === "" ? [] : searchContent.split(/\r?\n/)
|
||||||
return {
|
const replaceLines = replaceContent === "" ? [] : replaceContent.split(/\r?\n/)
|
||||||
success: false,
|
const originalLines = originalContent.split(/\r?\n/)
|
||||||
error: `Empty search content requires start_line to be specified\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, specify the line number where content should be inserted`
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate that empty search requires same start and end line
|
// Validate that empty search requires start line
|
||||||
if (searchLines.length === 0 && startLine && endLine && startLine !== endLine) {
|
if (searchLines.length === 0 && !startLine) {
|
||||||
return {
|
return {
|
||||||
success: false,
|
success: false,
|
||||||
error: `Empty search content requires start_line and end_line to be the same (got ${startLine}-${endLine})\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, use the same line number for both start_line and end_line`
|
error: `Empty search content requires start_line to be specified\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, specify the line number where content should be inserted`,
|
||||||
};
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize search variables
|
|
||||||
let matchIndex = -1;
|
|
||||||
let bestMatchScore = 0;
|
|
||||||
let bestMatchContent = "";
|
|
||||||
const searchChunk = searchLines.join('\n');
|
|
||||||
|
|
||||||
// Determine search bounds
|
// Validate that empty search requires same start and end line
|
||||||
let searchStartIndex = 0;
|
if (searchLines.length === 0 && startLine && endLine && startLine !== endLine) {
|
||||||
let searchEndIndex = originalLines.length;
|
return {
|
||||||
|
success: false,
|
||||||
|
error: `Empty search content requires start_line and end_line to be the same (got ${startLine}-${endLine})\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, use the same line number for both start_line and end_line`,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Validate and handle line range if provided
|
// Initialize search variables
|
||||||
if (startLine && endLine) {
|
let matchIndex = -1
|
||||||
// Convert to 0-based index
|
let bestMatchScore = 0
|
||||||
const exactStartIndex = startLine - 1;
|
let bestMatchContent = ""
|
||||||
const exactEndIndex = endLine - 1;
|
const searchChunk = searchLines.join("\n")
|
||||||
|
|
||||||
if (exactStartIndex < 0 || exactEndIndex > originalLines.length || exactStartIndex > exactEndIndex) {
|
// Determine search bounds
|
||||||
return {
|
let searchStartIndex = 0
|
||||||
success: false,
|
let searchEndIndex = originalLines.length
|
||||||
error: `Line range ${startLine}-${endLine} is invalid (file has ${originalLines.length} lines)\n\nDebug Info:\n- Requested Range: lines ${startLine}-${endLine}\n- File Bounds: lines 1-${originalLines.length}`,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try exact match first
|
// Validate and handle line range if provided
|
||||||
const originalChunk = originalLines.slice(exactStartIndex, exactEndIndex + 1).join('\n');
|
if (startLine && endLine) {
|
||||||
const similarity = getSimilarity(originalChunk, searchChunk);
|
// Convert to 0-based index
|
||||||
if (similarity >= this.fuzzyThreshold) {
|
const exactStartIndex = startLine - 1
|
||||||
matchIndex = exactStartIndex;
|
const exactEndIndex = endLine - 1
|
||||||
bestMatchScore = similarity;
|
|
||||||
bestMatchContent = originalChunk;
|
|
||||||
} else {
|
|
||||||
// Set bounds for buffered search
|
|
||||||
searchStartIndex = Math.max(0, startLine - (this.bufferLines + 1));
|
|
||||||
searchEndIndex = Math.min(originalLines.length, endLine + this.bufferLines);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no match found yet, try middle-out search within bounds
|
if (exactStartIndex < 0 || exactEndIndex > originalLines.length || exactStartIndex > exactEndIndex) {
|
||||||
if (matchIndex === -1) {
|
return {
|
||||||
const midPoint = Math.floor((searchStartIndex + searchEndIndex) / 2);
|
success: false,
|
||||||
let leftIndex = midPoint;
|
error: `Line range ${startLine}-${endLine} is invalid (file has ${originalLines.length} lines)\n\nDebug Info:\n- Requested Range: lines ${startLine}-${endLine}\n- File Bounds: lines 1-${originalLines.length}`,
|
||||||
let rightIndex = midPoint + 1;
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Search outward from the middle within bounds
|
// Try exact match first
|
||||||
while (leftIndex >= searchStartIndex || rightIndex <= searchEndIndex - searchLines.length) {
|
const originalChunk = originalLines.slice(exactStartIndex, exactEndIndex + 1).join("\n")
|
||||||
// Check left side if still in range
|
const similarity = getSimilarity(originalChunk, searchChunk)
|
||||||
if (leftIndex >= searchStartIndex) {
|
if (similarity >= this.fuzzyThreshold) {
|
||||||
const originalChunk = originalLines.slice(leftIndex, leftIndex + searchLines.length).join('\n');
|
matchIndex = exactStartIndex
|
||||||
const similarity = getSimilarity(originalChunk, searchChunk);
|
bestMatchScore = similarity
|
||||||
if (similarity > bestMatchScore) {
|
bestMatchContent = originalChunk
|
||||||
bestMatchScore = similarity;
|
} else {
|
||||||
matchIndex = leftIndex;
|
// Set bounds for buffered search
|
||||||
bestMatchContent = originalChunk;
|
searchStartIndex = Math.max(0, startLine - (this.bufferLines + 1))
|
||||||
}
|
searchEndIndex = Math.min(originalLines.length, endLine + this.bufferLines)
|
||||||
leftIndex--;
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check right side if still in range
|
// If no match found yet, try middle-out search within bounds
|
||||||
if (rightIndex <= searchEndIndex - searchLines.length) {
|
if (matchIndex === -1) {
|
||||||
const originalChunk = originalLines.slice(rightIndex, rightIndex + searchLines.length).join('\n');
|
const midPoint = Math.floor((searchStartIndex + searchEndIndex) / 2)
|
||||||
const similarity = getSimilarity(originalChunk, searchChunk);
|
let leftIndex = midPoint
|
||||||
if (similarity > bestMatchScore) {
|
let rightIndex = midPoint + 1
|
||||||
bestMatchScore = similarity;
|
|
||||||
matchIndex = rightIndex;
|
|
||||||
bestMatchContent = originalChunk;
|
|
||||||
}
|
|
||||||
rightIndex++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Require similarity to meet threshold
|
// Search outward from the middle within bounds
|
||||||
if (matchIndex === -1 || bestMatchScore < this.fuzzyThreshold) {
|
while (leftIndex >= searchStartIndex || rightIndex <= searchEndIndex - searchLines.length) {
|
||||||
const searchChunk = searchLines.join('\n');
|
// Check left side if still in range
|
||||||
const originalContentSection = startLine !== undefined && endLine !== undefined
|
if (leftIndex >= searchStartIndex) {
|
||||||
? `\n\nOriginal Content:\n${addLineNumbers(
|
const originalChunk = originalLines.slice(leftIndex, leftIndex + searchLines.length).join("\n")
|
||||||
originalLines.slice(
|
const similarity = getSimilarity(originalChunk, searchChunk)
|
||||||
Math.max(0, startLine - 1 - this.bufferLines),
|
if (similarity > bestMatchScore) {
|
||||||
Math.min(originalLines.length, endLine + this.bufferLines)
|
bestMatchScore = similarity
|
||||||
).join('\n'),
|
matchIndex = leftIndex
|
||||||
Math.max(1, startLine - this.bufferLines)
|
bestMatchContent = originalChunk
|
||||||
)}`
|
}
|
||||||
: `\n\nOriginal Content:\n${addLineNumbers(originalLines.join('\n'))}`;
|
leftIndex--
|
||||||
|
}
|
||||||
|
|
||||||
const bestMatchSection = bestMatchContent
|
// Check right side if still in range
|
||||||
? `\n\nBest Match Found:\n${addLineNumbers(bestMatchContent, matchIndex + 1)}`
|
if (rightIndex <= searchEndIndex - searchLines.length) {
|
||||||
: `\n\nBest Match Found:\n(no match)`;
|
const originalChunk = originalLines.slice(rightIndex, rightIndex + searchLines.length).join("\n")
|
||||||
|
const similarity = getSimilarity(originalChunk, searchChunk)
|
||||||
|
if (similarity > bestMatchScore) {
|
||||||
|
bestMatchScore = similarity
|
||||||
|
matchIndex = rightIndex
|
||||||
|
bestMatchContent = originalChunk
|
||||||
|
}
|
||||||
|
rightIndex++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const lineRange = startLine || endLine ?
|
// Require similarity to meet threshold
|
||||||
` at ${startLine ? `start: ${startLine}` : 'start'} to ${endLine ? `end: ${endLine}` : 'end'}` : '';
|
if (matchIndex === -1 || bestMatchScore < this.fuzzyThreshold) {
|
||||||
return {
|
const searchChunk = searchLines.join("\n")
|
||||||
success: false,
|
const originalContentSection =
|
||||||
error: `No sufficiently similar match found${lineRange} (${Math.floor(bestMatchScore * 100)}% similar, needs ${Math.floor(this.fuzzyThreshold * 100)}%)\n\nDebug Info:\n- Similarity Score: ${Math.floor(bestMatchScore * 100)}%\n- Required Threshold: ${Math.floor(this.fuzzyThreshold * 100)}%\n- Search Range: ${startLine && endLine ? `lines ${startLine}-${endLine}` : 'start to end'}\n\nSearch Content:\n${searchChunk}${bestMatchSection}${originalContentSection}`
|
startLine !== undefined && endLine !== undefined
|
||||||
};
|
? `\n\nOriginal Content:\n${addLineNumbers(
|
||||||
}
|
originalLines
|
||||||
|
.slice(
|
||||||
|
Math.max(0, startLine - 1 - this.bufferLines),
|
||||||
|
Math.min(originalLines.length, endLine + this.bufferLines),
|
||||||
|
)
|
||||||
|
.join("\n"),
|
||||||
|
Math.max(1, startLine - this.bufferLines),
|
||||||
|
)}`
|
||||||
|
: `\n\nOriginal Content:\n${addLineNumbers(originalLines.join("\n"))}`
|
||||||
|
|
||||||
// Get the matched lines from the original content
|
const bestMatchSection = bestMatchContent
|
||||||
const matchedLines = originalLines.slice(matchIndex, matchIndex + searchLines.length);
|
? `\n\nBest Match Found:\n${addLineNumbers(bestMatchContent, matchIndex + 1)}`
|
||||||
|
: `\n\nBest Match Found:\n(no match)`
|
||||||
// Get the exact indentation (preserving tabs/spaces) of each line
|
|
||||||
const originalIndents = matchedLines.map(line => {
|
|
||||||
const match = line.match(/^[\t ]*/);
|
|
||||||
return match ? match[0] : '';
|
|
||||||
});
|
|
||||||
|
|
||||||
// Get the exact indentation of each line in the search block
|
const lineRange =
|
||||||
const searchIndents = searchLines.map(line => {
|
startLine || endLine
|
||||||
const match = line.match(/^[\t ]*/);
|
? ` at ${startLine ? `start: ${startLine}` : "start"} to ${endLine ? `end: ${endLine}` : "end"}`
|
||||||
return match ? match[0] : '';
|
: ""
|
||||||
});
|
return {
|
||||||
|
success: false,
|
||||||
|
error: `No sufficiently similar match found${lineRange} (${Math.floor(bestMatchScore * 100)}% similar, needs ${Math.floor(this.fuzzyThreshold * 100)}%)\n\nDebug Info:\n- Similarity Score: ${Math.floor(bestMatchScore * 100)}%\n- Required Threshold: ${Math.floor(this.fuzzyThreshold * 100)}%\n- Search Range: ${startLine && endLine ? `lines ${startLine}-${endLine}` : "start to end"}\n\nSearch Content:\n${searchChunk}${bestMatchSection}${originalContentSection}`,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Apply the replacement while preserving exact indentation
|
// Get the matched lines from the original content
|
||||||
const indentedReplaceLines = replaceLines.map((line, i) => {
|
const matchedLines = originalLines.slice(matchIndex, matchIndex + searchLines.length)
|
||||||
// Get the matched line's exact indentation
|
|
||||||
const matchedIndent = originalIndents[0] || '';
|
|
||||||
|
|
||||||
// Get the current line's indentation relative to the search content
|
|
||||||
const currentIndentMatch = line.match(/^[\t ]*/);
|
|
||||||
const currentIndent = currentIndentMatch ? currentIndentMatch[0] : '';
|
|
||||||
const searchBaseIndent = searchIndents[0] || '';
|
|
||||||
|
|
||||||
// Calculate the relative indentation level
|
|
||||||
const searchBaseLevel = searchBaseIndent.length;
|
|
||||||
const currentLevel = currentIndent.length;
|
|
||||||
const relativeLevel = currentLevel - searchBaseLevel;
|
|
||||||
|
|
||||||
// If relative level is negative, remove indentation from matched indent
|
|
||||||
// If positive, add to matched indent
|
|
||||||
const finalIndent = relativeLevel < 0
|
|
||||||
? matchedIndent.slice(0, Math.max(0, matchedIndent.length + relativeLevel))
|
|
||||||
: matchedIndent + currentIndent.slice(searchBaseLevel);
|
|
||||||
|
|
||||||
return finalIndent + line.trim();
|
|
||||||
});
|
|
||||||
|
|
||||||
// Construct the final content
|
// Get the exact indentation (preserving tabs/spaces) of each line
|
||||||
const beforeMatch = originalLines.slice(0, matchIndex);
|
const originalIndents = matchedLines.map((line) => {
|
||||||
const afterMatch = originalLines.slice(matchIndex + searchLines.length);
|
const match = line.match(/^[\t ]*/)
|
||||||
|
return match ? match[0] : ""
|
||||||
const finalContent = [...beforeMatch, ...indentedReplaceLines, ...afterMatch].join(lineEnding);
|
})
|
||||||
return {
|
|
||||||
success: true,
|
// Get the exact indentation of each line in the search block
|
||||||
content: finalContent
|
const searchIndents = searchLines.map((line) => {
|
||||||
};
|
const match = line.match(/^[\t ]*/)
|
||||||
}
|
return match ? match[0] : ""
|
||||||
}
|
})
|
||||||
|
|
||||||
|
// Apply the replacement while preserving exact indentation
|
||||||
|
const indentedReplaceLines = replaceLines.map((line, i) => {
|
||||||
|
// Get the matched line's exact indentation
|
||||||
|
const matchedIndent = originalIndents[0] || ""
|
||||||
|
|
||||||
|
// Get the current line's indentation relative to the search content
|
||||||
|
const currentIndentMatch = line.match(/^[\t ]*/)
|
||||||
|
const currentIndent = currentIndentMatch ? currentIndentMatch[0] : ""
|
||||||
|
const searchBaseIndent = searchIndents[0] || ""
|
||||||
|
|
||||||
|
// Calculate the relative indentation level
|
||||||
|
const searchBaseLevel = searchBaseIndent.length
|
||||||
|
const currentLevel = currentIndent.length
|
||||||
|
const relativeLevel = currentLevel - searchBaseLevel
|
||||||
|
|
||||||
|
// If relative level is negative, remove indentation from matched indent
|
||||||
|
// If positive, add to matched indent
|
||||||
|
const finalIndent =
|
||||||
|
relativeLevel < 0
|
||||||
|
? matchedIndent.slice(0, Math.max(0, matchedIndent.length + relativeLevel))
|
||||||
|
: matchedIndent + currentIndent.slice(searchBaseLevel)
|
||||||
|
|
||||||
|
return finalIndent + line.trim()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Construct the final content
|
||||||
|
const beforeMatch = originalLines.slice(0, matchIndex)
|
||||||
|
const afterMatch = originalLines.slice(matchIndex + searchLines.length)
|
||||||
|
|
||||||
|
const finalContent = [...beforeMatch, ...indentedReplaceLines, ...afterMatch].join(lineEnding)
|
||||||
|
return {
|
||||||
|
success: true,
|
||||||
|
content: finalContent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ import { applyPatch } from "diff"
|
|||||||
import { DiffStrategy, DiffResult } from "../types"
|
import { DiffStrategy, DiffResult } from "../types"
|
||||||
|
|
||||||
export class UnifiedDiffStrategy implements DiffStrategy {
|
export class UnifiedDiffStrategy implements DiffStrategy {
|
||||||
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
|
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
|
||||||
return `## apply_diff
|
return `## apply_diff
|
||||||
Description: Apply a unified diff to a file at the specified path. This tool is useful when you need to make specific modifications to a file based on a set of changes provided in unified diff format (diff -U3).
|
Description: Apply a unified diff to a file at the specified path. This tool is useful when you need to make specific modifications to a file based on a set of changes provided in unified diff format (diff -U3).
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
@@ -106,32 +106,32 @@ Usage:
|
|||||||
Your diff here
|
Your diff here
|
||||||
</diff>
|
</diff>
|
||||||
</apply_diff>`
|
</apply_diff>`
|
||||||
}
|
}
|
||||||
|
|
||||||
async applyDiff(originalContent: string, diffContent: string): Promise<DiffResult> {
|
async applyDiff(originalContent: string, diffContent: string): Promise<DiffResult> {
|
||||||
try {
|
try {
|
||||||
const result = applyPatch(originalContent, diffContent)
|
const result = applyPatch(originalContent, diffContent)
|
||||||
if (result === false) {
|
if (result === false) {
|
||||||
return {
|
return {
|
||||||
success: false,
|
success: false,
|
||||||
error: "Failed to apply unified diff - patch rejected",
|
error: "Failed to apply unified diff - patch rejected",
|
||||||
details: {
|
details: {
|
||||||
searchContent: diffContent
|
searchContent: diffContent,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
success: true,
|
success: true,
|
||||||
content: result
|
content: result,
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
return {
|
return {
|
||||||
success: false,
|
success: false,
|
||||||
error: `Error applying unified diff: ${error.message}`,
|
error: `Error applying unified diff: ${error.message}`,
|
||||||
details: {
|
details: {
|
||||||
searchContent: diffContent
|
searchContent: diffContent,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,31 +2,35 @@
|
|||||||
* Interface for implementing different diff strategies
|
* Interface for implementing different diff strategies
|
||||||
*/
|
*/
|
||||||
|
|
||||||
export type DiffResult =
|
export type DiffResult =
|
||||||
| { success: true; content: string }
|
| { success: true; content: string }
|
||||||
| { success: false; error: string; details?: {
|
| {
|
||||||
similarity?: number;
|
success: false
|
||||||
threshold?: number;
|
error: string
|
||||||
matchedRange?: { start: number; end: number };
|
details?: {
|
||||||
searchContent?: string;
|
similarity?: number
|
||||||
bestMatch?: string;
|
threshold?: number
|
||||||
}};
|
matchedRange?: { start: number; end: number }
|
||||||
|
searchContent?: string
|
||||||
|
bestMatch?: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export interface DiffStrategy {
|
export interface DiffStrategy {
|
||||||
/**
|
/**
|
||||||
* Get the tool description for this diff strategy
|
* Get the tool description for this diff strategy
|
||||||
* @param args The tool arguments including cwd and toolOptions
|
* @param args The tool arguments including cwd and toolOptions
|
||||||
* @returns The complete tool description including format requirements and examples
|
* @returns The complete tool description including format requirements and examples
|
||||||
*/
|
*/
|
||||||
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string
|
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Apply a diff to the original content
|
* Apply a diff to the original content
|
||||||
* @param originalContent The original file content
|
* @param originalContent The original file content
|
||||||
* @param diffContent The diff content in the strategy's format
|
* @param diffContent The diff content in the strategy's format
|
||||||
* @param startLine Optional line number where the search block starts. If not provided, searches the entire file.
|
* @param startLine Optional line number where the search block starts. If not provided, searches the entire file.
|
||||||
* @param endLine Optional line number where the search block ends. If not provided, searches the entire file.
|
* @param endLine Optional line number where the search block ends. If not provided, searches the entire file.
|
||||||
* @returns A DiffResult object containing either the successful result or error details
|
* @returns A DiffResult object containing either the successful result or error details
|
||||||
*/
|
*/
|
||||||
applyDiff(originalContent: string, diffContent: string, startLine?: number, endLine?: number): Promise<DiffResult>
|
applyDiff(originalContent: string, diffContent: string, startLine?: number, endLine?: number): Promise<DiffResult>
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,20 +1,20 @@
|
|||||||
// Create mock vscode module before importing anything
|
// Create mock vscode module before importing anything
|
||||||
const createMockUri = (scheme: string, path: string) => ({
|
const createMockUri = (scheme: string, path: string) => ({
|
||||||
scheme,
|
scheme,
|
||||||
authority: '',
|
authority: "",
|
||||||
path,
|
path,
|
||||||
query: '',
|
query: "",
|
||||||
fragment: '',
|
fragment: "",
|
||||||
fsPath: path,
|
fsPath: path,
|
||||||
with: jest.fn(),
|
with: jest.fn(),
|
||||||
toString: () => path,
|
toString: () => path,
|
||||||
toJSON: () => ({
|
toJSON: () => ({
|
||||||
scheme,
|
scheme,
|
||||||
authority: '',
|
authority: "",
|
||||||
path,
|
path,
|
||||||
query: '',
|
query: "",
|
||||||
fragment: ''
|
fragment: "",
|
||||||
})
|
}),
|
||||||
})
|
})
|
||||||
|
|
||||||
const mockExecuteCommand = jest.fn()
|
const mockExecuteCommand = jest.fn()
|
||||||
@@ -23,9 +23,11 @@ const mockShowErrorMessage = jest.fn()
|
|||||||
|
|
||||||
const mockVscode = {
|
const mockVscode = {
|
||||||
workspace: {
|
workspace: {
|
||||||
workspaceFolders: [{
|
workspaceFolders: [
|
||||||
uri: { fsPath: "/test/workspace" }
|
{
|
||||||
}]
|
uri: { fsPath: "/test/workspace" },
|
||||||
|
},
|
||||||
|
],
|
||||||
},
|
},
|
||||||
window: {
|
window: {
|
||||||
showErrorMessage: mockShowErrorMessage,
|
showErrorMessage: mockShowErrorMessage,
|
||||||
@@ -34,17 +36,17 @@ const mockVscode = {
|
|||||||
createTextEditorDecorationType: jest.fn(),
|
createTextEditorDecorationType: jest.fn(),
|
||||||
createOutputChannel: jest.fn(),
|
createOutputChannel: jest.fn(),
|
||||||
createWebviewPanel: jest.fn(),
|
createWebviewPanel: jest.fn(),
|
||||||
activeTextEditor: undefined
|
activeTextEditor: undefined,
|
||||||
},
|
},
|
||||||
commands: {
|
commands: {
|
||||||
executeCommand: mockExecuteCommand
|
executeCommand: mockExecuteCommand,
|
||||||
},
|
},
|
||||||
env: {
|
env: {
|
||||||
openExternal: mockOpenExternal
|
openExternal: mockOpenExternal,
|
||||||
},
|
},
|
||||||
Uri: {
|
Uri: {
|
||||||
parse: jest.fn((url: string) => createMockUri('https', url)),
|
parse: jest.fn((url: string) => createMockUri("https", url)),
|
||||||
file: jest.fn((path: string) => createMockUri('file', path))
|
file: jest.fn((path: string) => createMockUri("file", path)),
|
||||||
},
|
},
|
||||||
Position: jest.fn(),
|
Position: jest.fn(),
|
||||||
Range: jest.fn(),
|
Range: jest.fn(),
|
||||||
@@ -54,12 +56,12 @@ const mockVscode = {
|
|||||||
Error: 0,
|
Error: 0,
|
||||||
Warning: 1,
|
Warning: 1,
|
||||||
Information: 2,
|
Information: 2,
|
||||||
Hint: 3
|
Hint: 3,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mock modules
|
// Mock modules
|
||||||
jest.mock('vscode', () => mockVscode)
|
jest.mock("vscode", () => mockVscode)
|
||||||
jest.mock("../../../services/browser/UrlContentFetcher")
|
jest.mock("../../../services/browser/UrlContentFetcher")
|
||||||
jest.mock("../../../utils/git")
|
jest.mock("../../../utils/git")
|
||||||
|
|
||||||
@@ -74,7 +76,7 @@ describe("mentions", () => {
|
|||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
jest.clearAllMocks()
|
||||||
|
|
||||||
// Create a mock instance with just the methods we need
|
// Create a mock instance with just the methods we need
|
||||||
mockUrlContentFetcher = {
|
mockUrlContentFetcher = {
|
||||||
launchBrowser: jest.fn().mockResolvedValue(undefined),
|
launchBrowser: jest.fn().mockResolvedValue(undefined),
|
||||||
@@ -94,14 +96,10 @@ Date: Mon Jan 5 23:50:06 2025 -0500
|
|||||||
Detailed commit message with multiple lines
|
Detailed commit message with multiple lines
|
||||||
- Fixed parsing issue
|
- Fixed parsing issue
|
||||||
- Added tests`
|
- Added tests`
|
||||||
|
|
||||||
jest.mocked(git.getCommitInfo).mockResolvedValue(commitInfo)
|
jest.mocked(git.getCommitInfo).mockResolvedValue(commitInfo)
|
||||||
|
|
||||||
const result = await parseMentions(
|
const result = await parseMentions(`Check out this commit @${commitHash}`, mockCwd, mockUrlContentFetcher)
|
||||||
`Check out this commit @${commitHash}`,
|
|
||||||
mockCwd,
|
|
||||||
mockUrlContentFetcher
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(result).toContain(`'${commitHash}' (see below for commit info)`)
|
expect(result).toContain(`'${commitHash}' (see below for commit info)`)
|
||||||
expect(result).toContain(`<git_commit hash="${commitHash}">`)
|
expect(result).toContain(`<git_commit hash="${commitHash}">`)
|
||||||
@@ -111,14 +109,10 @@ Detailed commit message with multiple lines
|
|||||||
it("should handle errors fetching git info", async () => {
|
it("should handle errors fetching git info", async () => {
|
||||||
const commitHash = "abc1234"
|
const commitHash = "abc1234"
|
||||||
const errorMessage = "Failed to get commit info"
|
const errorMessage = "Failed to get commit info"
|
||||||
|
|
||||||
jest.mocked(git.getCommitInfo).mockRejectedValue(new Error(errorMessage))
|
jest.mocked(git.getCommitInfo).mockRejectedValue(new Error(errorMessage))
|
||||||
|
|
||||||
const result = await parseMentions(
|
const result = await parseMentions(`Check out this commit @${commitHash}`, mockCwd, mockUrlContentFetcher)
|
||||||
`Check out this commit @${commitHash}`,
|
|
||||||
mockCwd,
|
|
||||||
mockUrlContentFetcher
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(result).toContain(`'${commitHash}' (see below for commit info)`)
|
expect(result).toContain(`'${commitHash}' (see below for commit info)`)
|
||||||
expect(result).toContain(`<git_commit hash="${commitHash}">`)
|
expect(result).toContain(`<git_commit hash="${commitHash}">`)
|
||||||
@@ -143,13 +137,15 @@ Detailed commit message with multiple lines
|
|||||||
const mockUri = mockVscode.Uri.parse(url)
|
const mockUri = mockVscode.Uri.parse(url)
|
||||||
expect(mockOpenExternal).toHaveBeenCalled()
|
expect(mockOpenExternal).toHaveBeenCalled()
|
||||||
const calledArg = mockOpenExternal.mock.calls[0][0]
|
const calledArg = mockOpenExternal.mock.calls[0][0]
|
||||||
expect(calledArg).toEqual(expect.objectContaining({
|
expect(calledArg).toEqual(
|
||||||
scheme: mockUri.scheme,
|
expect.objectContaining({
|
||||||
authority: mockUri.authority,
|
scheme: mockUri.scheme,
|
||||||
path: mockUri.path,
|
authority: mockUri.authority,
|
||||||
query: mockUri.query,
|
path: mockUri.path,
|
||||||
fragment: mockUri.fragment
|
query: mockUri.query,
|
||||||
}))
|
fragment: mockUri.fragment,
|
||||||
|
}),
|
||||||
|
)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
import { Mode, isToolAllowedForMode, TestToolName, getModeConfig } from '../shared/modes';
|
import { Mode, isToolAllowedForMode, TestToolName, getModeConfig } from "../shared/modes"
|
||||||
|
|
||||||
export { isToolAllowedForMode };
|
export { isToolAllowedForMode }
|
||||||
export type { TestToolName };
|
export type { TestToolName }
|
||||||
|
|
||||||
export function validateToolUse(toolName: TestToolName, mode: Mode): void {
|
export function validateToolUse(toolName: TestToolName, mode: Mode): void {
|
||||||
if (!isToolAllowedForMode(toolName, mode)) {
|
if (!isToolAllowedForMode(toolName, mode)) {
|
||||||
throw new Error(
|
throw new Error(`Tool "${toolName}" is not allowed in ${mode} mode.`)
|
||||||
`Tool "${toolName}" is not allowed in ${mode} mode.`
|
}
|
||||||
);
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,422 +1,357 @@
|
|||||||
import { SYSTEM_PROMPT, addCustomInstructions } from '../system'
|
import { SYSTEM_PROMPT, addCustomInstructions } from "../system"
|
||||||
import { McpHub } from '../../../services/mcp/McpHub'
|
import { McpHub } from "../../../services/mcp/McpHub"
|
||||||
import { McpServer } from '../../../shared/mcp'
|
import { McpServer } from "../../../shared/mcp"
|
||||||
import { ClineProvider } from '../../../core/webview/ClineProvider'
|
import { ClineProvider } from "../../../core/webview/ClineProvider"
|
||||||
import { SearchReplaceDiffStrategy } from '../../../core/diff/strategies/search-replace'
|
import { SearchReplaceDiffStrategy } from "../../../core/diff/strategies/search-replace"
|
||||||
import fs from 'fs/promises'
|
import fs from "fs/promises"
|
||||||
import os from 'os'
|
import os from "os"
|
||||||
import { defaultModeSlug, modes } from '../../../shared/modes'
|
import { defaultModeSlug, modes } from "../../../shared/modes"
|
||||||
// Import path utils to get access to toPosix string extension
|
// Import path utils to get access to toPosix string extension
|
||||||
import '../../../utils/path'
|
import "../../../utils/path"
|
||||||
|
|
||||||
// Mock environment-specific values for consistent tests
|
// Mock environment-specific values for consistent tests
|
||||||
jest.mock('os', () => ({
|
jest.mock("os", () => ({
|
||||||
...jest.requireActual('os'),
|
...jest.requireActual("os"),
|
||||||
homedir: () => '/home/user'
|
homedir: () => "/home/user",
|
||||||
}))
|
}))
|
||||||
|
|
||||||
jest.mock('default-shell', () => '/bin/bash')
|
jest.mock("default-shell", () => "/bin/bash")
|
||||||
|
|
||||||
jest.mock('os-name', () => () => 'Linux')
|
jest.mock("os-name", () => () => "Linux")
|
||||||
|
|
||||||
// Mock fs.readFile to return empty mcpServers config and mock rules files
|
// Mock fs.readFile to return empty mcpServers config and mock rules files
|
||||||
jest.mock('fs/promises', () => ({
|
jest.mock("fs/promises", () => ({
|
||||||
...jest.requireActual('fs/promises'),
|
...jest.requireActual("fs/promises"),
|
||||||
readFile: jest.fn().mockImplementation(async (path: string) => {
|
readFile: jest.fn().mockImplementation(async (path: string) => {
|
||||||
if (path.endsWith('mcpSettings.json')) {
|
if (path.endsWith("mcpSettings.json")) {
|
||||||
return '{"mcpServers": {}}'
|
return '{"mcpServers": {}}'
|
||||||
}
|
}
|
||||||
if (path.endsWith('.clinerules-code')) {
|
if (path.endsWith(".clinerules-code")) {
|
||||||
return '# Code Mode Rules\n1. Code specific rule'
|
return "# Code Mode Rules\n1. Code specific rule"
|
||||||
}
|
}
|
||||||
if (path.endsWith('.clinerules-ask')) {
|
if (path.endsWith(".clinerules-ask")) {
|
||||||
return '# Ask Mode Rules\n1. Ask specific rule'
|
return "# Ask Mode Rules\n1. Ask specific rule"
|
||||||
}
|
}
|
||||||
if (path.endsWith('.clinerules-architect')) {
|
if (path.endsWith(".clinerules-architect")) {
|
||||||
return '# Architect Mode Rules\n1. Architect specific rule'
|
return "# Architect Mode Rules\n1. Architect specific rule"
|
||||||
}
|
}
|
||||||
if (path.endsWith('.clinerules')) {
|
if (path.endsWith(".clinerules")) {
|
||||||
return '# Test Rules\n1. First rule\n2. Second rule'
|
return "# Test Rules\n1. First rule\n2. Second rule"
|
||||||
}
|
}
|
||||||
return ''
|
return ""
|
||||||
}),
|
}),
|
||||||
writeFile: jest.fn().mockResolvedValue(undefined)
|
writeFile: jest.fn().mockResolvedValue(undefined),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Create a minimal mock of ClineProvider
|
// Create a minimal mock of ClineProvider
|
||||||
const mockProvider = {
|
const mockProvider = {
|
||||||
ensureMcpServersDirectoryExists: async () => '/mock/mcp/path',
|
ensureMcpServersDirectoryExists: async () => "/mock/mcp/path",
|
||||||
ensureSettingsDirectoryExists: async () => '/mock/settings/path',
|
ensureSettingsDirectoryExists: async () => "/mock/settings/path",
|
||||||
postMessageToWebview: async () => {},
|
postMessageToWebview: async () => {},
|
||||||
context: {
|
context: {
|
||||||
extension: {
|
extension: {
|
||||||
packageJSON: {
|
packageJSON: {
|
||||||
version: '1.0.0'
|
version: "1.0.0",
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
} as unknown as ClineProvider
|
} as unknown as ClineProvider
|
||||||
|
|
||||||
// Instead of extending McpHub, create a mock that implements just what we need
|
// Instead of extending McpHub, create a mock that implements just what we need
|
||||||
const createMockMcpHub = (): McpHub => ({
|
const createMockMcpHub = (): McpHub =>
|
||||||
getServers: () => [],
|
({
|
||||||
getMcpServersPath: async () => '/mock/mcp/path',
|
getServers: () => [],
|
||||||
getMcpSettingsFilePath: async () => '/mock/settings/path',
|
getMcpServersPath: async () => "/mock/mcp/path",
|
||||||
dispose: async () => {},
|
getMcpSettingsFilePath: async () => "/mock/settings/path",
|
||||||
// Add other required public methods with no-op implementations
|
dispose: async () => {},
|
||||||
restartConnection: async () => {},
|
// Add other required public methods with no-op implementations
|
||||||
readResource: async () => ({ contents: [] }),
|
restartConnection: async () => {},
|
||||||
callTool: async () => ({ content: [] }),
|
readResource: async () => ({ contents: [] }),
|
||||||
toggleServerDisabled: async () => {},
|
callTool: async () => ({ content: [] }),
|
||||||
toggleToolAlwaysAllow: async () => {},
|
toggleServerDisabled: async () => {},
|
||||||
isConnecting: false,
|
toggleToolAlwaysAllow: async () => {},
|
||||||
connections: []
|
isConnecting: false,
|
||||||
} as unknown as McpHub)
|
connections: [],
|
||||||
|
}) as unknown as McpHub
|
||||||
|
|
||||||
describe('SYSTEM_PROMPT', () => {
|
describe("SYSTEM_PROMPT", () => {
|
||||||
let mockMcpHub: McpHub
|
let mockMcpHub: McpHub
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
jest.clearAllMocks()
|
||||||
})
|
})
|
||||||
|
|
||||||
afterEach(async () => {
|
afterEach(async () => {
|
||||||
// Clean up any McpHub instances
|
// Clean up any McpHub instances
|
||||||
if (mockMcpHub) {
|
if (mockMcpHub) {
|
||||||
await mockMcpHub.dispose()
|
await mockMcpHub.dispose()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should maintain consistent system prompt', async () => {
|
it("should maintain consistent system prompt", async () => {
|
||||||
const prompt = await SYSTEM_PROMPT(
|
const prompt = await SYSTEM_PROMPT(
|
||||||
'/test/path',
|
"/test/path",
|
||||||
false, // supportsComputerUse
|
false, // supportsComputerUse
|
||||||
undefined, // mcpHub
|
undefined, // mcpHub
|
||||||
undefined, // diffStrategy
|
undefined, // diffStrategy
|
||||||
undefined // browserViewportSize
|
undefined, // browserViewportSize
|
||||||
)
|
)
|
||||||
|
|
||||||
expect(prompt).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should include browser actions when supportsComputerUse is true', async () => {
|
expect(prompt).toMatchSnapshot()
|
||||||
const prompt = await SYSTEM_PROMPT(
|
})
|
||||||
'/test/path',
|
|
||||||
true,
|
|
||||||
undefined,
|
|
||||||
undefined,
|
|
||||||
'1280x800'
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(prompt).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should include MCP server info when mcpHub is provided', async () => {
|
it("should include browser actions when supportsComputerUse is true", async () => {
|
||||||
mockMcpHub = createMockMcpHub()
|
const prompt = await SYSTEM_PROMPT("/test/path", true, undefined, undefined, "1280x800")
|
||||||
|
|
||||||
const prompt = await SYSTEM_PROMPT(
|
expect(prompt).toMatchSnapshot()
|
||||||
'/test/path',
|
})
|
||||||
false,
|
|
||||||
mockMcpHub
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(prompt).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should explicitly handle undefined mcpHub', async () => {
|
it("should include MCP server info when mcpHub is provided", async () => {
|
||||||
const prompt = await SYSTEM_PROMPT(
|
mockMcpHub = createMockMcpHub()
|
||||||
'/test/path',
|
|
||||||
false,
|
|
||||||
undefined, // explicitly undefined mcpHub
|
|
||||||
undefined,
|
|
||||||
undefined
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(prompt).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle different browser viewport sizes', async () => {
|
const prompt = await SYSTEM_PROMPT("/test/path", false, mockMcpHub)
|
||||||
const prompt = await SYSTEM_PROMPT(
|
|
||||||
'/test/path',
|
|
||||||
true,
|
|
||||||
undefined,
|
|
||||||
undefined,
|
|
||||||
'900x600' // different viewport size
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(prompt).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should include diff strategy tool description', async () => {
|
expect(prompt).toMatchSnapshot()
|
||||||
const prompt = await SYSTEM_PROMPT(
|
})
|
||||||
'/test/path',
|
|
||||||
false,
|
|
||||||
undefined,
|
|
||||||
new SearchReplaceDiffStrategy(), // Use actual diff strategy from the codebase
|
|
||||||
undefined
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(prompt).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
afterAll(() => {
|
it("should explicitly handle undefined mcpHub", async () => {
|
||||||
jest.restoreAllMocks()
|
const prompt = await SYSTEM_PROMPT(
|
||||||
})
|
"/test/path",
|
||||||
|
false,
|
||||||
|
undefined, // explicitly undefined mcpHub
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(prompt).toMatchSnapshot()
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle different browser viewport sizes", async () => {
|
||||||
|
const prompt = await SYSTEM_PROMPT(
|
||||||
|
"/test/path",
|
||||||
|
true,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
"900x600", // different viewport size
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(prompt).toMatchSnapshot()
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should include diff strategy tool description", async () => {
|
||||||
|
const prompt = await SYSTEM_PROMPT(
|
||||||
|
"/test/path",
|
||||||
|
false,
|
||||||
|
undefined,
|
||||||
|
new SearchReplaceDiffStrategy(), // Use actual diff strategy from the codebase
|
||||||
|
undefined,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(prompt).toMatchSnapshot()
|
||||||
|
})
|
||||||
|
|
||||||
|
afterAll(() => {
|
||||||
|
jest.restoreAllMocks()
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('addCustomInstructions', () => {
|
describe("addCustomInstructions", () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
jest.clearAllMocks()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should generate correct prompt for architect mode', async () => {
|
it("should generate correct prompt for architect mode", async () => {
|
||||||
const prompt = await SYSTEM_PROMPT(
|
const prompt = await SYSTEM_PROMPT("/test/path", false, undefined, undefined, undefined, "architect")
|
||||||
'/test/path',
|
|
||||||
false,
|
|
||||||
undefined,
|
|
||||||
undefined,
|
|
||||||
undefined,
|
|
||||||
'architect'
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(prompt).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should generate correct prompt for ask mode', async () => {
|
expect(prompt).toMatchSnapshot()
|
||||||
const prompt = await SYSTEM_PROMPT(
|
})
|
||||||
'/test/path',
|
|
||||||
false,
|
|
||||||
undefined,
|
|
||||||
undefined,
|
|
||||||
undefined,
|
|
||||||
'ask'
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(prompt).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should prioritize mode-specific rules for code mode', async () => {
|
it("should generate correct prompt for ask mode", async () => {
|
||||||
const instructions = await addCustomInstructions(
|
const prompt = await SYSTEM_PROMPT("/test/path", false, undefined, undefined, undefined, "ask")
|
||||||
{},
|
|
||||||
'/test/path',
|
|
||||||
defaultModeSlug
|
|
||||||
)
|
|
||||||
expect(instructions).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should prioritize mode-specific rules for ask mode', async () => {
|
expect(prompt).toMatchSnapshot()
|
||||||
const instructions = await addCustomInstructions(
|
})
|
||||||
{},
|
|
||||||
'/test/path',
|
|
||||||
modes[2].slug
|
|
||||||
)
|
|
||||||
expect(instructions).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should prioritize mode-specific rules for architect mode', async () => {
|
it("should prioritize mode-specific rules for code mode", async () => {
|
||||||
const instructions = await addCustomInstructions(
|
const instructions = await addCustomInstructions({}, "/test/path", defaultModeSlug)
|
||||||
{},
|
expect(instructions).toMatchSnapshot()
|
||||||
'/test/path',
|
})
|
||||||
modes[1].slug
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(instructions).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should prioritize mode-specific rules for test engineer mode', async () => {
|
it("should prioritize mode-specific rules for ask mode", async () => {
|
||||||
// Mock readFile to include test engineer rules
|
const instructions = await addCustomInstructions({}, "/test/path", modes[2].slug)
|
||||||
const mockReadFile = jest.fn().mockImplementation(async (path: string) => {
|
expect(instructions).toMatchSnapshot()
|
||||||
if (path.endsWith('.clinerules-test')) {
|
})
|
||||||
return '# Test Engineer Rules\n1. Always write tests first\n2. Get approval before modifying non-test code'
|
|
||||||
}
|
|
||||||
if (path.endsWith('.clinerules')) {
|
|
||||||
return '# Test Rules\n1. First rule\n2. Second rule'
|
|
||||||
}
|
|
||||||
return ''
|
|
||||||
})
|
|
||||||
jest.spyOn(fs, 'readFile').mockImplementation(mockReadFile)
|
|
||||||
|
|
||||||
const instructions = await addCustomInstructions(
|
it("should prioritize mode-specific rules for architect mode", async () => {
|
||||||
{},
|
const instructions = await addCustomInstructions({}, "/test/path", modes[1].slug)
|
||||||
'/test/path',
|
|
||||||
'test'
|
|
||||||
)
|
|
||||||
expect(instructions).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should prioritize mode-specific rules for code reviewer mode', async () => {
|
expect(instructions).toMatchSnapshot()
|
||||||
// Mock readFile to include code reviewer rules
|
})
|
||||||
const mockReadFile = jest.fn().mockImplementation(async (path: string) => {
|
|
||||||
if (path.endsWith('.clinerules-review')) {
|
|
||||||
return '# Code Reviewer Rules\n1. Provide specific examples in feedback\n2. Focus on maintainability and best practices'
|
|
||||||
}
|
|
||||||
if (path.endsWith('.clinerules')) {
|
|
||||||
return '# Test Rules\n1. First rule\n2. Second rule'
|
|
||||||
}
|
|
||||||
return ''
|
|
||||||
})
|
|
||||||
jest.spyOn(fs, 'readFile').mockImplementation(mockReadFile)
|
|
||||||
|
|
||||||
const instructions = await addCustomInstructions(
|
it("should prioritize mode-specific rules for test engineer mode", async () => {
|
||||||
{},
|
// Mock readFile to include test engineer rules
|
||||||
'/test/path',
|
const mockReadFile = jest.fn().mockImplementation(async (path: string) => {
|
||||||
'review'
|
if (path.endsWith(".clinerules-test")) {
|
||||||
)
|
return "# Test Engineer Rules\n1. Always write tests first\n2. Get approval before modifying non-test code"
|
||||||
expect(instructions).toMatchSnapshot()
|
}
|
||||||
})
|
if (path.endsWith(".clinerules")) {
|
||||||
|
return "# Test Rules\n1. First rule\n2. Second rule"
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
})
|
||||||
|
jest.spyOn(fs, "readFile").mockImplementation(mockReadFile)
|
||||||
|
|
||||||
it('should generate correct prompt for test engineer mode', async () => {
|
const instructions = await addCustomInstructions({}, "/test/path", "test")
|
||||||
const prompt = await SYSTEM_PROMPT(
|
expect(instructions).toMatchSnapshot()
|
||||||
'/test/path',
|
})
|
||||||
false,
|
|
||||||
undefined,
|
|
||||||
undefined,
|
|
||||||
undefined,
|
|
||||||
'test'
|
|
||||||
)
|
|
||||||
|
|
||||||
// Verify test engineer role requirements
|
|
||||||
expect(prompt).toContain('must ask the user to confirm before making ANY changes to non-test code')
|
|
||||||
expect(prompt).toContain('ask the user to confirm your test plan')
|
|
||||||
expect(prompt).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should generate correct prompt for code reviewer mode', async () => {
|
it("should prioritize mode-specific rules for code reviewer mode", async () => {
|
||||||
const prompt = await SYSTEM_PROMPT(
|
// Mock readFile to include code reviewer rules
|
||||||
'/test/path',
|
const mockReadFile = jest.fn().mockImplementation(async (path: string) => {
|
||||||
false,
|
if (path.endsWith(".clinerules-review")) {
|
||||||
undefined,
|
return "# Code Reviewer Rules\n1. Provide specific examples in feedback\n2. Focus on maintainability and best practices"
|
||||||
undefined,
|
}
|
||||||
undefined,
|
if (path.endsWith(".clinerules")) {
|
||||||
'review'
|
return "# Test Rules\n1. First rule\n2. Second rule"
|
||||||
)
|
}
|
||||||
|
return ""
|
||||||
// Verify code reviewer role constraints
|
})
|
||||||
expect(prompt).toContain('providing detailed, actionable feedback')
|
jest.spyOn(fs, "readFile").mockImplementation(mockReadFile)
|
||||||
expect(prompt).toContain('maintain a read-only approach')
|
|
||||||
expect(prompt).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should fall back to generic rules when mode-specific rules not found', async () => {
|
const instructions = await addCustomInstructions({}, "/test/path", "review")
|
||||||
// Mock readFile to return ENOENT for mode-specific file
|
expect(instructions).toMatchSnapshot()
|
||||||
const mockReadFile = jest.fn().mockImplementation(async (path: string) => {
|
})
|
||||||
if (path.endsWith('.clinerules-code') ||
|
|
||||||
path.endsWith('.clinerules-test') ||
|
|
||||||
path.endsWith('.clinerules-review')) {
|
|
||||||
const error = new Error('ENOENT') as NodeJS.ErrnoException
|
|
||||||
error.code = 'ENOENT'
|
|
||||||
throw error
|
|
||||||
}
|
|
||||||
if (path.endsWith('.clinerules')) {
|
|
||||||
return '# Test Rules\n1. First rule\n2. Second rule'
|
|
||||||
}
|
|
||||||
return ''
|
|
||||||
})
|
|
||||||
jest.spyOn(fs, 'readFile').mockImplementation(mockReadFile)
|
|
||||||
|
|
||||||
const instructions = await addCustomInstructions(
|
it("should generate correct prompt for test engineer mode", async () => {
|
||||||
{},
|
const prompt = await SYSTEM_PROMPT("/test/path", false, undefined, undefined, undefined, "test")
|
||||||
'/test/path',
|
|
||||||
defaultModeSlug
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(instructions).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should include preferred language when provided', async () => {
|
// Verify test engineer role requirements
|
||||||
const instructions = await addCustomInstructions(
|
expect(prompt).toContain("must ask the user to confirm before making ANY changes to non-test code")
|
||||||
{ preferredLanguage: 'Spanish' },
|
expect(prompt).toContain("ask the user to confirm your test plan")
|
||||||
'/test/path',
|
expect(prompt).toMatchSnapshot()
|
||||||
defaultModeSlug
|
})
|
||||||
)
|
|
||||||
|
|
||||||
expect(instructions).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should include custom instructions when provided', async () => {
|
it("should generate correct prompt for code reviewer mode", async () => {
|
||||||
const instructions = await addCustomInstructions(
|
const prompt = await SYSTEM_PROMPT("/test/path", false, undefined, undefined, undefined, "review")
|
||||||
{ customInstructions: 'Custom test instructions' },
|
|
||||||
'/test/path'
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(instructions).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should combine all custom instructions', async () => {
|
// Verify code reviewer role constraints
|
||||||
const instructions = await addCustomInstructions(
|
expect(prompt).toContain("providing detailed, actionable feedback")
|
||||||
{
|
expect(prompt).toContain("maintain a read-only approach")
|
||||||
customInstructions: 'Custom test instructions',
|
expect(prompt).toMatchSnapshot()
|
||||||
preferredLanguage: 'French'
|
})
|
||||||
},
|
|
||||||
'/test/path',
|
|
||||||
defaultModeSlug
|
|
||||||
)
|
|
||||||
expect(instructions).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle undefined mode-specific instructions', async () => {
|
it("should fall back to generic rules when mode-specific rules not found", async () => {
|
||||||
const instructions = await addCustomInstructions(
|
// Mock readFile to return ENOENT for mode-specific file
|
||||||
{},
|
const mockReadFile = jest.fn().mockImplementation(async (path: string) => {
|
||||||
'/test/path'
|
if (
|
||||||
)
|
path.endsWith(".clinerules-code") ||
|
||||||
|
path.endsWith(".clinerules-test") ||
|
||||||
expect(instructions).toMatchSnapshot()
|
path.endsWith(".clinerules-review")
|
||||||
})
|
) {
|
||||||
|
const error = new Error("ENOENT") as NodeJS.ErrnoException
|
||||||
|
error.code = "ENOENT"
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
if (path.endsWith(".clinerules")) {
|
||||||
|
return "# Test Rules\n1. First rule\n2. Second rule"
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
})
|
||||||
|
jest.spyOn(fs, "readFile").mockImplementation(mockReadFile)
|
||||||
|
|
||||||
it('should trim mode-specific instructions', async () => {
|
const instructions = await addCustomInstructions({}, "/test/path", defaultModeSlug)
|
||||||
const instructions = await addCustomInstructions(
|
|
||||||
{ customInstructions: ' Custom mode instructions ' },
|
|
||||||
'/test/path'
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(instructions).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle empty mode-specific instructions', async () => {
|
expect(instructions).toMatchSnapshot()
|
||||||
const instructions = await addCustomInstructions(
|
})
|
||||||
{ customInstructions: '' },
|
|
||||||
'/test/path'
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(instructions).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should combine global and mode-specific instructions', async () => {
|
it("should include preferred language when provided", async () => {
|
||||||
const instructions = await addCustomInstructions(
|
const instructions = await addCustomInstructions(
|
||||||
{
|
{ preferredLanguage: "Spanish" },
|
||||||
customInstructions: 'Global instructions',
|
"/test/path",
|
||||||
customPrompts: {
|
defaultModeSlug,
|
||||||
code: { customInstructions: 'Mode-specific instructions' }
|
)
|
||||||
}
|
|
||||||
},
|
|
||||||
'/test/path',
|
|
||||||
defaultModeSlug
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(instructions).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should prioritize mode-specific instructions after global ones', async () => {
|
expect(instructions).toMatchSnapshot()
|
||||||
const instructions = await addCustomInstructions(
|
})
|
||||||
{
|
|
||||||
customInstructions: 'First instruction',
|
|
||||||
customPrompts: {
|
|
||||||
code: { customInstructions: 'Second instruction' }
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'/test/path',
|
|
||||||
defaultModeSlug
|
|
||||||
)
|
|
||||||
|
|
||||||
const instructionParts = instructions.split('\n\n')
|
|
||||||
const globalIndex = instructionParts.findIndex(part => part.includes('First instruction'))
|
|
||||||
const modeSpecificIndex = instructionParts.findIndex(part => part.includes('Second instruction'))
|
|
||||||
|
|
||||||
expect(globalIndex).toBeLessThan(modeSpecificIndex)
|
|
||||||
expect(instructions).toMatchSnapshot()
|
|
||||||
})
|
|
||||||
|
|
||||||
afterAll(() => {
|
it("should include custom instructions when provided", async () => {
|
||||||
jest.restoreAllMocks()
|
const instructions = await addCustomInstructions(
|
||||||
})
|
{ customInstructions: "Custom test instructions" },
|
||||||
|
"/test/path",
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(instructions).toMatchSnapshot()
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should combine all custom instructions", async () => {
|
||||||
|
const instructions = await addCustomInstructions(
|
||||||
|
{
|
||||||
|
customInstructions: "Custom test instructions",
|
||||||
|
preferredLanguage: "French",
|
||||||
|
},
|
||||||
|
"/test/path",
|
||||||
|
defaultModeSlug,
|
||||||
|
)
|
||||||
|
expect(instructions).toMatchSnapshot()
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle undefined mode-specific instructions", async () => {
|
||||||
|
const instructions = await addCustomInstructions({}, "/test/path")
|
||||||
|
|
||||||
|
expect(instructions).toMatchSnapshot()
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should trim mode-specific instructions", async () => {
|
||||||
|
const instructions = await addCustomInstructions(
|
||||||
|
{ customInstructions: " Custom mode instructions " },
|
||||||
|
"/test/path",
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(instructions).toMatchSnapshot()
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle empty mode-specific instructions", async () => {
|
||||||
|
const instructions = await addCustomInstructions({ customInstructions: "" }, "/test/path")
|
||||||
|
|
||||||
|
expect(instructions).toMatchSnapshot()
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should combine global and mode-specific instructions", async () => {
|
||||||
|
const instructions = await addCustomInstructions(
|
||||||
|
{
|
||||||
|
customInstructions: "Global instructions",
|
||||||
|
customPrompts: {
|
||||||
|
code: { customInstructions: "Mode-specific instructions" },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"/test/path",
|
||||||
|
defaultModeSlug,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(instructions).toMatchSnapshot()
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should prioritize mode-specific instructions after global ones", async () => {
|
||||||
|
const instructions = await addCustomInstructions(
|
||||||
|
{
|
||||||
|
customInstructions: "First instruction",
|
||||||
|
customPrompts: {
|
||||||
|
code: { customInstructions: "Second instruction" },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"/test/path",
|
||||||
|
defaultModeSlug,
|
||||||
|
)
|
||||||
|
|
||||||
|
const instructionParts = instructions.split("\n\n")
|
||||||
|
const globalIndex = instructionParts.findIndex((part) => part.includes("First instruction"))
|
||||||
|
const modeSpecificIndex = instructionParts.findIndex((part) => part.includes("Second instruction"))
|
||||||
|
|
||||||
|
expect(globalIndex).toBeLessThan(modeSpecificIndex)
|
||||||
|
expect(instructions).toMatchSnapshot()
|
||||||
|
})
|
||||||
|
|
||||||
|
afterAll(() => {
|
||||||
|
jest.restoreAllMocks()
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -2,27 +2,31 @@ import { DiffStrategy } from "../../diff/DiffStrategy"
|
|||||||
import { McpHub } from "../../../services/mcp/McpHub"
|
import { McpHub } from "../../../services/mcp/McpHub"
|
||||||
|
|
||||||
export function getCapabilitiesSection(
|
export function getCapabilitiesSection(
|
||||||
cwd: string,
|
cwd: string,
|
||||||
supportsComputerUse: boolean,
|
supportsComputerUse: boolean,
|
||||||
mcpHub?: McpHub,
|
mcpHub?: McpHub,
|
||||||
diffStrategy?: DiffStrategy,
|
diffStrategy?: DiffStrategy,
|
||||||
): string {
|
): string {
|
||||||
return `====
|
return `====
|
||||||
|
|
||||||
CAPABILITIES
|
CAPABILITIES
|
||||||
|
|
||||||
- You have access to tools that let you execute CLI commands on the user's computer, list files, view source code definitions, regex search${
|
- You have access to tools that let you execute CLI commands on the user's computer, list files, view source code definitions, regex search${
|
||||||
supportsComputerUse ? ", use the browser" : ""
|
supportsComputerUse ? ", use the browser" : ""
|
||||||
}, read and write files, and ask follow-up questions. These tools help you effectively accomplish a wide range of tasks, such as writing code, making edits or improvements to existing files, understanding the current state of a project, performing system operations, and much more.
|
}, read and write files, and ask follow-up questions. These tools help you effectively accomplish a wide range of tasks, such as writing code, making edits or improvements to existing files, understanding the current state of a project, performing system operations, and much more.
|
||||||
- When the user initially gives you a task, a recursive list of all filepaths in the current working directory ('${cwd}') will be included in environment_details. This provides an overview of the project's file structure, offering key insights into the project from directory/file names (how developers conceptualize and organize their code) and file extensions (the language used). This can also guide decision-making on which files to explore further. If you need to further explore directories such as outside the current working directory, you can use the list_files tool. If you pass 'true' for the recursive parameter, it will list files recursively. Otherwise, it will list files at the top level, which is better suited for generic directories where you don't necessarily need the nested structure, like the Desktop.
|
- When the user initially gives you a task, a recursive list of all filepaths in the current working directory ('${cwd}') will be included in environment_details. This provides an overview of the project's file structure, offering key insights into the project from directory/file names (how developers conceptualize and organize their code) and file extensions (the language used). This can also guide decision-making on which files to explore further. If you need to further explore directories such as outside the current working directory, you can use the list_files tool. If you pass 'true' for the recursive parameter, it will list files recursively. Otherwise, it will list files at the top level, which is better suited for generic directories where you don't necessarily need the nested structure, like the Desktop.
|
||||||
- You can use search_files to perform regex searches across files in a specified directory, outputting context-rich results that include surrounding lines. This is particularly useful for understanding code patterns, finding specific implementations, or identifying areas that need refactoring.
|
- You can use search_files to perform regex searches across files in a specified directory, outputting context-rich results that include surrounding lines. This is particularly useful for understanding code patterns, finding specific implementations, or identifying areas that need refactoring.
|
||||||
- You can use the list_code_definition_names tool to get an overview of source code definitions for all files at the top level of a specified directory. This can be particularly useful when you need to understand the broader context and relationships between certain parts of the code. You may need to call this tool multiple times to understand various parts of the codebase related to the task.
|
- You can use the list_code_definition_names tool to get an overview of source code definitions for all files at the top level of a specified directory. This can be particularly useful when you need to understand the broader context and relationships between certain parts of the code. You may need to call this tool multiple times to understand various parts of the codebase related to the task.
|
||||||
- For example, when asked to make edits or improvements you might analyze the file structure in the initial environment_details to get an overview of the project, then use list_code_definition_names to get further insight using source code definitions for files located in relevant directories, then read_file to examine the contents of relevant files, analyze the code and suggest improvements or make necessary edits, then use the write_to_file ${diffStrategy ? "or apply_diff " : ""}tool to apply the changes. If you refactored code that could affect other parts of the codebase, you could use search_files to ensure you update other files as needed.
|
- For example, when asked to make edits or improvements you might analyze the file structure in the initial environment_details to get an overview of the project, then use list_code_definition_names to get further insight using source code definitions for files located in relevant directories, then read_file to examine the contents of relevant files, analyze the code and suggest improvements or make necessary edits, then use the write_to_file ${diffStrategy ? "or apply_diff " : ""}tool to apply the changes. If you refactored code that could affect other parts of the codebase, you could use search_files to ensure you update other files as needed.
|
||||||
- You can use the execute_command tool to run commands on the user's computer whenever you feel it can help accomplish the user's task. When you need to execute a CLI command, you must provide a clear explanation of what the command does. Prefer to execute complex CLI commands over creating executable scripts, since they are more flexible and easier to run. Interactive and long-running commands are allowed, since the commands are run in the user's VSCode terminal. The user may keep commands running in the background and you will be kept updated on their status along the way. Each command you execute is run in a new terminal instance.${
|
- You can use the execute_command tool to run commands on the user's computer whenever you feel it can help accomplish the user's task. When you need to execute a CLI command, you must provide a clear explanation of what the command does. Prefer to execute complex CLI commands over creating executable scripts, since they are more flexible and easier to run. Interactive and long-running commands are allowed, since the commands are run in the user's VSCode terminal. The user may keep commands running in the background and you will be kept updated on their status along the way. Each command you execute is run in a new terminal instance.${
|
||||||
supportsComputerUse
|
supportsComputerUse
|
||||||
? "\n- You can use the browser_action tool to interact with websites (including html files and locally running development servers) through a Puppeteer-controlled browser when you feel it is necessary in accomplishing the user's task. This tool is particularly useful for web development tasks as it allows you to launch a browser, navigate to pages, interact with elements through clicks and keyboard input, and capture the results through screenshots and console logs. This tool may be useful at key stages of web development tasks-such as after implementing new features, making substantial changes, when troubleshooting issues, or to verify the result of your work. You can analyze the provided screenshots to ensure correct rendering or identify errors, and review console logs for runtime issues.\n - For example, if asked to add a component to a react website, you might create the necessary files, use execute_command to run the site locally, then use browser_action to launch the browser, navigate to the local server, and verify the component renders & functions correctly before closing the browser."
|
? "\n- You can use the browser_action tool to interact with websites (including html files and locally running development servers) through a Puppeteer-controlled browser when you feel it is necessary in accomplishing the user's task. This tool is particularly useful for web development tasks as it allows you to launch a browser, navigate to pages, interact with elements through clicks and keyboard input, and capture the results through screenshots and console logs. This tool may be useful at key stages of web development tasks-such as after implementing new features, making substantial changes, when troubleshooting issues, or to verify the result of your work. You can analyze the provided screenshots to ensure correct rendering or identify errors, and review console logs for runtime issues.\n - For example, if asked to add a component to a react website, you might create the necessary files, use execute_command to run the site locally, then use browser_action to launch the browser, navigate to the local server, and verify the component renders & functions correctly before closing the browser."
|
||||||
: ""
|
: ""
|
||||||
}${mcpHub ? `
|
}${
|
||||||
|
mcpHub
|
||||||
|
? `
|
||||||
- You have access to MCP servers that may provide additional tools and resources. Each server may provide different capabilities that you can use to accomplish tasks more effectively.
|
- You have access to MCP servers that may provide additional tools and resources. Each server may provide different capabilities that you can use to accomplish tasks more effectively.
|
||||||
` : ''}`
|
`
|
||||||
}
|
: ""
|
||||||
|
}`
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,46 +1,51 @@
|
|||||||
import fs from 'fs/promises'
|
import fs from "fs/promises"
|
||||||
import path from 'path'
|
import path from "path"
|
||||||
|
|
||||||
export async function loadRuleFiles(cwd: string): Promise<string> {
|
export async function loadRuleFiles(cwd: string): Promise<string> {
|
||||||
const ruleFiles = ['.clinerules', '.cursorrules', '.windsurfrules']
|
const ruleFiles = [".clinerules", ".cursorrules", ".windsurfrules"]
|
||||||
let combinedRules = ''
|
let combinedRules = ""
|
||||||
|
|
||||||
for (const file of ruleFiles) {
|
for (const file of ruleFiles) {
|
||||||
try {
|
try {
|
||||||
const content = await fs.readFile(path.join(cwd, file), 'utf-8')
|
const content = await fs.readFile(path.join(cwd, file), "utf-8")
|
||||||
if (content.trim()) {
|
if (content.trim()) {
|
||||||
combinedRules += `\n# Rules from ${file}:\n${content.trim()}\n`
|
combinedRules += `\n# Rules from ${file}:\n${content.trim()}\n`
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
// Silently skip if file doesn't exist
|
// Silently skip if file doesn't exist
|
||||||
if ((err as NodeJS.ErrnoException).code !== 'ENOENT') {
|
if ((err as NodeJS.ErrnoException).code !== "ENOENT") {
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return combinedRules
|
return combinedRules
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function addCustomInstructions(customInstructions: string, cwd: string, preferredLanguage?: string): Promise<string> {
|
export async function addCustomInstructions(
|
||||||
const ruleFileContent = await loadRuleFiles(cwd)
|
customInstructions: string,
|
||||||
const allInstructions = []
|
cwd: string,
|
||||||
|
preferredLanguage?: string,
|
||||||
|
): Promise<string> {
|
||||||
|
const ruleFileContent = await loadRuleFiles(cwd)
|
||||||
|
const allInstructions = []
|
||||||
|
|
||||||
if (preferredLanguage) {
|
if (preferredLanguage) {
|
||||||
allInstructions.push(`You should always speak and think in the ${preferredLanguage} language.`)
|
allInstructions.push(`You should always speak and think in the ${preferredLanguage} language.`)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (customInstructions.trim()) {
|
|
||||||
allInstructions.push(customInstructions.trim())
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ruleFileContent && ruleFileContent.trim()) {
|
if (customInstructions.trim()) {
|
||||||
allInstructions.push(ruleFileContent.trim())
|
allInstructions.push(customInstructions.trim())
|
||||||
}
|
}
|
||||||
|
|
||||||
const joinedInstructions = allInstructions.join('\n\n')
|
if (ruleFileContent && ruleFileContent.trim()) {
|
||||||
|
allInstructions.push(ruleFileContent.trim())
|
||||||
|
}
|
||||||
|
|
||||||
return joinedInstructions ? `
|
const joinedInstructions = allInstructions.join("\n\n")
|
||||||
|
|
||||||
|
return joinedInstructions
|
||||||
|
? `
|
||||||
====
|
====
|
||||||
|
|
||||||
USER'S CUSTOM INSTRUCTIONS
|
USER'S CUSTOM INSTRUCTIONS
|
||||||
@@ -48,5 +53,5 @@ USER'S CUSTOM INSTRUCTIONS
|
|||||||
The following additional instructions are provided by the user, and should be followed to the best of your ability without interfering with the TOOL USE guidelines.
|
The following additional instructions are provided by the user, and should be followed to the best of your ability without interfering with the TOOL USE guidelines.
|
||||||
|
|
||||||
${joinedInstructions}`
|
${joinedInstructions}`
|
||||||
: ""
|
: ""
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
export { getRulesSection } from './rules'
|
export { getRulesSection } from "./rules"
|
||||||
export { getSystemInfoSection } from './system-info'
|
export { getSystemInfoSection } from "./system-info"
|
||||||
export { getObjectiveSection } from './objective'
|
export { getObjectiveSection } from "./objective"
|
||||||
export { addCustomInstructions } from './custom-instructions'
|
export { addCustomInstructions } from "./custom-instructions"
|
||||||
export { getSharedToolUseSection } from './tool-use'
|
export { getSharedToolUseSection } from "./tool-use"
|
||||||
export { getMcpServersSection } from './mcp-servers'
|
export { getMcpServersSection } from "./mcp-servers"
|
||||||
export { getToolUseGuidelinesSection } from './tool-use-guidelines'
|
export { getToolUseGuidelinesSection } from "./tool-use-guidelines"
|
||||||
export { getCapabilitiesSection } from './capabilities'
|
export { getCapabilitiesSection } from "./capabilities"
|
||||||
|
|||||||
@@ -2,47 +2,48 @@ import { DiffStrategy } from "../../diff/DiffStrategy"
|
|||||||
import { McpHub } from "../../../services/mcp/McpHub"
|
import { McpHub } from "../../../services/mcp/McpHub"
|
||||||
|
|
||||||
export async function getMcpServersSection(mcpHub?: McpHub, diffStrategy?: DiffStrategy): Promise<string> {
|
export async function getMcpServersSection(mcpHub?: McpHub, diffStrategy?: DiffStrategy): Promise<string> {
|
||||||
if (!mcpHub) {
|
if (!mcpHub) {
|
||||||
return '';
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
const connectedServers = mcpHub.getServers().length > 0
|
const connectedServers =
|
||||||
? `${mcpHub
|
mcpHub.getServers().length > 0
|
||||||
.getServers()
|
? `${mcpHub
|
||||||
.filter((server) => server.status === "connected")
|
.getServers()
|
||||||
.map((server) => {
|
.filter((server) => server.status === "connected")
|
||||||
const tools = server.tools
|
.map((server) => {
|
||||||
?.map((tool) => {
|
const tools = server.tools
|
||||||
const schemaStr = tool.inputSchema
|
?.map((tool) => {
|
||||||
? ` Input Schema:
|
const schemaStr = tool.inputSchema
|
||||||
|
? ` Input Schema:
|
||||||
${JSON.stringify(tool.inputSchema, null, 2).split("\n").join("\n ")}`
|
${JSON.stringify(tool.inputSchema, null, 2).split("\n").join("\n ")}`
|
||||||
: ""
|
: ""
|
||||||
|
|
||||||
return `- ${tool.name}: ${tool.description}\n${schemaStr}`
|
return `- ${tool.name}: ${tool.description}\n${schemaStr}`
|
||||||
})
|
})
|
||||||
.join("\n\n")
|
.join("\n\n")
|
||||||
|
|
||||||
const templates = server.resourceTemplates
|
const templates = server.resourceTemplates
|
||||||
?.map((template) => `- ${template.uriTemplate} (${template.name}): ${template.description}`)
|
?.map((template) => `- ${template.uriTemplate} (${template.name}): ${template.description}`)
|
||||||
.join("\n")
|
.join("\n")
|
||||||
|
|
||||||
const resources = server.resources
|
const resources = server.resources
|
||||||
?.map((resource) => `- ${resource.uri} (${resource.name}): ${resource.description}`)
|
?.map((resource) => `- ${resource.uri} (${resource.name}): ${resource.description}`)
|
||||||
.join("\n")
|
.join("\n")
|
||||||
|
|
||||||
const config = JSON.parse(server.config)
|
const config = JSON.parse(server.config)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
`## ${server.name} (\`${config.command}${config.args && Array.isArray(config.args) ? ` ${config.args.join(" ")}` : ""}\`)` +
|
`## ${server.name} (\`${config.command}${config.args && Array.isArray(config.args) ? ` ${config.args.join(" ")}` : ""}\`)` +
|
||||||
(tools ? `\n\n### Available Tools\n${tools}` : "") +
|
(tools ? `\n\n### Available Tools\n${tools}` : "") +
|
||||||
(templates ? `\n\n### Resource Templates\n${templates}` : "") +
|
(templates ? `\n\n### Resource Templates\n${templates}` : "") +
|
||||||
(resources ? `\n\n### Direct Resources\n${resources}` : "")
|
(resources ? `\n\n### Direct Resources\n${resources}` : "")
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.join("\n\n")}`
|
.join("\n\n")}`
|
||||||
: "(No MCP servers currently connected)";
|
: "(No MCP servers currently connected)"
|
||||||
|
|
||||||
return `MCP SERVERS
|
return `MCP SERVERS
|
||||||
|
|
||||||
The Model Context Protocol (MCP) enables communication between the system and locally running MCP servers that provide additional tools and resources to extend your capabilities.
|
The Model Context Protocol (MCP) enables communication between the system and locally running MCP servers that provide additional tools and resources to extend your capabilities.
|
||||||
|
|
||||||
@@ -397,11 +398,11 @@ IMPORTANT: Regardless of what else you see in the MCP settings file, you must de
|
|||||||
## Editing MCP Servers
|
## Editing MCP Servers
|
||||||
|
|
||||||
The user may ask to add tools or resources that may make sense to add to an existing MCP server (listed under 'Connected MCP Servers' above: ${
|
The user may ask to add tools or resources that may make sense to add to an existing MCP server (listed under 'Connected MCP Servers' above: ${
|
||||||
mcpHub
|
mcpHub
|
||||||
.getServers()
|
.getServers()
|
||||||
.map((server) => server.name)
|
.map((server) => server.name)
|
||||||
.join(", ") || "(None running currently)"
|
.join(", ") || "(None running currently)"
|
||||||
}, e.g. if it would use the same API. This would be possible if you can locate the MCP server repository on the user's system by looking at the server arguments for a filepath. You might then use list_files and read_file to explore the files in the repository, and use write_to_file${diffStrategy ? " or apply_diff" : ""} to make changes to the files.
|
}, e.g. if it would use the same API. This would be possible if you can locate the MCP server repository on the user's system by looking at the server arguments for a filepath. You might then use list_files and read_file to explore the files in the repository, and use write_to_file${diffStrategy ? " or apply_diff" : ""} to make changes to the files.
|
||||||
|
|
||||||
However some MCP servers may be running from installed packages rather than a local repository, in which case it may make more sense to create a new MCP server.
|
However some MCP servers may be running from installed packages rather than a local repository, in which case it may make more sense to create a new MCP server.
|
||||||
|
|
||||||
@@ -410,4 +411,4 @@ However some MCP servers may be running from installed packages rather than a lo
|
|||||||
The user may not always request the use or creation of MCP servers. Instead, they might provide tasks that can be completed with existing tools. While using the MCP SDK to extend your capabilities can be useful, it's important to understand that this is just one specialized type of task you can accomplish. You should only implement MCP servers when the user explicitly requests it (e.g., "add a tool that...").
|
The user may not always request the use or creation of MCP servers. Instead, they might provide tasks that can be completed with existing tools. While using the MCP SDK to extend your capabilities can be useful, it's important to understand that this is just one specialized type of task you can accomplish. You should only implement MCP servers when the user explicitly requests it (e.g., "add a tool that...").
|
||||||
|
|
||||||
Remember: The MCP documentation and example provided above are to help you understand and work with existing MCP servers or create new ones when requested by the user. You already have access to tools and capabilities that can be used to accomplish a wide range of tasks.`
|
Remember: The MCP documentation and example provided above are to help you understand and work with existing MCP servers or create new ones when requested by the user. You already have access to tools and capabilities that can be used to accomplish a wide range of tasks.`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
export function getObjectiveSection(): string {
|
export function getObjectiveSection(): string {
|
||||||
return `====
|
return `====
|
||||||
|
|
||||||
OBJECTIVE
|
OBJECTIVE
|
||||||
|
|
||||||
@@ -10,4 +10,4 @@ You accomplish a given task iteratively, breaking it down into clear steps and w
|
|||||||
3. Remember, you have extensive capabilities with access to a wide range of tools that can be used in powerful and clever ways as necessary to accomplish each goal. Before calling a tool, do some analysis within <thinking></thinking> tags. First, analyze the file structure provided in environment_details to gain context and insights for proceeding effectively. Then, think about which of the provided tools is the most relevant tool to accomplish the user's task. Next, go through each of the required parameters of the relevant tool and determine if the user has directly provided or given enough information to infer a value. When deciding if the parameter can be inferred, carefully consider all the context to see if it supports a specific value. If all of the required parameters are present or can be reasonably inferred, close the thinking tag and proceed with the tool use. BUT, if one of the values for a required parameter is missing, DO NOT invoke the tool (not even with fillers for the missing params) and instead, ask the user to provide the missing parameters using the ask_followup_question tool. DO NOT ask for more information on optional parameters if it is not provided.
|
3. Remember, you have extensive capabilities with access to a wide range of tools that can be used in powerful and clever ways as necessary to accomplish each goal. Before calling a tool, do some analysis within <thinking></thinking> tags. First, analyze the file structure provided in environment_details to gain context and insights for proceeding effectively. Then, think about which of the provided tools is the most relevant tool to accomplish the user's task. Next, go through each of the required parameters of the relevant tool and determine if the user has directly provided or given enough information to infer a value. When deciding if the parameter can be inferred, carefully consider all the context to see if it supports a specific value. If all of the required parameters are present or can be reasonably inferred, close the thinking tag and proceed with the tool use. BUT, if one of the values for a required parameter is missing, DO NOT invoke the tool (not even with fillers for the missing params) and instead, ask the user to provide the missing parameters using the ask_followup_question tool. DO NOT ask for more information on optional parameters if it is not provided.
|
||||||
4. Once you've completed the user's task, you must use the attempt_completion tool to present the result of the task to the user. You may also provide a CLI command to showcase the result of your task; this can be particularly useful for web development tasks, where you can run e.g. \`open index.html\` to show the website you've built.
|
4. Once you've completed the user's task, you must use the attempt_completion tool to present the result of the task to the user. You may also provide a CLI command to showcase the result of your task; this can be particularly useful for web development tasks, where you can run e.g. \`open index.html\` to show the website you've built.
|
||||||
5. The user may provide feedback, which you can use to make improvements and try again. But DO NOT continue in pointless back and forth conversations, i.e. don't end your responses with questions or offers for further assistance.`
|
5. The user may provide feedback, which you can use to make improvements and try again. But DO NOT continue in pointless back and forth conversations, i.e. don't end your responses with questions or offers for further assistance.`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,7 @@
|
|||||||
import { DiffStrategy } from "../../diff/DiffStrategy"
|
import { DiffStrategy } from "../../diff/DiffStrategy"
|
||||||
|
|
||||||
export function getRulesSection(
|
export function getRulesSection(cwd: string, supportsComputerUse: boolean, diffStrategy?: DiffStrategy): string {
|
||||||
cwd: string,
|
return `====
|
||||||
supportsComputerUse: boolean,
|
|
||||||
diffStrategy?: DiffStrategy
|
|
||||||
): string {
|
|
||||||
return `====
|
|
||||||
|
|
||||||
RULES
|
RULES
|
||||||
|
|
||||||
@@ -23,10 +19,10 @@ ${diffStrategy ? "- You should use apply_diff instead of write_to_file when maki
|
|||||||
- When executing commands, if you don't see the expected output, assume the terminal executed the command successfully and proceed with the task. The user's terminal may be unable to stream the output back properly. If you absolutely need to see the actual terminal output, use the ask_followup_question tool to request the user to copy and paste it back to you.
|
- When executing commands, if you don't see the expected output, assume the terminal executed the command successfully and proceed with the task. The user's terminal may be unable to stream the output back properly. If you absolutely need to see the actual terminal output, use the ask_followup_question tool to request the user to copy and paste it back to you.
|
||||||
- The user may provide a file's contents directly in their message, in which case you shouldn't use the read_file tool to get the file contents again since you already have it.
|
- The user may provide a file's contents directly in their message, in which case you shouldn't use the read_file tool to get the file contents again since you already have it.
|
||||||
- Your goal is to try to accomplish the user's task, NOT engage in a back and forth conversation.${
|
- Your goal is to try to accomplish the user's task, NOT engage in a back and forth conversation.${
|
||||||
supportsComputerUse
|
supportsComputerUse
|
||||||
? '\n- The user may ask generic non-development tasks, such as "what\'s the latest news" or "look up the weather in San Diego", in which case you might use the browser_action tool to complete the task if it makes sense to do so, rather than trying to create a website or using curl to answer the question. However, if an available MCP server tool or resource can be used instead, you should prefer to use it over browser_action.'
|
? '\n- The user may ask generic non-development tasks, such as "what\'s the latest news" or "look up the weather in San Diego", in which case you might use the browser_action tool to complete the task if it makes sense to do so, rather than trying to create a website or using curl to answer the question. However, if an available MCP server tool or resource can be used instead, you should prefer to use it over browser_action.'
|
||||||
: ""
|
: ""
|
||||||
}
|
}
|
||||||
- NEVER end attempt_completion result with a question or request to engage in further conversation! Formulate the end of your result in a way that is final and does not require further input from the user.
|
- NEVER end attempt_completion result with a question or request to engage in further conversation! Formulate the end of your result in a way that is final and does not require further input from the user.
|
||||||
- You are STRICTLY FORBIDDEN from starting your messages with "Great", "Certainly", "Okay", "Sure". You should NOT be conversational in your responses, but rather direct and to the point. For example you should NOT say "Great, I've updated the CSS" but instead something like "I've updated the CSS". It is important you be clear and technical in your messages.
|
- You are STRICTLY FORBIDDEN from starting your messages with "Great", "Certainly", "Okay", "Sure". You should NOT be conversational in your responses, but rather direct and to the point. For example you should NOT say "Great, I've updated the CSS" but instead something like "I've updated the CSS". It is important you be clear and technical in your messages.
|
||||||
- When presented with images, utilize your vision capabilities to thoroughly examine them and extract meaningful information. Incorporate these insights into your thought process as you accomplish the user's task.
|
- When presented with images, utilize your vision capabilities to thoroughly examine them and extract meaningful information. Incorporate these insights into your thought process as you accomplish the user's task.
|
||||||
@@ -35,8 +31,8 @@ ${diffStrategy ? "- You should use apply_diff instead of write_to_file when maki
|
|||||||
- When using the write_to_file tool, ALWAYS provide the COMPLETE file content in your response. This is NON-NEGOTIABLE. Partial updates or placeholders like '// rest of code unchanged' are STRICTLY FORBIDDEN. You MUST include ALL parts of the file, even if they haven't been modified. Failure to do so will result in incomplete or broken code, severely impacting the user's project.
|
- When using the write_to_file tool, ALWAYS provide the COMPLETE file content in your response. This is NON-NEGOTIABLE. Partial updates or placeholders like '// rest of code unchanged' are STRICTLY FORBIDDEN. You MUST include ALL parts of the file, even if they haven't been modified. Failure to do so will result in incomplete or broken code, severely impacting the user's project.
|
||||||
- MCP operations should be used one at a time, similar to other tool usage. Wait for confirmation of success before proceeding with additional operations.
|
- MCP operations should be used one at a time, similar to other tool usage. Wait for confirmation of success before proceeding with additional operations.
|
||||||
- It is critical you wait for the user's response after each tool use, in order to confirm the success of the tool use. For example, if asked to make a todo app, you would create a file, wait for the user's response it was created successfully, then create another file if needed, wait for the user's response it was created successfully, etc.${
|
- It is critical you wait for the user's response after each tool use, in order to confirm the success of the tool use. For example, if asked to make a todo app, you would create a file, wait for the user's response it was created successfully, then create another file if needed, wait for the user's response it was created successfully, etc.${
|
||||||
supportsComputerUse
|
supportsComputerUse
|
||||||
? " Then if you want to test your work, you might use browser_action to launch the site, wait for the user's response confirming the site was launched along with a screenshot, then perhaps e.g., click a button to test functionality if needed, wait for the user's response confirming the button was clicked along with a screenshot of the new state, before finally closing the browser."
|
? " Then if you want to test your work, you might use browser_action to launch the site, wait for the user's response confirming the site was launched along with a screenshot, then perhaps e.g., click a button to test functionality if needed, wait for the user's response confirming the button was clicked along with a screenshot of the new state, before finally closing the browser."
|
||||||
: ""
|
: ""
|
||||||
}`
|
}`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os from "os"
|
|||||||
import osName from "os-name"
|
import osName from "os-name"
|
||||||
|
|
||||||
export function getSystemInfoSection(cwd: string): string {
|
export function getSystemInfoSection(cwd: string): string {
|
||||||
return `====
|
return `====
|
||||||
|
|
||||||
SYSTEM INFORMATION
|
SYSTEM INFORMATION
|
||||||
|
|
||||||
@@ -13,4 +13,4 @@ Home Directory: ${os.homedir().toPosix()}
|
|||||||
Current Working Directory: ${cwd.toPosix()}
|
Current Working Directory: ${cwd.toPosix()}
|
||||||
|
|
||||||
When the user initially gives you a task, a recursive list of all filepaths in the current working directory ('/test/path') will be included in environment_details. This provides an overview of the project's file structure, offering key insights into the project from directory/file names (how developers conceptualize and organize their code) and file extensions (the language used). This can also guide decision-making on which files to explore further. If you need to further explore directories such as outside the current working directory, you can use the list_files tool. If you pass 'true' for the recursive parameter, it will list files recursively. Otherwise, it will list files at the top level, which is better suited for generic directories where you don't necessarily need the nested structure, like the Desktop.`
|
When the user initially gives you a task, a recursive list of all filepaths in the current working directory ('/test/path') will be included in environment_details. This provides an overview of the project's file structure, offering key insights into the project from directory/file names (how developers conceptualize and organize their code) and file extensions (the language used). This can also guide decision-making on which files to explore further. If you need to further explore directories such as outside the current working directory, you can use the list_files tool. If you pass 'true' for the recursive parameter, it will list files recursively. Otherwise, it will list files at the top level, which is better suited for generic directories where you don't necessarily need the nested structure, like the Desktop.`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
export function getToolUseGuidelinesSection(): string {
|
export function getToolUseGuidelinesSection(): string {
|
||||||
return `# Tool Use Guidelines
|
return `# Tool Use Guidelines
|
||||||
|
|
||||||
1. In <thinking> tags, assess what information you already have and what information you need to proceed with the task.
|
1. In <thinking> tags, assess what information you already have and what information you need to proceed with the task.
|
||||||
2. Choose the most appropriate tool based on the task and the tool descriptions provided. Assess if you need additional information to proceed, and which of the available tools would be most effective for gathering this information. For example using the list_files tool is more effective than running a command like \`ls\` in the terminal. It's critical that you think about each available tool and use the one that best fits the current step in the task.
|
2. Choose the most appropriate tool based on the task and the tool descriptions provided. Assess if you need additional information to proceed, and which of the available tools would be most effective for gathering this information. For example using the list_files tool is more effective than running a command like \`ls\` in the terminal. It's critical that you think about each available tool and use the one that best fits the current step in the task.
|
||||||
@@ -19,4 +19,4 @@ It is crucial to proceed step-by-step, waiting for the user's message after each
|
|||||||
4. Ensure that each action builds correctly on the previous ones.
|
4. Ensure that each action builds correctly on the previous ones.
|
||||||
|
|
||||||
By waiting for and carefully considering the user's response after each tool use, you can react accordingly and make informed decisions about how to proceed with the task. This iterative process helps ensure the overall success and accuracy of your work.`
|
By waiting for and carefully considering the user's response after each tool use, you can react accordingly and make informed decisions about how to proceed with the task. This iterative process helps ensure the overall success and accuracy of your work.`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
export function getSharedToolUseSection(): string {
|
export function getSharedToolUseSection(): string {
|
||||||
return `====
|
return `====
|
||||||
|
|
||||||
TOOL USE
|
TOOL USE
|
||||||
|
|
||||||
@@ -22,4 +22,4 @@ For example:
|
|||||||
</read_file>
|
</read_file>
|
||||||
|
|
||||||
Always adhere to this format for the tool use to ensure proper parsing and execution.`
|
Always adhere to this format for the tool use to ensure proper parsing and execution.`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,87 +3,84 @@ import { DiffStrategy } from "../diff/DiffStrategy"
|
|||||||
import { McpHub } from "../../services/mcp/McpHub"
|
import { McpHub } from "../../services/mcp/McpHub"
|
||||||
import { getToolDescriptionsForMode } from "./tools"
|
import { getToolDescriptionsForMode } from "./tools"
|
||||||
import {
|
import {
|
||||||
getRulesSection,
|
getRulesSection,
|
||||||
getSystemInfoSection,
|
getSystemInfoSection,
|
||||||
getObjectiveSection,
|
getObjectiveSection,
|
||||||
getSharedToolUseSection,
|
getSharedToolUseSection,
|
||||||
getMcpServersSection,
|
getMcpServersSection,
|
||||||
getToolUseGuidelinesSection,
|
getToolUseGuidelinesSection,
|
||||||
getCapabilitiesSection
|
getCapabilitiesSection,
|
||||||
} from "./sections"
|
} from "./sections"
|
||||||
import fs from 'fs/promises'
|
import fs from "fs/promises"
|
||||||
import path from 'path'
|
import path from "path"
|
||||||
|
|
||||||
async function loadRuleFiles(cwd: string, mode: Mode): Promise<string> {
|
async function loadRuleFiles(cwd: string, mode: Mode): Promise<string> {
|
||||||
let combinedRules = ''
|
let combinedRules = ""
|
||||||
|
|
||||||
// First try mode-specific rules
|
// First try mode-specific rules
|
||||||
const modeSpecificFile = `.clinerules-${mode}`
|
const modeSpecificFile = `.clinerules-${mode}`
|
||||||
try {
|
try {
|
||||||
const content = await fs.readFile(path.join(cwd, modeSpecificFile), 'utf-8')
|
const content = await fs.readFile(path.join(cwd, modeSpecificFile), "utf-8")
|
||||||
if (content.trim()) {
|
if (content.trim()) {
|
||||||
combinedRules += `\n# Rules from ${modeSpecificFile}:\n${content.trim()}\n`
|
combinedRules += `\n# Rules from ${modeSpecificFile}:\n${content.trim()}\n`
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
// Silently skip if file doesn't exist
|
// Silently skip if file doesn't exist
|
||||||
if ((err as NodeJS.ErrnoException).code !== 'ENOENT') {
|
if ((err as NodeJS.ErrnoException).code !== "ENOENT") {
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then try generic rules files
|
// Then try generic rules files
|
||||||
const genericRuleFiles = ['.clinerules']
|
const genericRuleFiles = [".clinerules"]
|
||||||
for (const file of genericRuleFiles) {
|
for (const file of genericRuleFiles) {
|
||||||
try {
|
try {
|
||||||
const content = await fs.readFile(path.join(cwd, file), 'utf-8')
|
const content = await fs.readFile(path.join(cwd, file), "utf-8")
|
||||||
if (content.trim()) {
|
if (content.trim()) {
|
||||||
combinedRules += `\n# Rules from ${file}:\n${content.trim()}\n`
|
combinedRules += `\n# Rules from ${file}:\n${content.trim()}\n`
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
// Silently skip if file doesn't exist
|
// Silently skip if file doesn't exist
|
||||||
if ((err as NodeJS.ErrnoException).code !== 'ENOENT') {
|
if ((err as NodeJS.ErrnoException).code !== "ENOENT") {
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return combinedRules
|
return combinedRules
|
||||||
}
|
}
|
||||||
|
|
||||||
interface State {
|
interface State {
|
||||||
customInstructions?: string;
|
customInstructions?: string
|
||||||
customPrompts?: CustomPrompts;
|
customPrompts?: CustomPrompts
|
||||||
preferredLanguage?: string;
|
preferredLanguage?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function addCustomInstructions(
|
export async function addCustomInstructions(state: State, cwd: string, mode: Mode = defaultModeSlug): Promise<string> {
|
||||||
state: State,
|
const ruleFileContent = await loadRuleFiles(cwd, mode)
|
||||||
cwd: string,
|
const allInstructions = []
|
||||||
mode: Mode = defaultModeSlug
|
|
||||||
): Promise<string> {
|
|
||||||
const ruleFileContent = await loadRuleFiles(cwd, mode)
|
|
||||||
const allInstructions = []
|
|
||||||
|
|
||||||
if (state.preferredLanguage) {
|
if (state.preferredLanguage) {
|
||||||
allInstructions.push(`You should always speak and think in the ${state.preferredLanguage} language.`)
|
allInstructions.push(`You should always speak and think in the ${state.preferredLanguage} language.`)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (state.customInstructions?.trim()) {
|
if (state.customInstructions?.trim()) {
|
||||||
allInstructions.push(state.customInstructions.trim())
|
allInstructions.push(state.customInstructions.trim())
|
||||||
}
|
}
|
||||||
|
|
||||||
const customPrompt = state.customPrompts?.[mode]
|
const customPrompt = state.customPrompts?.[mode]
|
||||||
if (typeof customPrompt === 'object' && customPrompt?.customInstructions?.trim()) {
|
if (typeof customPrompt === "object" && customPrompt?.customInstructions?.trim()) {
|
||||||
allInstructions.push(customPrompt.customInstructions.trim())
|
allInstructions.push(customPrompt.customInstructions.trim())
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ruleFileContent && ruleFileContent.trim()) {
|
if (ruleFileContent && ruleFileContent.trim()) {
|
||||||
allInstructions.push(ruleFileContent.trim())
|
allInstructions.push(ruleFileContent.trim())
|
||||||
}
|
}
|
||||||
|
|
||||||
const joinedInstructions = allInstructions.join('\n\n')
|
const joinedInstructions = allInstructions.join("\n\n")
|
||||||
|
|
||||||
return joinedInstructions ? `
|
return joinedInstructions
|
||||||
|
? `
|
||||||
====
|
====
|
||||||
|
|
||||||
USER'S CUSTOM INSTRUCTIONS
|
USER'S CUSTOM INSTRUCTIONS
|
||||||
@@ -91,19 +88,19 @@ USER'S CUSTOM INSTRUCTIONS
|
|||||||
The following additional instructions are provided by the user, and should be followed to the best of your ability without interfering with the TOOL USE guidelines.
|
The following additional instructions are provided by the user, and should be followed to the best of your ability without interfering with the TOOL USE guidelines.
|
||||||
|
|
||||||
${joinedInstructions}`
|
${joinedInstructions}`
|
||||||
: ""
|
: ""
|
||||||
}
|
}
|
||||||
|
|
||||||
async function generatePrompt(
|
async function generatePrompt(
|
||||||
cwd: string,
|
cwd: string,
|
||||||
supportsComputerUse: boolean,
|
supportsComputerUse: boolean,
|
||||||
mode: Mode,
|
mode: Mode,
|
||||||
mcpHub?: McpHub,
|
mcpHub?: McpHub,
|
||||||
diffStrategy?: DiffStrategy,
|
diffStrategy?: DiffStrategy,
|
||||||
browserViewportSize?: string,
|
browserViewportSize?: string,
|
||||||
promptComponent?: PromptComponent,
|
promptComponent?: PromptComponent,
|
||||||
): Promise<string> {
|
): Promise<string> {
|
||||||
const basePrompt = `${promptComponent?.roleDefinition || getRoleDefinition(mode)}
|
const basePrompt = `${promptComponent?.roleDefinition || getRoleDefinition(mode)}
|
||||||
|
|
||||||
${getSharedToolUseSection()}
|
${getSharedToolUseSection()}
|
||||||
|
|
||||||
@@ -119,38 +116,38 @@ ${getRulesSection(cwd, supportsComputerUse, diffStrategy)}
|
|||||||
|
|
||||||
${getSystemInfoSection(cwd)}
|
${getSystemInfoSection(cwd)}
|
||||||
|
|
||||||
${getObjectiveSection()}`;
|
${getObjectiveSection()}`
|
||||||
|
|
||||||
return basePrompt;
|
return basePrompt
|
||||||
}
|
}
|
||||||
|
|
||||||
export const SYSTEM_PROMPT = async (
|
export const SYSTEM_PROMPT = async (
|
||||||
cwd: string,
|
cwd: string,
|
||||||
supportsComputerUse: boolean,
|
supportsComputerUse: boolean,
|
||||||
mcpHub?: McpHub,
|
mcpHub?: McpHub,
|
||||||
diffStrategy?: DiffStrategy,
|
diffStrategy?: DiffStrategy,
|
||||||
browserViewportSize?: string,
|
browserViewportSize?: string,
|
||||||
mode: Mode = defaultModeSlug,
|
mode: Mode = defaultModeSlug,
|
||||||
customPrompts?: CustomPrompts,
|
customPrompts?: CustomPrompts,
|
||||||
) => {
|
) => {
|
||||||
const getPromptComponent = (value: unknown) => {
|
const getPromptComponent = (value: unknown) => {
|
||||||
if (typeof value === 'object' && value !== null) {
|
if (typeof value === "object" && value !== null) {
|
||||||
return value as PromptComponent;
|
return value as PromptComponent
|
||||||
}
|
}
|
||||||
return undefined;
|
return undefined
|
||||||
};
|
}
|
||||||
|
|
||||||
// Use default mode if not found
|
// Use default mode if not found
|
||||||
const currentMode = modes.find(m => m.slug === mode) || modes[0];
|
const currentMode = modes.find((m) => m.slug === mode) || modes[0]
|
||||||
const promptComponent = getPromptComponent(customPrompts?.[currentMode.slug]);
|
const promptComponent = getPromptComponent(customPrompts?.[currentMode.slug])
|
||||||
|
|
||||||
return generatePrompt(
|
return generatePrompt(
|
||||||
cwd,
|
cwd,
|
||||||
supportsComputerUse,
|
supportsComputerUse,
|
||||||
currentMode.slug,
|
currentMode.slug,
|
||||||
mcpHub,
|
mcpHub,
|
||||||
diffStrategy,
|
diffStrategy,
|
||||||
browserViewportSize,
|
browserViewportSize,
|
||||||
promptComponent
|
promptComponent,
|
||||||
);
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import { ToolArgs } from './types';
|
import { ToolArgs } from "./types"
|
||||||
|
|
||||||
export function getAccessMcpResourceDescription(args: ToolArgs): string | undefined {
|
export function getAccessMcpResourceDescription(args: ToolArgs): string | undefined {
|
||||||
if (!args.mcpHub) {
|
if (!args.mcpHub) {
|
||||||
return undefined;
|
return undefined
|
||||||
}
|
}
|
||||||
return `## access_mcp_resource
|
return `## access_mcp_resource
|
||||||
Description: Request to access a resource provided by a connected MCP server. Resources represent data sources that can be used as context, such as files, API responses, or system information.
|
Description: Request to access a resource provided by a connected MCP server. Resources represent data sources that can be used as context, such as files, API responses, or system information.
|
||||||
Parameters:
|
Parameters:
|
||||||
- server_name: (required) The name of the MCP server providing the resource
|
- server_name: (required) The name of the MCP server providing the resource
|
||||||
@@ -21,4 +21,4 @@ Example: Requesting to access an MCP resource
|
|||||||
<server_name>weather-server</server_name>
|
<server_name>weather-server</server_name>
|
||||||
<uri>weather://san-francisco/current</uri>
|
<uri>weather://san-francisco/current</uri>
|
||||||
</access_mcp_resource>`
|
</access_mcp_resource>`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
export function getAskFollowupQuestionDescription(): string {
|
export function getAskFollowupQuestionDescription(): string {
|
||||||
return `## ask_followup_question
|
return `## ask_followup_question
|
||||||
Description: Ask the user a question to gather additional information needed to complete the task. This tool should be used when you encounter ambiguities, need clarification, or require more details to proceed effectively. It allows for interactive problem-solving by enabling direct communication with the user. Use this tool judiciously to maintain a balance between gathering necessary information and avoiding excessive back-and-forth.
|
Description: Ask the user a question to gather additional information needed to complete the task. This tool should be used when you encounter ambiguities, need clarification, or require more details to proceed effectively. It allows for interactive problem-solving by enabling direct communication with the user. Use this tool judiciously to maintain a balance between gathering necessary information and avoiding excessive back-and-forth.
|
||||||
Parameters:
|
Parameters:
|
||||||
- question: (required) The question to ask the user. This should be a clear, specific question that addresses the information you need.
|
- question: (required) The question to ask the user. This should be a clear, specific question that addresses the information you need.
|
||||||
@@ -12,4 +12,4 @@ Example: Requesting to ask the user for the path to the frontend-config.json fil
|
|||||||
<ask_followup_question>
|
<ask_followup_question>
|
||||||
<question>What is the path to the frontend-config.json file?</question>
|
<question>What is the path to the frontend-config.json file?</question>
|
||||||
</ask_followup_question>`
|
</ask_followup_question>`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
export function getAttemptCompletionDescription(): string {
|
export function getAttemptCompletionDescription(): string {
|
||||||
return `## attempt_completion
|
return `## attempt_completion
|
||||||
Description: After each tool use, the user will respond with the result of that tool use, i.e. if it succeeded or failed, along with any reasons for failure. Once you've received the results of tool uses and can confirm that the task is complete, use this tool to present the result of your work to the user. Optionally you may provide a CLI command to showcase the result of your work. The user may respond with feedback if they are not satisfied with the result, which you can use to make improvements and try again.
|
Description: After each tool use, the user will respond with the result of that tool use, i.e. if it succeeded or failed, along with any reasons for failure. Once you've received the results of tool uses and can confirm that the task is complete, use this tool to present the result of your work to the user. Optionally you may provide a CLI command to showcase the result of your work. The user may respond with feedback if they are not satisfied with the result, which you can use to make improvements and try again.
|
||||||
IMPORTANT NOTE: This tool CANNOT be used until you've confirmed from the user that any previous tool uses were successful. Failure to do so will result in code corruption and system failure. Before using this tool, you must ask yourself in <thinking></thinking> tags if you've confirmed from the user that any previous tool uses were successful. If not, then DO NOT use this tool.
|
IMPORTANT NOTE: This tool CANNOT be used until you've confirmed from the user that any previous tool uses were successful. Failure to do so will result in code corruption and system failure. Before using this tool, you must ask yourself in <thinking></thinking> tags if you've confirmed from the user that any previous tool uses were successful. If not, then DO NOT use this tool.
|
||||||
Parameters:
|
Parameters:
|
||||||
@@ -20,4 +20,4 @@ I've updated the CSS
|
|||||||
</result>
|
</result>
|
||||||
<command>open index.html</command>
|
<command>open index.html</command>
|
||||||
</attempt_completion>`
|
</attempt_completion>`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import { ToolArgs } from './types';
|
import { ToolArgs } from "./types"
|
||||||
|
|
||||||
export function getBrowserActionDescription(args: ToolArgs): string | undefined {
|
export function getBrowserActionDescription(args: ToolArgs): string | undefined {
|
||||||
if (!args.supportsComputerUse) {
|
if (!args.supportsComputerUse) {
|
||||||
return undefined;
|
return undefined
|
||||||
}
|
}
|
||||||
return `## browser_action
|
return `## browser_action
|
||||||
Description: Request to interact with a Puppeteer-controlled browser. Every action, except \`close\`, will be responded to with a screenshot of the browser's current state, along with any new console logs. You may only perform one browser action per message, and wait for the user's response including a screenshot and logs to determine the next action.
|
Description: Request to interact with a Puppeteer-controlled browser. Every action, except \`close\`, will be responded to with a screenshot of the browser's current state, along with any new console logs. You may only perform one browser action per message, and wait for the user's response including a screenshot and logs to determine the next action.
|
||||||
- The sequence of actions **must always start with** launching the browser at a URL, and **must always end with** closing the browser. If you need to visit a new URL that is not possible to navigate to from the current webpage, you must first close the browser, then launch again at the new URL.
|
- The sequence of actions **must always start with** launching the browser at a URL, and **must always end with** closing the browser. If you need to visit a new URL that is not possible to navigate to from the current webpage, you must first close the browser, then launch again at the new URL.
|
||||||
- While the browser is active, only the \`browser_action\` tool can be used. No other tools should be called during this time. You may proceed to use other tools only after closing the browser. For example if you run into an error and need to fix a file, you must close the browser, then use other tools to make the necessary changes, then re-launch the browser to verify the result.
|
- While the browser is active, only the \`browser_action\` tool can be used. No other tools should be called during this time. You may proceed to use other tools only after closing the browser. For example if you run into an error and need to fix a file, you must close the browser, then use other tools to make the necessary changes, then re-launch the browser to verify the result.
|
||||||
@@ -49,4 +49,4 @@ Example: Requesting to click on the element at coordinates 450,300
|
|||||||
<action>click</action>
|
<action>click</action>
|
||||||
<coordinate>450,300</coordinate>
|
<coordinate>450,300</coordinate>
|
||||||
</browser_action>`
|
</browser_action>`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { ToolArgs } from './types';
|
import { ToolArgs } from "./types"
|
||||||
|
|
||||||
export function getExecuteCommandDescription(args: ToolArgs): string | undefined {
|
export function getExecuteCommandDescription(args: ToolArgs): string | undefined {
|
||||||
return `## execute_command
|
return `## execute_command
|
||||||
Description: Request to execute a CLI command on the system. Use this when you need to perform system operations or run specific commands to accomplish any step in the user's task. You must tailor your command to the user's system and provide a clear explanation of what the command does. Prefer to execute complex CLI commands over creating executable scripts, as they are more flexible and easier to run. Commands will be executed in the current working directory: ${args.cwd}
|
Description: Request to execute a CLI command on the system. Use this when you need to perform system operations or run specific commands to accomplish any step in the user's task. You must tailor your command to the user's system and provide a clear explanation of what the command does. Prefer to execute complex CLI commands over creating executable scripts, as they are more flexible and easier to run. Commands will be executed in the current working directory: ${args.cwd}
|
||||||
Parameters:
|
Parameters:
|
||||||
- command: (required) The CLI command to execute. This should be valid for the current operating system. Ensure the command is properly formatted and does not contain any harmful instructions.
|
- command: (required) The CLI command to execute. This should be valid for the current operating system. Ensure the command is properly formatted and does not contain any harmful instructions.
|
||||||
@@ -14,4 +14,4 @@ Example: Requesting to execute npm run dev
|
|||||||
<execute_command>
|
<execute_command>
|
||||||
<command>npm run dev</command>
|
<command>npm run dev</command>
|
||||||
</execute_command>`
|
</execute_command>`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,79 +1,80 @@
|
|||||||
import { getExecuteCommandDescription } from './execute-command'
|
import { getExecuteCommandDescription } from "./execute-command"
|
||||||
import { getReadFileDescription } from './read-file'
|
import { getReadFileDescription } from "./read-file"
|
||||||
import { getWriteToFileDescription } from './write-to-file'
|
import { getWriteToFileDescription } from "./write-to-file"
|
||||||
import { getSearchFilesDescription } from './search-files'
|
import { getSearchFilesDescription } from "./search-files"
|
||||||
import { getListFilesDescription } from './list-files'
|
import { getListFilesDescription } from "./list-files"
|
||||||
import { getListCodeDefinitionNamesDescription } from './list-code-definition-names'
|
import { getListCodeDefinitionNamesDescription } from "./list-code-definition-names"
|
||||||
import { getBrowserActionDescription } from './browser-action'
|
import { getBrowserActionDescription } from "./browser-action"
|
||||||
import { getAskFollowupQuestionDescription } from './ask-followup-question'
|
import { getAskFollowupQuestionDescription } from "./ask-followup-question"
|
||||||
import { getAttemptCompletionDescription } from './attempt-completion'
|
import { getAttemptCompletionDescription } from "./attempt-completion"
|
||||||
import { getUseMcpToolDescription } from './use-mcp-tool'
|
import { getUseMcpToolDescription } from "./use-mcp-tool"
|
||||||
import { getAccessMcpResourceDescription } from './access-mcp-resource'
|
import { getAccessMcpResourceDescription } from "./access-mcp-resource"
|
||||||
import { DiffStrategy } from '../../diff/DiffStrategy'
|
import { DiffStrategy } from "../../diff/DiffStrategy"
|
||||||
import { McpHub } from '../../../services/mcp/McpHub'
|
import { McpHub } from "../../../services/mcp/McpHub"
|
||||||
import { Mode, ToolName, getModeConfig, isToolAllowedForMode } from '../../../shared/modes'
|
import { Mode, ToolName, getModeConfig, isToolAllowedForMode } from "../../../shared/modes"
|
||||||
import { ToolArgs } from './types'
|
import { ToolArgs } from "./types"
|
||||||
|
|
||||||
// Map of tool names to their description functions
|
// Map of tool names to their description functions
|
||||||
const toolDescriptionMap: Record<string, (args: ToolArgs) => string | undefined> = {
|
const toolDescriptionMap: Record<string, (args: ToolArgs) => string | undefined> = {
|
||||||
'execute_command': args => getExecuteCommandDescription(args),
|
execute_command: (args) => getExecuteCommandDescription(args),
|
||||||
'read_file': args => getReadFileDescription(args),
|
read_file: (args) => getReadFileDescription(args),
|
||||||
'write_to_file': args => getWriteToFileDescription(args),
|
write_to_file: (args) => getWriteToFileDescription(args),
|
||||||
'search_files': args => getSearchFilesDescription(args),
|
search_files: (args) => getSearchFilesDescription(args),
|
||||||
'list_files': args => getListFilesDescription(args),
|
list_files: (args) => getListFilesDescription(args),
|
||||||
'list_code_definition_names': args => getListCodeDefinitionNamesDescription(args),
|
list_code_definition_names: (args) => getListCodeDefinitionNamesDescription(args),
|
||||||
'browser_action': args => getBrowserActionDescription(args),
|
browser_action: (args) => getBrowserActionDescription(args),
|
||||||
'ask_followup_question': () => getAskFollowupQuestionDescription(),
|
ask_followup_question: () => getAskFollowupQuestionDescription(),
|
||||||
'attempt_completion': () => getAttemptCompletionDescription(),
|
attempt_completion: () => getAttemptCompletionDescription(),
|
||||||
'use_mcp_tool': args => getUseMcpToolDescription(args),
|
use_mcp_tool: (args) => getUseMcpToolDescription(args),
|
||||||
'access_mcp_resource': args => getAccessMcpResourceDescription(args),
|
access_mcp_resource: (args) => getAccessMcpResourceDescription(args),
|
||||||
'apply_diff': args => args.diffStrategy ? args.diffStrategy.getToolDescription({ cwd: args.cwd, toolOptions: args.toolOptions }) : ''
|
apply_diff: (args) =>
|
||||||
};
|
args.diffStrategy ? args.diffStrategy.getToolDescription({ cwd: args.cwd, toolOptions: args.toolOptions }) : "",
|
||||||
|
}
|
||||||
|
|
||||||
export function getToolDescriptionsForMode(
|
export function getToolDescriptionsForMode(
|
||||||
mode: Mode,
|
mode: Mode,
|
||||||
cwd: string,
|
cwd: string,
|
||||||
supportsComputerUse: boolean,
|
supportsComputerUse: boolean,
|
||||||
diffStrategy?: DiffStrategy,
|
diffStrategy?: DiffStrategy,
|
||||||
browserViewportSize?: string,
|
browserViewportSize?: string,
|
||||||
mcpHub?: McpHub
|
mcpHub?: McpHub,
|
||||||
): string {
|
): string {
|
||||||
const config = getModeConfig(mode);
|
const config = getModeConfig(mode)
|
||||||
const args: ToolArgs = {
|
const args: ToolArgs = {
|
||||||
cwd,
|
cwd,
|
||||||
supportsComputerUse,
|
supportsComputerUse,
|
||||||
diffStrategy,
|
diffStrategy,
|
||||||
browserViewportSize,
|
browserViewportSize,
|
||||||
mcpHub
|
mcpHub,
|
||||||
};
|
}
|
||||||
|
|
||||||
// Map tool descriptions in the exact order specified in the mode's tools array
|
// Map tool descriptions in the exact order specified in the mode's tools array
|
||||||
const descriptions = config.tools.map(([toolName, toolOptions]) => {
|
const descriptions = config.tools.map(([toolName, toolOptions]) => {
|
||||||
const descriptionFn = toolDescriptionMap[toolName];
|
const descriptionFn = toolDescriptionMap[toolName]
|
||||||
if (!descriptionFn || !isToolAllowedForMode(toolName as ToolName, mode)) {
|
if (!descriptionFn || !isToolAllowedForMode(toolName as ToolName, mode)) {
|
||||||
return undefined;
|
return undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
return descriptionFn({
|
return descriptionFn({
|
||||||
...args,
|
...args,
|
||||||
toolOptions
|
toolOptions,
|
||||||
});
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
return `# Tools\n\n${descriptions.filter(Boolean).join('\n\n')}`;
|
return `# Tools\n\n${descriptions.filter(Boolean).join("\n\n")}`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Export individual description functions for backward compatibility
|
// Export individual description functions for backward compatibility
|
||||||
export {
|
export {
|
||||||
getExecuteCommandDescription,
|
getExecuteCommandDescription,
|
||||||
getReadFileDescription,
|
getReadFileDescription,
|
||||||
getWriteToFileDescription,
|
getWriteToFileDescription,
|
||||||
getSearchFilesDescription,
|
getSearchFilesDescription,
|
||||||
getListFilesDescription,
|
getListFilesDescription,
|
||||||
getListCodeDefinitionNamesDescription,
|
getListCodeDefinitionNamesDescription,
|
||||||
getBrowserActionDescription,
|
getBrowserActionDescription,
|
||||||
getAskFollowupQuestionDescription,
|
getAskFollowupQuestionDescription,
|
||||||
getAttemptCompletionDescription,
|
getAttemptCompletionDescription,
|
||||||
getUseMcpToolDescription,
|
getUseMcpToolDescription,
|
||||||
getAccessMcpResourceDescription
|
getAccessMcpResourceDescription,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { ToolArgs } from './types';
|
import { ToolArgs } from "./types"
|
||||||
|
|
||||||
export function getListCodeDefinitionNamesDescription(args: ToolArgs): string {
|
export function getListCodeDefinitionNamesDescription(args: ToolArgs): string {
|
||||||
return `## list_code_definition_names
|
return `## list_code_definition_names
|
||||||
Description: Request to list definition names (classes, functions, methods, etc.) used in source code files at the top level of the specified directory. This tool provides insights into the codebase structure and important constructs, encapsulating high-level concepts and relationships that are crucial for understanding the overall architecture.
|
Description: Request to list definition names (classes, functions, methods, etc.) used in source code files at the top level of the specified directory. This tool provides insights into the codebase structure and important constructs, encapsulating high-level concepts and relationships that are crucial for understanding the overall architecture.
|
||||||
Parameters:
|
Parameters:
|
||||||
- path: (required) The path of the directory (relative to the current working directory ${args.cwd}) to list top level source code definitions for.
|
- path: (required) The path of the directory (relative to the current working directory ${args.cwd}) to list top level source code definitions for.
|
||||||
@@ -14,4 +14,4 @@ Example: Requesting to list all top level source code definitions in the current
|
|||||||
<list_code_definition_names>
|
<list_code_definition_names>
|
||||||
<path>.</path>
|
<path>.</path>
|
||||||
</list_code_definition_names>`
|
</list_code_definition_names>`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { ToolArgs } from './types';
|
import { ToolArgs } from "./types"
|
||||||
|
|
||||||
export function getListFilesDescription(args: ToolArgs): string {
|
export function getListFilesDescription(args: ToolArgs): string {
|
||||||
return `## list_files
|
return `## list_files
|
||||||
Description: Request to list files and directories within the specified directory. If recursive is true, it will list all files and directories recursively. If recursive is false or not provided, it will only list the top-level contents. Do not use this tool to confirm the existence of files you may have created, as the user will let you know if the files were created successfully or not.
|
Description: Request to list files and directories within the specified directory. If recursive is true, it will list all files and directories recursively. If recursive is false or not provided, it will only list the top-level contents. Do not use this tool to confirm the existence of files you may have created, as the user will let you know if the files were created successfully or not.
|
||||||
Parameters:
|
Parameters:
|
||||||
- path: (required) The path of the directory to list contents for (relative to the current working directory ${args.cwd})
|
- path: (required) The path of the directory to list contents for (relative to the current working directory ${args.cwd})
|
||||||
@@ -17,4 +17,4 @@ Example: Requesting to list all files in the current directory
|
|||||||
<path>.</path>
|
<path>.</path>
|
||||||
<recursive>false</recursive>
|
<recursive>false</recursive>
|
||||||
</list_files>`
|
</list_files>`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { ToolArgs } from './types';
|
import { ToolArgs } from "./types"
|
||||||
|
|
||||||
export function getReadFileDescription(args: ToolArgs): string {
|
export function getReadFileDescription(args: ToolArgs): string {
|
||||||
return `## read_file
|
return `## read_file
|
||||||
Description: Request to read the contents of a file at the specified path. Use this when you need to examine the contents of an existing file you do not know the contents of, for example to analyze code, review text files, or extract information from configuration files. The output includes line numbers prefixed to each line (e.g. "1 | const x = 1"), making it easier to reference specific lines when creating diffs or discussing code. Automatically extracts raw text from PDF and DOCX files. May not be suitable for other types of binary files, as it returns the raw content as a string.
|
Description: Request to read the contents of a file at the specified path. Use this when you need to examine the contents of an existing file you do not know the contents of, for example to analyze code, review text files, or extract information from configuration files. The output includes line numbers prefixed to each line (e.g. "1 | const x = 1"), making it easier to reference specific lines when creating diffs or discussing code. Automatically extracts raw text from PDF and DOCX files. May not be suitable for other types of binary files, as it returns the raw content as a string.
|
||||||
Parameters:
|
Parameters:
|
||||||
- path: (required) The path of the file to read (relative to the current working directory ${args.cwd})
|
- path: (required) The path of the file to read (relative to the current working directory ${args.cwd})
|
||||||
@@ -14,4 +14,4 @@ Example: Requesting to read frontend-config.json
|
|||||||
<read_file>
|
<read_file>
|
||||||
<path>frontend-config.json</path>
|
<path>frontend-config.json</path>
|
||||||
</read_file>`
|
</read_file>`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { ToolArgs } from './types';
|
import { ToolArgs } from "./types"
|
||||||
|
|
||||||
export function getSearchFilesDescription(args: ToolArgs): string {
|
export function getSearchFilesDescription(args: ToolArgs): string {
|
||||||
return `## search_files
|
return `## search_files
|
||||||
Description: Request to perform a regex search across files in a specified directory, providing context-rich results. This tool searches for patterns or specific content across multiple files, displaying each match with encapsulating context.
|
Description: Request to perform a regex search across files in a specified directory, providing context-rich results. This tool searches for patterns or specific content across multiple files, displaying each match with encapsulating context.
|
||||||
Parameters:
|
Parameters:
|
||||||
- path: (required) The path of the directory to search in (relative to the current working directory ${args.cwd}). This directory will be recursively searched.
|
- path: (required) The path of the directory to search in (relative to the current working directory ${args.cwd}). This directory will be recursively searched.
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import { DiffStrategy } from '../../diff/DiffStrategy'
|
import { DiffStrategy } from "../../diff/DiffStrategy"
|
||||||
import { McpHub } from '../../../services/mcp/McpHub'
|
import { McpHub } from "../../../services/mcp/McpHub"
|
||||||
|
|
||||||
export type ToolArgs = {
|
export type ToolArgs = {
|
||||||
cwd: string;
|
cwd: string
|
||||||
supportsComputerUse: boolean;
|
supportsComputerUse: boolean
|
||||||
diffStrategy?: DiffStrategy;
|
diffStrategy?: DiffStrategy
|
||||||
browserViewportSize?: string;
|
browserViewportSize?: string
|
||||||
mcpHub?: McpHub;
|
mcpHub?: McpHub
|
||||||
toolOptions?: any;
|
toolOptions?: any
|
||||||
};
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import { ToolArgs } from './types';
|
import { ToolArgs } from "./types"
|
||||||
|
|
||||||
export function getUseMcpToolDescription(args: ToolArgs): string | undefined {
|
export function getUseMcpToolDescription(args: ToolArgs): string | undefined {
|
||||||
if (!args.mcpHub) {
|
if (!args.mcpHub) {
|
||||||
return undefined;
|
return undefined
|
||||||
}
|
}
|
||||||
return `## use_mcp_tool
|
return `## use_mcp_tool
|
||||||
Description: Request to use a tool provided by a connected MCP server. Each MCP server can provide multiple tools with different capabilities. Tools have defined input schemas that specify required and optional parameters.
|
Description: Request to use a tool provided by a connected MCP server. Each MCP server can provide multiple tools with different capabilities. Tools have defined input schemas that specify required and optional parameters.
|
||||||
Parameters:
|
Parameters:
|
||||||
- server_name: (required) The name of the MCP server providing the tool
|
- server_name: (required) The name of the MCP server providing the tool
|
||||||
@@ -34,4 +34,4 @@ Example: Requesting to use an MCP tool
|
|||||||
}
|
}
|
||||||
</arguments>
|
</arguments>
|
||||||
</use_mcp_tool>`
|
</use_mcp_tool>`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { ToolArgs } from './types';
|
import { ToolArgs } from "./types"
|
||||||
|
|
||||||
export function getWriteToFileDescription(args: ToolArgs): string {
|
export function getWriteToFileDescription(args: ToolArgs): string {
|
||||||
return `## write_to_file
|
return `## write_to_file
|
||||||
Description: Request to write full content to a file at the specified path. If the file exists, it will be overwritten with the provided content. If the file doesn't exist, it will be created. This tool will automatically create any directories needed to write the file.
|
Description: Request to write full content to a file at the specified path. If the file exists, it will be overwritten with the provided content. If the file doesn't exist, it will be created. This tool will automatically create any directories needed to write the file.
|
||||||
Parameters:
|
Parameters:
|
||||||
- path: (required) The path of the file to write to (relative to the current working directory ${args.cwd})
|
- path: (required) The path of the file to write to (relative to the current working directory ${args.cwd})
|
||||||
|
|||||||
@@ -1,52 +1,52 @@
|
|||||||
import { Mode } from '../../shared/modes';
|
import { Mode } from "../../shared/modes"
|
||||||
|
|
||||||
export type { Mode };
|
export type { Mode }
|
||||||
|
|
||||||
export type ToolName =
|
export type ToolName =
|
||||||
| 'execute_command'
|
| "execute_command"
|
||||||
| 'read_file'
|
| "read_file"
|
||||||
| 'write_to_file'
|
| "write_to_file"
|
||||||
| 'apply_diff'
|
| "apply_diff"
|
||||||
| 'search_files'
|
| "search_files"
|
||||||
| 'list_files'
|
| "list_files"
|
||||||
| 'list_code_definition_names'
|
| "list_code_definition_names"
|
||||||
| 'browser_action'
|
| "browser_action"
|
||||||
| 'use_mcp_tool'
|
| "use_mcp_tool"
|
||||||
| 'access_mcp_resource'
|
| "access_mcp_resource"
|
||||||
| 'ask_followup_question'
|
| "ask_followup_question"
|
||||||
| 'attempt_completion';
|
| "attempt_completion"
|
||||||
|
|
||||||
export const CODE_TOOLS: ToolName[] = [
|
export const CODE_TOOLS: ToolName[] = [
|
||||||
'execute_command',
|
"execute_command",
|
||||||
'read_file',
|
"read_file",
|
||||||
'write_to_file',
|
"write_to_file",
|
||||||
'apply_diff',
|
"apply_diff",
|
||||||
'search_files',
|
"search_files",
|
||||||
'list_files',
|
"list_files",
|
||||||
'list_code_definition_names',
|
"list_code_definition_names",
|
||||||
'browser_action',
|
"browser_action",
|
||||||
'use_mcp_tool',
|
"use_mcp_tool",
|
||||||
'access_mcp_resource',
|
"access_mcp_resource",
|
||||||
'ask_followup_question',
|
"ask_followup_question",
|
||||||
'attempt_completion'
|
"attempt_completion",
|
||||||
];
|
]
|
||||||
|
|
||||||
export const ARCHITECT_TOOLS: ToolName[] = [
|
export const ARCHITECT_TOOLS: ToolName[] = [
|
||||||
'read_file',
|
"read_file",
|
||||||
'search_files',
|
"search_files",
|
||||||
'list_files',
|
"list_files",
|
||||||
'list_code_definition_names',
|
"list_code_definition_names",
|
||||||
'ask_followup_question',
|
"ask_followup_question",
|
||||||
'attempt_completion'
|
"attempt_completion",
|
||||||
];
|
]
|
||||||
|
|
||||||
export const ASK_TOOLS: ToolName[] = [
|
export const ASK_TOOLS: ToolName[] = [
|
||||||
'read_file',
|
"read_file",
|
||||||
'search_files',
|
"search_files",
|
||||||
'list_files',
|
"list_files",
|
||||||
'browser_action',
|
"browser_action",
|
||||||
'use_mcp_tool',
|
"use_mcp_tool",
|
||||||
'access_mcp_resource',
|
"access_mcp_resource",
|
||||||
'ask_followup_question',
|
"ask_followup_question",
|
||||||
'attempt_completion'
|
"attempt_completion",
|
||||||
];
|
]
|
||||||
|
|||||||
@@ -1,32 +1,32 @@
|
|||||||
// Shared tools for architect and ask modes - read-only operations plus MCP and browser tools
|
// Shared tools for architect and ask modes - read-only operations plus MCP and browser tools
|
||||||
export const READONLY_ALLOWED_TOOLS = [
|
export const READONLY_ALLOWED_TOOLS = [
|
||||||
'read_file',
|
"read_file",
|
||||||
'search_files',
|
"search_files",
|
||||||
'list_files',
|
"list_files",
|
||||||
'list_code_definition_names',
|
"list_code_definition_names",
|
||||||
'browser_action',
|
"browser_action",
|
||||||
'use_mcp_tool',
|
"use_mcp_tool",
|
||||||
'access_mcp_resource',
|
"access_mcp_resource",
|
||||||
'ask_followup_question',
|
"ask_followup_question",
|
||||||
'attempt_completion'
|
"attempt_completion",
|
||||||
] as const;
|
] as const
|
||||||
|
|
||||||
// Code mode has access to all tools
|
// Code mode has access to all tools
|
||||||
export const CODE_ALLOWED_TOOLS = [
|
export const CODE_ALLOWED_TOOLS = [
|
||||||
'execute_command',
|
"execute_command",
|
||||||
'read_file',
|
"read_file",
|
||||||
'write_to_file',
|
"write_to_file",
|
||||||
'apply_diff',
|
"apply_diff",
|
||||||
'search_files',
|
"search_files",
|
||||||
'list_files',
|
"list_files",
|
||||||
'list_code_definition_names',
|
"list_code_definition_names",
|
||||||
'browser_action',
|
"browser_action",
|
||||||
'use_mcp_tool',
|
"use_mcp_tool",
|
||||||
'access_mcp_resource',
|
"access_mcp_resource",
|
||||||
'ask_followup_question',
|
"ask_followup_question",
|
||||||
'attempt_completion'
|
"attempt_completion",
|
||||||
] as const;
|
] as const
|
||||||
|
|
||||||
// Tool name types for type safety
|
// Tool name types for type safety
|
||||||
export type ReadOnlyToolName = typeof READONLY_ALLOWED_TOOLS[number];
|
export type ReadOnlyToolName = (typeof READONLY_ALLOWED_TOOLS)[number]
|
||||||
export type ToolName = typeof CODE_ALLOWED_TOOLS[number];
|
export type ToolName = (typeof CODE_ALLOWED_TOOLS)[number]
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ type GlobalStateKey =
|
|||||||
| "modeApiConfigs"
|
| "modeApiConfigs"
|
||||||
| "customPrompts"
|
| "customPrompts"
|
||||||
| "enhancementApiConfigId"
|
| "enhancementApiConfigId"
|
||||||
| "experimentalDiffStrategy"
|
| "experimentalDiffStrategy"
|
||||||
| "autoApprovalEnabled"
|
| "autoApprovalEnabled"
|
||||||
|
|
||||||
export const GlobalFileNames = {
|
export const GlobalFileNames = {
|
||||||
@@ -254,14 +254,12 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
fuzzyMatchThreshold,
|
fuzzyMatchThreshold,
|
||||||
mode,
|
mode,
|
||||||
customInstructions: globalInstructions,
|
customInstructions: globalInstructions,
|
||||||
experimentalDiffStrategy
|
experimentalDiffStrategy,
|
||||||
} = await this.getState()
|
} = await this.getState()
|
||||||
|
|
||||||
const modePrompt = customPrompts?.[mode]
|
const modePrompt = customPrompts?.[mode]
|
||||||
const modeInstructions = typeof modePrompt === 'object' ? modePrompt.customInstructions : undefined
|
const modeInstructions = typeof modePrompt === "object" ? modePrompt.customInstructions : undefined
|
||||||
const effectiveInstructions = [globalInstructions, modeInstructions]
|
const effectiveInstructions = [globalInstructions, modeInstructions].filter(Boolean).join("\n\n")
|
||||||
.filter(Boolean)
|
|
||||||
.join('\n\n')
|
|
||||||
|
|
||||||
this.cline = new Cline(
|
this.cline = new Cline(
|
||||||
this,
|
this,
|
||||||
@@ -272,7 +270,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
task,
|
task,
|
||||||
images,
|
images,
|
||||||
undefined,
|
undefined,
|
||||||
experimentalDiffStrategy
|
experimentalDiffStrategy,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,14 +283,12 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
fuzzyMatchThreshold,
|
fuzzyMatchThreshold,
|
||||||
mode,
|
mode,
|
||||||
customInstructions: globalInstructions,
|
customInstructions: globalInstructions,
|
||||||
experimentalDiffStrategy
|
experimentalDiffStrategy,
|
||||||
} = await this.getState()
|
} = await this.getState()
|
||||||
|
|
||||||
const modePrompt = customPrompts?.[mode]
|
const modePrompt = customPrompts?.[mode]
|
||||||
const modeInstructions = typeof modePrompt === 'object' ? modePrompt.customInstructions : undefined
|
const modeInstructions = typeof modePrompt === "object" ? modePrompt.customInstructions : undefined
|
||||||
const effectiveInstructions = [globalInstructions, modeInstructions]
|
const effectiveInstructions = [globalInstructions, modeInstructions].filter(Boolean).join("\n\n")
|
||||||
.filter(Boolean)
|
|
||||||
.join('\n\n')
|
|
||||||
|
|
||||||
this.cline = new Cline(
|
this.cline = new Cline(
|
||||||
this,
|
this,
|
||||||
@@ -303,7 +299,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
undefined,
|
undefined,
|
||||||
undefined,
|
undefined,
|
||||||
historyItem,
|
historyItem,
|
||||||
experimentalDiffStrategy
|
experimentalDiffStrategy,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -403,7 +399,6 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
async (message: WebviewMessage) => {
|
async (message: WebviewMessage) => {
|
||||||
switch (message.type) {
|
switch (message.type) {
|
||||||
case "webviewDidLaunch":
|
case "webviewDidLaunch":
|
||||||
|
|
||||||
this.postStateToWebview()
|
this.postStateToWebview()
|
||||||
this.workspaceTracker?.initializeFilePaths() // don't await
|
this.workspaceTracker?.initializeFilePaths() // don't await
|
||||||
getTheme().then((theme) =>
|
getTheme().then((theme) =>
|
||||||
@@ -450,53 +445,53 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
this.configManager
|
||||||
this.configManager.ListConfig().then(async (listApiConfig) => {
|
.ListConfig()
|
||||||
|
.then(async (listApiConfig) => {
|
||||||
if (!listApiConfig) {
|
if (!listApiConfig) {
|
||||||
return
|
return
|
||||||
}
|
|
||||||
|
|
||||||
if (listApiConfig.length === 1) {
|
|
||||||
// check if first time init then sync with exist config
|
|
||||||
if (!checkExistKey(listApiConfig[0])) {
|
|
||||||
const {
|
|
||||||
apiConfiguration,
|
|
||||||
} = await this.getState()
|
|
||||||
await this.configManager.SaveConfig(listApiConfig[0].name ?? "default", apiConfiguration)
|
|
||||||
listApiConfig[0].apiProvider = apiConfiguration.apiProvider
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
let currentConfigName = await this.getGlobalState("currentApiConfigName") as string
|
if (listApiConfig.length === 1) {
|
||||||
|
// check if first time init then sync with exist config
|
||||||
if (currentConfigName) {
|
if (!checkExistKey(listApiConfig[0])) {
|
||||||
if (!await this.configManager.HasConfig(currentConfigName)) {
|
const { apiConfiguration } = await this.getState()
|
||||||
// current config name not valid, get first config in list
|
await this.configManager.SaveConfig(
|
||||||
await this.updateGlobalState("currentApiConfigName", listApiConfig?.[0]?.name)
|
listApiConfig[0].name ?? "default",
|
||||||
if (listApiConfig?.[0]?.name) {
|
apiConfiguration,
|
||||||
const apiConfig = await this.configManager.LoadConfig(listApiConfig?.[0]?.name);
|
)
|
||||||
|
listApiConfig[0].apiProvider = apiConfiguration.apiProvider
|
||||||
await Promise.all([
|
|
||||||
this.updateGlobalState("listApiConfigMeta", listApiConfig),
|
|
||||||
this.postMessageToWebview({ type: "listApiConfig", listApiConfig }),
|
|
||||||
this.updateApiConfiguration(apiConfig),
|
|
||||||
])
|
|
||||||
await this.postStateToWebview()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
let currentConfigName = (await this.getGlobalState("currentApiConfigName")) as string
|
||||||
|
|
||||||
await Promise.all(
|
if (currentConfigName) {
|
||||||
[
|
if (!(await this.configManager.HasConfig(currentConfigName))) {
|
||||||
|
// current config name not valid, get first config in list
|
||||||
|
await this.updateGlobalState("currentApiConfigName", listApiConfig?.[0]?.name)
|
||||||
|
if (listApiConfig?.[0]?.name) {
|
||||||
|
const apiConfig = await this.configManager.LoadConfig(
|
||||||
|
listApiConfig?.[0]?.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
await Promise.all([
|
||||||
|
this.updateGlobalState("listApiConfigMeta", listApiConfig),
|
||||||
|
this.postMessageToWebview({ type: "listApiConfig", listApiConfig }),
|
||||||
|
this.updateApiConfiguration(apiConfig),
|
||||||
|
])
|
||||||
|
await this.postStateToWebview()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await Promise.all([
|
||||||
await this.updateGlobalState("listApiConfigMeta", listApiConfig),
|
await this.updateGlobalState("listApiConfigMeta", listApiConfig),
|
||||||
await this.postMessageToWebview({ type: "listApiConfig", listApiConfig })
|
await this.postMessageToWebview({ type: "listApiConfig", listApiConfig }),
|
||||||
]
|
])
|
||||||
)
|
})
|
||||||
}).catch(console.error);
|
.catch(console.error)
|
||||||
|
|
||||||
break
|
break
|
||||||
case "newTask":
|
case "newTask":
|
||||||
@@ -593,7 +588,10 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
break
|
break
|
||||||
case "refreshOpenAiModels":
|
case "refreshOpenAiModels":
|
||||||
if (message?.values?.baseUrl && message?.values?.apiKey) {
|
if (message?.values?.baseUrl && message?.values?.apiKey) {
|
||||||
const openAiModels = await this.getOpenAiModels(message?.values?.baseUrl, message?.values?.apiKey)
|
const openAiModels = await this.getOpenAiModels(
|
||||||
|
message?.values?.baseUrl,
|
||||||
|
message?.values?.apiKey,
|
||||||
|
)
|
||||||
this.postMessageToWebview({ type: "openAiModels", openAiModels })
|
this.postMessageToWebview({ type: "openAiModels", openAiModels })
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
@@ -625,12 +623,12 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
|
|
||||||
break
|
break
|
||||||
case "allowedCommands":
|
case "allowedCommands":
|
||||||
await this.context.globalState.update('allowedCommands', message.commands);
|
await this.context.globalState.update("allowedCommands", message.commands)
|
||||||
// Also update workspace settings
|
// Also update workspace settings
|
||||||
await vscode.workspace
|
await vscode.workspace
|
||||||
.getConfiguration('roo-cline')
|
.getConfiguration("roo-cline")
|
||||||
.update('allowedCommands', message.commands, vscode.ConfigurationTarget.Global);
|
.update("allowedCommands", message.commands, vscode.ConfigurationTarget.Global)
|
||||||
break;
|
break
|
||||||
case "openMcpSettings": {
|
case "openMcpSettings": {
|
||||||
const mcpSettingsFilePath = await this.mcpHub?.getMcpSettingsFilePath()
|
const mcpSettingsFilePath = await this.mcpHub?.getMcpSettingsFilePath()
|
||||||
if (mcpSettingsFilePath) {
|
if (mcpSettingsFilePath) {
|
||||||
@@ -651,7 +649,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
await this.mcpHub?.toggleToolAlwaysAllow(
|
await this.mcpHub?.toggleToolAlwaysAllow(
|
||||||
message.serverName!,
|
message.serverName!,
|
||||||
message.toolName!,
|
message.toolName!,
|
||||||
message.alwaysAllow!
|
message.alwaysAllow!,
|
||||||
)
|
)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`Failed to toggle auto-approve for tool ${message.toolName}:`, error)
|
console.error(`Failed to toggle auto-approve for tool ${message.toolName}:`, error)
|
||||||
@@ -660,10 +658,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
}
|
}
|
||||||
case "toggleMcpServer": {
|
case "toggleMcpServer": {
|
||||||
try {
|
try {
|
||||||
await this.mcpHub?.toggleServerDisabled(
|
await this.mcpHub?.toggleServerDisabled(message.serverName!, message.disabled!)
|
||||||
message.serverName!,
|
|
||||||
message.disabled!
|
|
||||||
)
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`Failed to toggle MCP server ${message.serverName}:`, error)
|
console.error(`Failed to toggle MCP server ${message.serverName}:`, error)
|
||||||
}
|
}
|
||||||
@@ -683,7 +678,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
case "soundEnabled":
|
case "soundEnabled":
|
||||||
const soundEnabled = message.bool ?? true
|
const soundEnabled = message.bool ?? true
|
||||||
await this.updateGlobalState("soundEnabled", soundEnabled)
|
await this.updateGlobalState("soundEnabled", soundEnabled)
|
||||||
setSoundEnabled(soundEnabled) // Add this line to update the sound utility
|
setSoundEnabled(soundEnabled) // Add this line to update the sound utility
|
||||||
await this.postStateToWebview()
|
await this.postStateToWebview()
|
||||||
break
|
break
|
||||||
case "soundVolume":
|
case "soundVolume":
|
||||||
@@ -729,84 +724,84 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
case "mode":
|
case "mode":
|
||||||
const newMode = message.text as Mode
|
const newMode = message.text as Mode
|
||||||
await this.updateGlobalState("mode", newMode)
|
await this.updateGlobalState("mode", newMode)
|
||||||
|
|
||||||
// Load the saved API config for the new mode if it exists
|
// Load the saved API config for the new mode if it exists
|
||||||
const savedConfigId = await this.configManager.GetModeConfigId(newMode)
|
const savedConfigId = await this.configManager.GetModeConfigId(newMode)
|
||||||
const listApiConfig = await this.configManager.ListConfig()
|
const listApiConfig = await this.configManager.ListConfig()
|
||||||
|
|
||||||
// Update listApiConfigMeta first to ensure UI has latest data
|
// Update listApiConfigMeta first to ensure UI has latest data
|
||||||
await this.updateGlobalState("listApiConfigMeta", listApiConfig)
|
await this.updateGlobalState("listApiConfigMeta", listApiConfig)
|
||||||
|
|
||||||
// If this mode has a saved config, use it
|
// If this mode has a saved config, use it
|
||||||
if (savedConfigId) {
|
if (savedConfigId) {
|
||||||
const config = listApiConfig?.find(c => c.id === savedConfigId)
|
const config = listApiConfig?.find((c) => c.id === savedConfigId)
|
||||||
if (config?.name) {
|
if (config?.name) {
|
||||||
const apiConfig = await this.configManager.LoadConfig(config.name)
|
const apiConfig = await this.configManager.LoadConfig(config.name)
|
||||||
await Promise.all([
|
await Promise.all([
|
||||||
this.updateGlobalState("currentApiConfigName", config.name),
|
this.updateGlobalState("currentApiConfigName", config.name),
|
||||||
this.updateApiConfiguration(apiConfig)
|
this.updateApiConfiguration(apiConfig),
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// If no saved config for this mode, save current config as default
|
// If no saved config for this mode, save current config as default
|
||||||
const currentApiConfigName = await this.getGlobalState("currentApiConfigName")
|
const currentApiConfigName = await this.getGlobalState("currentApiConfigName")
|
||||||
if (currentApiConfigName) {
|
if (currentApiConfigName) {
|
||||||
const config = listApiConfig?.find(c => c.name === currentApiConfigName)
|
const config = listApiConfig?.find((c) => c.name === currentApiConfigName)
|
||||||
if (config?.id) {
|
if (config?.id) {
|
||||||
await this.configManager.SetModeConfig(newMode, config.id)
|
await this.configManager.SetModeConfig(newMode, config.id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
await this.postStateToWebview()
|
await this.postStateToWebview()
|
||||||
break
|
break
|
||||||
case "updateEnhancedPrompt":
|
case "updateEnhancedPrompt":
|
||||||
const existingPrompts = await this.getGlobalState("customPrompts") || {}
|
const existingPrompts = (await this.getGlobalState("customPrompts")) || {}
|
||||||
|
|
||||||
const updatedPrompts = {
|
const updatedPrompts = {
|
||||||
...existingPrompts,
|
...existingPrompts,
|
||||||
enhance: message.text
|
enhance: message.text,
|
||||||
}
|
}
|
||||||
|
|
||||||
await this.updateGlobalState("customPrompts", updatedPrompts)
|
await this.updateGlobalState("customPrompts", updatedPrompts)
|
||||||
|
|
||||||
// Get current state and explicitly include customPrompts
|
// Get current state and explicitly include customPrompts
|
||||||
const currentState = await this.getState()
|
const currentState = await this.getState()
|
||||||
|
|
||||||
const stateWithPrompts = {
|
const stateWithPrompts = {
|
||||||
...currentState,
|
...currentState,
|
||||||
customPrompts: updatedPrompts
|
customPrompts: updatedPrompts,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Post state with prompts
|
// Post state with prompts
|
||||||
this.view?.webview.postMessage({
|
this.view?.webview.postMessage({
|
||||||
type: "state",
|
type: "state",
|
||||||
state: stateWithPrompts
|
state: stateWithPrompts,
|
||||||
})
|
})
|
||||||
break
|
break
|
||||||
case "updatePrompt":
|
case "updatePrompt":
|
||||||
if (message.promptMode && message.customPrompt !== undefined) {
|
if (message.promptMode && message.customPrompt !== undefined) {
|
||||||
const existingPrompts = await this.getGlobalState("customPrompts") || {}
|
const existingPrompts = (await this.getGlobalState("customPrompts")) || {}
|
||||||
|
|
||||||
const updatedPrompts = {
|
const updatedPrompts = {
|
||||||
...existingPrompts,
|
...existingPrompts,
|
||||||
[message.promptMode]: message.customPrompt
|
[message.promptMode]: message.customPrompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
await this.updateGlobalState("customPrompts", updatedPrompts)
|
await this.updateGlobalState("customPrompts", updatedPrompts)
|
||||||
|
|
||||||
// Get current state and explicitly include customPrompts
|
// Get current state and explicitly include customPrompts
|
||||||
const currentState = await this.getState()
|
const currentState = await this.getState()
|
||||||
|
|
||||||
const stateWithPrompts = {
|
const stateWithPrompts = {
|
||||||
...currentState,
|
...currentState,
|
||||||
customPrompts: updatedPrompts
|
customPrompts: updatedPrompts,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Post state with prompts
|
// Post state with prompts
|
||||||
this.view?.webview.postMessage({
|
this.view?.webview.postMessage({
|
||||||
type: "state",
|
type: "state",
|
||||||
state: stateWithPrompts
|
state: stateWithPrompts,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
@@ -817,60 +812,79 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
"Just this message",
|
"Just this message",
|
||||||
"This and all subsequent messages",
|
"This and all subsequent messages",
|
||||||
)
|
)
|
||||||
if ((answer === "Just this message" || answer === "This and all subsequent messages") &&
|
if (
|
||||||
this.cline && typeof message.value === 'number' && message.value) {
|
(answer === "Just this message" || answer === "This and all subsequent messages") &&
|
||||||
const timeCutoff = message.value - 1000; // 1 second buffer before the message to delete
|
this.cline &&
|
||||||
const messageIndex = this.cline.clineMessages.findIndex(msg => msg.ts && msg.ts >= timeCutoff)
|
typeof message.value === "number" &&
|
||||||
const apiConversationHistoryIndex = this.cline.apiConversationHistory.findIndex(msg => msg.ts && msg.ts >= timeCutoff)
|
message.value
|
||||||
|
) {
|
||||||
|
const timeCutoff = message.value - 1000 // 1 second buffer before the message to delete
|
||||||
|
const messageIndex = this.cline.clineMessages.findIndex(
|
||||||
|
(msg) => msg.ts && msg.ts >= timeCutoff,
|
||||||
|
)
|
||||||
|
const apiConversationHistoryIndex = this.cline.apiConversationHistory.findIndex(
|
||||||
|
(msg) => msg.ts && msg.ts >= timeCutoff,
|
||||||
|
)
|
||||||
|
|
||||||
if (messageIndex !== -1) {
|
if (messageIndex !== -1) {
|
||||||
const { historyItem } = await this.getTaskWithId(this.cline.taskId)
|
const { historyItem } = await this.getTaskWithId(this.cline.taskId)
|
||||||
|
|
||||||
if (answer === "Just this message") {
|
if (answer === "Just this message") {
|
||||||
// Find the next user message first
|
// Find the next user message first
|
||||||
const nextUserMessage = this.cline.clineMessages
|
const nextUserMessage = this.cline.clineMessages
|
||||||
.slice(messageIndex + 1)
|
.slice(messageIndex + 1)
|
||||||
.find(msg => msg.type === "say" && msg.say === "user_feedback")
|
.find((msg) => msg.type === "say" && msg.say === "user_feedback")
|
||||||
|
|
||||||
// Handle UI messages
|
// Handle UI messages
|
||||||
if (nextUserMessage) {
|
if (nextUserMessage) {
|
||||||
// Find absolute index of next user message
|
// Find absolute index of next user message
|
||||||
const nextUserMessageIndex = this.cline.clineMessages.findIndex(msg => msg === nextUserMessage)
|
const nextUserMessageIndex = this.cline.clineMessages.findIndex(
|
||||||
|
(msg) => msg === nextUserMessage,
|
||||||
|
)
|
||||||
// Keep messages before current message and after next user message
|
// Keep messages before current message and after next user message
|
||||||
await this.cline.overwriteClineMessages([
|
await this.cline.overwriteClineMessages([
|
||||||
...this.cline.clineMessages.slice(0, messageIndex),
|
...this.cline.clineMessages.slice(0, messageIndex),
|
||||||
...this.cline.clineMessages.slice(nextUserMessageIndex)
|
...this.cline.clineMessages.slice(nextUserMessageIndex),
|
||||||
])
|
])
|
||||||
} else {
|
} else {
|
||||||
// If no next user message, keep only messages before current message
|
// If no next user message, keep only messages before current message
|
||||||
await this.cline.overwriteClineMessages(
|
await this.cline.overwriteClineMessages(
|
||||||
this.cline.clineMessages.slice(0, messageIndex)
|
this.cline.clineMessages.slice(0, messageIndex),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle API messages
|
// Handle API messages
|
||||||
if (apiConversationHistoryIndex !== -1) {
|
if (apiConversationHistoryIndex !== -1) {
|
||||||
if (nextUserMessage && nextUserMessage.ts) {
|
if (nextUserMessage && nextUserMessage.ts) {
|
||||||
// Keep messages before current API message and after next user message
|
// Keep messages before current API message and after next user message
|
||||||
await this.cline.overwriteApiConversationHistory([
|
await this.cline.overwriteApiConversationHistory([
|
||||||
...this.cline.apiConversationHistory.slice(0, apiConversationHistoryIndex),
|
...this.cline.apiConversationHistory.slice(
|
||||||
...this.cline.apiConversationHistory.filter(msg => msg.ts && msg.ts >= nextUserMessage.ts)
|
0,
|
||||||
|
apiConversationHistoryIndex,
|
||||||
|
),
|
||||||
|
...this.cline.apiConversationHistory.filter(
|
||||||
|
(msg) => msg.ts && msg.ts >= nextUserMessage.ts,
|
||||||
|
),
|
||||||
])
|
])
|
||||||
} else {
|
} else {
|
||||||
// If no next user message, keep only messages before current API message
|
// If no next user message, keep only messages before current API message
|
||||||
await this.cline.overwriteApiConversationHistory(
|
await this.cline.overwriteApiConversationHistory(
|
||||||
this.cline.apiConversationHistory.slice(0, apiConversationHistoryIndex)
|
this.cline.apiConversationHistory.slice(0, apiConversationHistoryIndex),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (answer === "This and all subsequent messages") {
|
} else if (answer === "This and all subsequent messages") {
|
||||||
// Delete this message and all that follow
|
// Delete this message and all that follow
|
||||||
await this.cline.overwriteClineMessages(this.cline.clineMessages.slice(0, messageIndex))
|
await this.cline.overwriteClineMessages(
|
||||||
|
this.cline.clineMessages.slice(0, messageIndex),
|
||||||
|
)
|
||||||
if (apiConversationHistoryIndex !== -1) {
|
if (apiConversationHistoryIndex !== -1) {
|
||||||
await this.cline.overwriteApiConversationHistory(this.cline.apiConversationHistory.slice(0, apiConversationHistoryIndex))
|
await this.cline.overwriteApiConversationHistory(
|
||||||
|
this.cline.apiConversationHistory.slice(0, apiConversationHistoryIndex),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
await this.initClineWithHistoryItem(historyItem)
|
await this.initClineWithHistoryItem(historyItem)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -891,12 +905,13 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
case "enhancePrompt":
|
case "enhancePrompt":
|
||||||
if (message.text) {
|
if (message.text) {
|
||||||
try {
|
try {
|
||||||
const { apiConfiguration, customPrompts, listApiConfigMeta, enhancementApiConfigId } = await this.getState()
|
const { apiConfiguration, customPrompts, listApiConfigMeta, enhancementApiConfigId } =
|
||||||
|
await this.getState()
|
||||||
|
|
||||||
// Try to get enhancement config first, fall back to current config
|
// Try to get enhancement config first, fall back to current config
|
||||||
let configToUse: ApiConfiguration = apiConfiguration
|
let configToUse: ApiConfiguration = apiConfiguration
|
||||||
if (enhancementApiConfigId) {
|
if (enhancementApiConfigId) {
|
||||||
const config = listApiConfigMeta?.find(c => c.id === enhancementApiConfigId)
|
const config = listApiConfigMeta?.find((c) => c.id === enhancementApiConfigId)
|
||||||
if (config?.name) {
|
if (config?.name) {
|
||||||
const loadedConfig = await this.configManager.LoadConfig(config.name)
|
const loadedConfig = await this.configManager.LoadConfig(config.name)
|
||||||
if (loadedConfig.apiProvider) {
|
if (loadedConfig.apiProvider) {
|
||||||
@@ -904,41 +919,49 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const getEnhancePrompt = (value: string | PromptComponent | undefined): string => {
|
const getEnhancePrompt = (value: string | PromptComponent | undefined): string => {
|
||||||
if (typeof value === 'string') {
|
if (typeof value === "string") {
|
||||||
return value;
|
return value
|
||||||
}
|
}
|
||||||
return enhance.prompt; // Use the constant from modes.ts which we know is a string
|
return enhance.prompt // Use the constant from modes.ts which we know is a string
|
||||||
}
|
}
|
||||||
const enhancedPrompt = await enhancePrompt(
|
const enhancedPrompt = await enhancePrompt(
|
||||||
configToUse,
|
configToUse,
|
||||||
message.text,
|
message.text,
|
||||||
getEnhancePrompt(customPrompts?.enhance)
|
getEnhancePrompt(customPrompts?.enhance),
|
||||||
)
|
)
|
||||||
await this.postMessageToWebview({
|
await this.postMessageToWebview({
|
||||||
type: "enhancedPrompt",
|
type: "enhancedPrompt",
|
||||||
text: enhancedPrompt
|
text: enhancedPrompt,
|
||||||
})
|
})
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error enhancing prompt:", error)
|
console.error("Error enhancing prompt:", error)
|
||||||
vscode.window.showErrorMessage("Failed to enhance prompt")
|
vscode.window.showErrorMessage("Failed to enhance prompt")
|
||||||
await this.postMessageToWebview({
|
await this.postMessageToWebview({
|
||||||
type: "enhancedPrompt"
|
type: "enhancedPrompt",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
case "getSystemPrompt":
|
case "getSystemPrompt":
|
||||||
try {
|
try {
|
||||||
const { apiConfiguration, customPrompts, customInstructions, preferredLanguage, browserViewportSize, mcpEnabled } = await this.getState()
|
const {
|
||||||
const cwd = vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0) || ''
|
apiConfiguration,
|
||||||
|
customPrompts,
|
||||||
|
customInstructions,
|
||||||
|
preferredLanguage,
|
||||||
|
browserViewportSize,
|
||||||
|
mcpEnabled,
|
||||||
|
} = await this.getState()
|
||||||
|
const cwd =
|
||||||
|
vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0) || ""
|
||||||
|
|
||||||
const mode = message.mode ?? defaultModeSlug
|
const mode = message.mode ?? defaultModeSlug
|
||||||
const instructions = await addCustomInstructions(
|
const instructions = await addCustomInstructions(
|
||||||
{ customInstructions, customPrompts, preferredLanguage },
|
{ customInstructions, customPrompts, preferredLanguage },
|
||||||
cwd,
|
cwd,
|
||||||
mode
|
mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
const systemPrompt = await SYSTEM_PROMPT(
|
const systemPrompt = await SYSTEM_PROMPT(
|
||||||
@@ -948,14 +971,14 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
undefined,
|
undefined,
|
||||||
browserViewportSize ?? "900x600",
|
browserViewportSize ?? "900x600",
|
||||||
mode,
|
mode,
|
||||||
customPrompts
|
customPrompts,
|
||||||
)
|
)
|
||||||
const fullPrompt = instructions ? `${systemPrompt}${instructions}` : systemPrompt
|
const fullPrompt = instructions ? `${systemPrompt}${instructions}` : systemPrompt
|
||||||
|
|
||||||
await this.postMessageToWebview({
|
await this.postMessageToWebview({
|
||||||
type: "systemPrompt",
|
type: "systemPrompt",
|
||||||
text: fullPrompt,
|
text: fullPrompt,
|
||||||
mode: message.mode
|
mode: message.mode,
|
||||||
})
|
})
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error getting system prompt:", error)
|
console.error("Error getting system prompt:", error)
|
||||||
@@ -969,7 +992,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
const commits = await searchCommits(message.query || "", cwd)
|
const commits = await searchCommits(message.query || "", cwd)
|
||||||
await this.postMessageToWebview({
|
await this.postMessageToWebview({
|
||||||
type: "commitSearchResults",
|
type: "commitSearchResults",
|
||||||
commits
|
commits,
|
||||||
})
|
})
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error searching commits:", error)
|
console.error("Error searching commits:", error)
|
||||||
@@ -981,9 +1004,9 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
case "upsertApiConfiguration":
|
case "upsertApiConfiguration":
|
||||||
if (message.text && message.apiConfiguration) {
|
if (message.text && message.apiConfiguration) {
|
||||||
try {
|
try {
|
||||||
await this.configManager.SaveConfig(message.text, message.apiConfiguration);
|
await this.configManager.SaveConfig(message.text, message.apiConfiguration)
|
||||||
let listApiConfig = await this.configManager.ListConfig();
|
let listApiConfig = await this.configManager.ListConfig()
|
||||||
|
|
||||||
await Promise.all([
|
await Promise.all([
|
||||||
this.updateGlobalState("listApiConfigMeta", listApiConfig),
|
this.updateGlobalState("listApiConfigMeta", listApiConfig),
|
||||||
this.updateApiConfiguration(message.apiConfiguration),
|
this.updateApiConfiguration(message.apiConfiguration),
|
||||||
@@ -1002,18 +1025,16 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
try {
|
try {
|
||||||
const { oldName, newName } = message.values
|
const { oldName, newName } = message.values
|
||||||
|
|
||||||
await this.configManager.SaveConfig(newName, message.apiConfiguration);
|
await this.configManager.SaveConfig(newName, message.apiConfiguration)
|
||||||
await this.configManager.DeleteConfig(oldName)
|
await this.configManager.DeleteConfig(oldName)
|
||||||
|
|
||||||
let listApiConfig = await this.configManager.ListConfig();
|
let listApiConfig = await this.configManager.ListConfig()
|
||||||
const config = listApiConfig?.find(c => c.name === newName);
|
const config = listApiConfig?.find((c) => c.name === newName)
|
||||||
|
|
||||||
// Update listApiConfigMeta first to ensure UI has latest data
|
|
||||||
await this.updateGlobalState("listApiConfigMeta", listApiConfig);
|
|
||||||
|
|
||||||
await Promise.all([
|
// Update listApiConfigMeta first to ensure UI has latest data
|
||||||
this.updateGlobalState("currentApiConfigName", newName),
|
await this.updateGlobalState("listApiConfigMeta", listApiConfig)
|
||||||
])
|
|
||||||
|
await Promise.all([this.updateGlobalState("currentApiConfigName", newName)])
|
||||||
|
|
||||||
await this.postStateToWebview()
|
await this.postStateToWebview()
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -1025,9 +1046,9 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
case "loadApiConfiguration":
|
case "loadApiConfiguration":
|
||||||
if (message.text) {
|
if (message.text) {
|
||||||
try {
|
try {
|
||||||
const apiConfig = await this.configManager.LoadConfig(message.text);
|
const apiConfig = await this.configManager.LoadConfig(message.text)
|
||||||
const listApiConfig = await this.configManager.ListConfig();
|
const listApiConfig = await this.configManager.ListConfig()
|
||||||
|
|
||||||
await Promise.all([
|
await Promise.all([
|
||||||
this.updateGlobalState("listApiConfigMeta", listApiConfig),
|
this.updateGlobalState("listApiConfigMeta", listApiConfig),
|
||||||
this.updateGlobalState("currentApiConfigName", message.text),
|
this.updateGlobalState("currentApiConfigName", message.text),
|
||||||
@@ -1054,16 +1075,16 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await this.configManager.DeleteConfig(message.text);
|
await this.configManager.DeleteConfig(message.text)
|
||||||
const listApiConfig = await this.configManager.ListConfig();
|
const listApiConfig = await this.configManager.ListConfig()
|
||||||
|
|
||||||
// Update listApiConfigMeta first to ensure UI has latest data
|
// Update listApiConfigMeta first to ensure UI has latest data
|
||||||
await this.updateGlobalState("listApiConfigMeta", listApiConfig);
|
await this.updateGlobalState("listApiConfigMeta", listApiConfig)
|
||||||
|
|
||||||
// If this was the current config, switch to first available
|
// If this was the current config, switch to first available
|
||||||
let currentApiConfigName = await this.getGlobalState("currentApiConfigName")
|
let currentApiConfigName = await this.getGlobalState("currentApiConfigName")
|
||||||
if (message.text === currentApiConfigName && listApiConfig?.[0]?.name) {
|
if (message.text === currentApiConfigName && listApiConfig?.[0]?.name) {
|
||||||
const apiConfig = await this.configManager.LoadConfig(listApiConfig[0].name);
|
const apiConfig = await this.configManager.LoadConfig(listApiConfig[0].name)
|
||||||
await Promise.all([
|
await Promise.all([
|
||||||
this.updateGlobalState("currentApiConfigName", listApiConfig[0].name),
|
this.updateGlobalState("currentApiConfigName", listApiConfig[0].name),
|
||||||
this.updateApiConfiguration(apiConfig),
|
this.updateApiConfiguration(apiConfig),
|
||||||
@@ -1079,7 +1100,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
break
|
break
|
||||||
case "getListApiConfiguration":
|
case "getListApiConfiguration":
|
||||||
try {
|
try {
|
||||||
let listApiConfig = await this.configManager.ListConfig();
|
let listApiConfig = await this.configManager.ListConfig()
|
||||||
await this.updateGlobalState("listApiConfigMeta", listApiConfig)
|
await this.updateGlobalState("listApiConfigMeta", listApiConfig)
|
||||||
this.postMessageToWebview({ type: "listApiConfig", listApiConfig })
|
this.postMessageToWebview({ type: "listApiConfig", listApiConfig })
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -1087,7 +1108,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
vscode.window.showErrorMessage("Failed to get list api configuration")
|
vscode.window.showErrorMessage("Failed to get list api configuration")
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
case "experimentalDiffStrategy":
|
case "experimentalDiffStrategy":
|
||||||
await this.updateGlobalState("experimentalDiffStrategy", message.bool ?? false)
|
await this.updateGlobalState("experimentalDiffStrategy", message.bool ?? false)
|
||||||
// Update diffStrategy in current Cline instance if it exists
|
// Update diffStrategy in current Cline instance if it exists
|
||||||
if (this.cline) {
|
if (this.cline) {
|
||||||
@@ -1103,13 +1124,13 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
|
|
||||||
private async updateApiConfiguration(apiConfiguration: ApiConfiguration) {
|
private async updateApiConfiguration(apiConfiguration: ApiConfiguration) {
|
||||||
// Update mode's default config
|
// Update mode's default config
|
||||||
const { mode } = await this.getState();
|
const { mode } = await this.getState()
|
||||||
if (mode) {
|
if (mode) {
|
||||||
const currentApiConfigName = await this.getGlobalState("currentApiConfigName");
|
const currentApiConfigName = await this.getGlobalState("currentApiConfigName")
|
||||||
const listApiConfig = await this.configManager.ListConfig();
|
const listApiConfig = await this.configManager.ListConfig()
|
||||||
const config = listApiConfig?.find(c => c.name === currentApiConfigName);
|
const config = listApiConfig?.find((c) => c.name === currentApiConfigName)
|
||||||
if (config?.id) {
|
if (config?.id) {
|
||||||
await this.configManager.SetModeConfig(mode, config.id);
|
await this.configManager.SetModeConfig(mode, config.id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1181,7 +1202,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
await this.storeSecret("mistralApiKey", mistralApiKey)
|
await this.storeSecret("mistralApiKey", mistralApiKey)
|
||||||
if (this.cline) {
|
if (this.cline) {
|
||||||
this.cline.api = buildApiHandler(apiConfiguration)
|
this.cline.api = buildApiHandler(apiConfiguration)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async updateCustomInstructions(instructions?: string) {
|
async updateCustomInstructions(instructions?: string) {
|
||||||
@@ -1252,11 +1273,11 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
// VSCode LM API
|
// VSCode LM API
|
||||||
private async getVsCodeLmModels() {
|
private async getVsCodeLmModels() {
|
||||||
try {
|
try {
|
||||||
const models = await vscode.lm.selectChatModels({});
|
const models = await vscode.lm.selectChatModels({})
|
||||||
return models || [];
|
return models || []
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error fetching VS Code LM models:', error);
|
console.error("Error fetching VS Code LM models:", error)
|
||||||
return [];
|
return []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1346,10 +1367,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async readGlamaModels(): Promise<Record<string, ModelInfo> | undefined> {
|
async readGlamaModels(): Promise<Record<string, ModelInfo> | undefined> {
|
||||||
const glamaModelsFilePath = path.join(
|
const glamaModelsFilePath = path.join(await this.ensureCacheDirectoryExists(), GlobalFileNames.glamaModels)
|
||||||
await this.ensureCacheDirectoryExists(),
|
|
||||||
GlobalFileNames.glamaModels,
|
|
||||||
)
|
|
||||||
const fileExists = await fileExistsAtPath(glamaModelsFilePath)
|
const fileExists = await fileExistsAtPath(glamaModelsFilePath)
|
||||||
if (fileExists) {
|
if (fileExists) {
|
||||||
const fileContents = await fs.readFile(glamaModelsFilePath, "utf8")
|
const fileContents = await fs.readFile(glamaModelsFilePath, "utf8")
|
||||||
@@ -1359,10 +1377,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async refreshGlamaModels() {
|
async refreshGlamaModels() {
|
||||||
const glamaModelsFilePath = path.join(
|
const glamaModelsFilePath = path.join(await this.ensureCacheDirectoryExists(), GlobalFileNames.glamaModels)
|
||||||
await this.ensureCacheDirectoryExists(),
|
|
||||||
GlobalFileNames.glamaModels,
|
|
||||||
)
|
|
||||||
|
|
||||||
let models: Record<string, ModelInfo> = {}
|
let models: Record<string, ModelInfo> = {}
|
||||||
try {
|
try {
|
||||||
@@ -1397,7 +1412,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
if (response.data) {
|
if (response.data) {
|
||||||
const rawModels = response.data;
|
const rawModels = response.data
|
||||||
const parsePrice = (price: any) => {
|
const parsePrice = (price: any) => {
|
||||||
if (price) {
|
if (price) {
|
||||||
return parseFloat(price) * 1_000_000
|
return parseFloat(price) * 1_000_000
|
||||||
@@ -1565,7 +1580,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
uiMessagesFilePath: string
|
uiMessagesFilePath: string
|
||||||
apiConversationHistory: Anthropic.MessageParam[]
|
apiConversationHistory: Anthropic.MessageParam[]
|
||||||
}> {
|
}> {
|
||||||
const history = (await this.getGlobalState("taskHistory") as HistoryItem[] | undefined) || []
|
const history = ((await this.getGlobalState("taskHistory")) as HistoryItem[] | undefined) || []
|
||||||
const historyItem = history.find((item) => item.id === id)
|
const historyItem = history.find((item) => item.id === id)
|
||||||
if (historyItem) {
|
if (historyItem) {
|
||||||
const taskDirPath = path.join(this.context.globalStorageUri.fsPath, "tasks", id)
|
const taskDirPath = path.join(this.context.globalStorageUri.fsPath, "tasks", id)
|
||||||
@@ -1630,7 +1645,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
|
|
||||||
async deleteTaskFromState(id: string) {
|
async deleteTaskFromState(id: string) {
|
||||||
// Remove the task from history
|
// Remove the task from history
|
||||||
const taskHistory = (await this.getGlobalState("taskHistory") as HistoryItem[]) || []
|
const taskHistory = ((await this.getGlobalState("taskHistory")) as HistoryItem[]) || []
|
||||||
const updatedTaskHistory = taskHistory.filter((task) => task.id !== id)
|
const updatedTaskHistory = taskHistory.filter((task) => task.id !== id)
|
||||||
await this.updateGlobalState("taskHistory", updatedTaskHistory)
|
await this.updateGlobalState("taskHistory", updatedTaskHistory)
|
||||||
|
|
||||||
@@ -1671,13 +1686,11 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
mode,
|
mode,
|
||||||
customPrompts,
|
customPrompts,
|
||||||
enhancementApiConfigId,
|
enhancementApiConfigId,
|
||||||
experimentalDiffStrategy,
|
experimentalDiffStrategy,
|
||||||
autoApprovalEnabled,
|
autoApprovalEnabled,
|
||||||
} = await this.getState()
|
} = await this.getState()
|
||||||
|
|
||||||
const allowedCommands = vscode.workspace
|
const allowedCommands = vscode.workspace.getConfiguration("roo-cline").get<string[]>("allowedCommands") || []
|
||||||
.getConfiguration('roo-cline')
|
|
||||||
.get<string[]>('allowedCommands') || []
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
version: this.context.extension?.packageJSON?.version ?? "",
|
version: this.context.extension?.packageJSON?.version ?? "",
|
||||||
@@ -1700,7 +1713,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
soundVolume: soundVolume ?? 0.5,
|
soundVolume: soundVolume ?? 0.5,
|
||||||
browserViewportSize: browserViewportSize ?? "900x600",
|
browserViewportSize: browserViewportSize ?? "900x600",
|
||||||
screenshotQuality: screenshotQuality ?? 75,
|
screenshotQuality: screenshotQuality ?? 75,
|
||||||
preferredLanguage: preferredLanguage ?? 'English',
|
preferredLanguage: preferredLanguage ?? "English",
|
||||||
writeDelayMs: writeDelayMs ?? 1000,
|
writeDelayMs: writeDelayMs ?? 1000,
|
||||||
terminalOutputLineLimit: terminalOutputLineLimit ?? 500,
|
terminalOutputLineLimit: terminalOutputLineLimit ?? 500,
|
||||||
fuzzyMatchThreshold: fuzzyMatchThreshold ?? 1.0,
|
fuzzyMatchThreshold: fuzzyMatchThreshold ?? 1.0,
|
||||||
@@ -1712,7 +1725,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
mode: mode ?? defaultModeSlug,
|
mode: mode ?? defaultModeSlug,
|
||||||
customPrompts: customPrompts ?? {},
|
customPrompts: customPrompts ?? {},
|
||||||
enhancementApiConfigId,
|
enhancementApiConfigId,
|
||||||
experimentalDiffStrategy: experimentalDiffStrategy ?? false,
|
experimentalDiffStrategy: experimentalDiffStrategy ?? false,
|
||||||
autoApprovalEnabled: autoApprovalEnabled ?? false,
|
autoApprovalEnabled: autoApprovalEnabled ?? false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1829,7 +1842,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
modeApiConfigs,
|
modeApiConfigs,
|
||||||
customPrompts,
|
customPrompts,
|
||||||
enhancementApiConfigId,
|
enhancementApiConfigId,
|
||||||
experimentalDiffStrategy,
|
experimentalDiffStrategy,
|
||||||
autoApprovalEnabled,
|
autoApprovalEnabled,
|
||||||
] = await Promise.all([
|
] = await Promise.all([
|
||||||
this.getGlobalState("apiProvider") as Promise<ApiProvider | undefined>,
|
this.getGlobalState("apiProvider") as Promise<ApiProvider | undefined>,
|
||||||
@@ -1891,7 +1904,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
this.getGlobalState("modeApiConfigs") as Promise<Record<Mode, string> | undefined>,
|
this.getGlobalState("modeApiConfigs") as Promise<Record<Mode, string> | undefined>,
|
||||||
this.getGlobalState("customPrompts") as Promise<CustomPrompts | undefined>,
|
this.getGlobalState("customPrompts") as Promise<CustomPrompts | undefined>,
|
||||||
this.getGlobalState("enhancementApiConfigId") as Promise<string | undefined>,
|
this.getGlobalState("enhancementApiConfigId") as Promise<string | undefined>,
|
||||||
this.getGlobalState("experimentalDiffStrategy") as Promise<boolean | undefined>,
|
this.getGlobalState("experimentalDiffStrategy") as Promise<boolean | undefined>,
|
||||||
this.getGlobalState("autoApprovalEnabled") as Promise<boolean | undefined>,
|
this.getGlobalState("autoApprovalEnabled") as Promise<boolean | undefined>,
|
||||||
])
|
])
|
||||||
|
|
||||||
@@ -1962,48 +1975,50 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
writeDelayMs: writeDelayMs ?? 1000,
|
writeDelayMs: writeDelayMs ?? 1000,
|
||||||
terminalOutputLineLimit: terminalOutputLineLimit ?? 500,
|
terminalOutputLineLimit: terminalOutputLineLimit ?? 500,
|
||||||
mode: mode ?? defaultModeSlug,
|
mode: mode ?? defaultModeSlug,
|
||||||
preferredLanguage: preferredLanguage ?? (() => {
|
preferredLanguage:
|
||||||
// Get VSCode's locale setting
|
preferredLanguage ??
|
||||||
const vscodeLang = vscode.env.language;
|
(() => {
|
||||||
// Map VSCode locale to our supported languages
|
// Get VSCode's locale setting
|
||||||
const langMap: { [key: string]: string } = {
|
const vscodeLang = vscode.env.language
|
||||||
'en': 'English',
|
// Map VSCode locale to our supported languages
|
||||||
'ar': 'Arabic',
|
const langMap: { [key: string]: string } = {
|
||||||
'pt-br': 'Brazilian Portuguese',
|
en: "English",
|
||||||
'cs': 'Czech',
|
ar: "Arabic",
|
||||||
'fr': 'French',
|
"pt-br": "Brazilian Portuguese",
|
||||||
'de': 'German',
|
cs: "Czech",
|
||||||
'hi': 'Hindi',
|
fr: "French",
|
||||||
'hu': 'Hungarian',
|
de: "German",
|
||||||
'it': 'Italian',
|
hi: "Hindi",
|
||||||
'ja': 'Japanese',
|
hu: "Hungarian",
|
||||||
'ko': 'Korean',
|
it: "Italian",
|
||||||
'pl': 'Polish',
|
ja: "Japanese",
|
||||||
'pt': 'Portuguese',
|
ko: "Korean",
|
||||||
'ru': 'Russian',
|
pl: "Polish",
|
||||||
'zh-cn': 'Simplified Chinese',
|
pt: "Portuguese",
|
||||||
'es': 'Spanish',
|
ru: "Russian",
|
||||||
'zh-tw': 'Traditional Chinese',
|
"zh-cn": "Simplified Chinese",
|
||||||
'tr': 'Turkish'
|
es: "Spanish",
|
||||||
};
|
"zh-tw": "Traditional Chinese",
|
||||||
// Return mapped language or default to English
|
tr: "Turkish",
|
||||||
return langMap[vscodeLang.split('-')[0]] ?? 'English';
|
}
|
||||||
})(),
|
// Return mapped language or default to English
|
||||||
|
return langMap[vscodeLang.split("-")[0]] ?? "English"
|
||||||
|
})(),
|
||||||
mcpEnabled: mcpEnabled ?? true,
|
mcpEnabled: mcpEnabled ?? true,
|
||||||
alwaysApproveResubmit: alwaysApproveResubmit ?? false,
|
alwaysApproveResubmit: alwaysApproveResubmit ?? false,
|
||||||
requestDelaySeconds: requestDelaySeconds ?? 5,
|
requestDelaySeconds: requestDelaySeconds ?? 5,
|
||||||
currentApiConfigName: currentApiConfigName ?? "default",
|
currentApiConfigName: currentApiConfigName ?? "default",
|
||||||
listApiConfigMeta: listApiConfigMeta ?? [],
|
listApiConfigMeta: listApiConfigMeta ?? [],
|
||||||
modeApiConfigs: modeApiConfigs ?? {} as Record<Mode, string>,
|
modeApiConfigs: modeApiConfigs ?? ({} as Record<Mode, string>),
|
||||||
customPrompts: customPrompts ?? {},
|
customPrompts: customPrompts ?? {},
|
||||||
enhancementApiConfigId,
|
enhancementApiConfigId,
|
||||||
experimentalDiffStrategy: experimentalDiffStrategy ?? false,
|
experimentalDiffStrategy: experimentalDiffStrategy ?? false,
|
||||||
autoApprovalEnabled: autoApprovalEnabled ?? false,
|
autoApprovalEnabled: autoApprovalEnabled ?? false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async updateTaskHistory(item: HistoryItem): Promise<HistoryItem[]> {
|
async updateTaskHistory(item: HistoryItem): Promise<HistoryItem[]> {
|
||||||
const history = (await this.getGlobalState("taskHistory") as HistoryItem[] | undefined) || []
|
const history = ((await this.getGlobalState("taskHistory")) as HistoryItem[] | undefined) || []
|
||||||
const existingItemIndex = history.findIndex((h) => h.id === item.id)
|
const existingItemIndex = history.findIndex((h) => h.id === item.id)
|
||||||
|
|
||||||
if (existingItemIndex !== -1) {
|
if (existingItemIndex !== -1) {
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -27,13 +27,11 @@ export function activate(context: vscode.ExtensionContext) {
|
|||||||
outputChannel.appendLine("Cline extension activated")
|
outputChannel.appendLine("Cline extension activated")
|
||||||
|
|
||||||
// Get default commands from configuration
|
// Get default commands from configuration
|
||||||
const defaultCommands = vscode.workspace
|
const defaultCommands = vscode.workspace.getConfiguration("roo-cline").get<string[]>("allowedCommands") || []
|
||||||
.getConfiguration('roo-cline')
|
|
||||||
.get<string[]>('allowedCommands') || [];
|
|
||||||
|
|
||||||
// Initialize global state if not already set
|
// Initialize global state if not already set
|
||||||
if (!context.globalState.get('allowedCommands')) {
|
if (!context.globalState.get("allowedCommands")) {
|
||||||
context.globalState.update('allowedCommands', defaultCommands);
|
context.globalState.update("allowedCommands", defaultCommands)
|
||||||
}
|
}
|
||||||
|
|
||||||
const sidebarProvider = new ClineProvider(context, outputChannel)
|
const sidebarProvider = new ClineProvider(context, outputChannel)
|
||||||
|
|||||||
@@ -132,10 +132,10 @@ export class DiffViewProvider {
|
|||||||
// Apply the final content
|
// Apply the final content
|
||||||
const finalEdit = new vscode.WorkspaceEdit()
|
const finalEdit = new vscode.WorkspaceEdit()
|
||||||
finalEdit.replace(document.uri, new vscode.Range(0, 0, document.lineCount, 0), accumulatedContent)
|
finalEdit.replace(document.uri, new vscode.Range(0, 0, document.lineCount, 0), accumulatedContent)
|
||||||
await vscode.workspace.applyEdit(finalEdit)
|
await vscode.workspace.applyEdit(finalEdit)
|
||||||
// Clear all decorations at the end (after applying final edit)
|
// Clear all decorations at the end (after applying final edit)
|
||||||
this.fadedOverlayController.clear()
|
this.fadedOverlayController.clear()
|
||||||
this.activeLineController.clear()
|
this.activeLineController.clear()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -352,4 +352,4 @@ export class DiffViewProvider {
|
|||||||
this.streamedLines = []
|
this.streamedLines = []
|
||||||
this.preDiagnostics = []
|
this.preDiagnostics = []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import { DiffViewProvider } from '../DiffViewProvider';
|
import { DiffViewProvider } from "../DiffViewProvider"
|
||||||
import * as vscode from 'vscode';
|
import * as vscode from "vscode"
|
||||||
|
|
||||||
// Mock vscode
|
// Mock vscode
|
||||||
jest.mock('vscode', () => ({
|
jest.mock("vscode", () => ({
|
||||||
workspace: {
|
workspace: {
|
||||||
applyEdit: jest.fn(),
|
applyEdit: jest.fn(),
|
||||||
},
|
},
|
||||||
@@ -19,34 +19,34 @@ jest.mock('vscode', () => ({
|
|||||||
TextEditorRevealType: {
|
TextEditorRevealType: {
|
||||||
InCenter: 2,
|
InCenter: 2,
|
||||||
},
|
},
|
||||||
}));
|
}))
|
||||||
|
|
||||||
// Mock DecorationController
|
// Mock DecorationController
|
||||||
jest.mock('../DecorationController', () => ({
|
jest.mock("../DecorationController", () => ({
|
||||||
DecorationController: jest.fn().mockImplementation(() => ({
|
DecorationController: jest.fn().mockImplementation(() => ({
|
||||||
setActiveLine: jest.fn(),
|
setActiveLine: jest.fn(),
|
||||||
updateOverlayAfterLine: jest.fn(),
|
updateOverlayAfterLine: jest.fn(),
|
||||||
clear: jest.fn(),
|
clear: jest.fn(),
|
||||||
})),
|
})),
|
||||||
}));
|
}))
|
||||||
|
|
||||||
describe('DiffViewProvider', () => {
|
describe("DiffViewProvider", () => {
|
||||||
let diffViewProvider: DiffViewProvider;
|
let diffViewProvider: DiffViewProvider
|
||||||
const mockCwd = '/mock/cwd';
|
const mockCwd = "/mock/cwd"
|
||||||
let mockWorkspaceEdit: { replace: jest.Mock; delete: jest.Mock };
|
let mockWorkspaceEdit: { replace: jest.Mock; delete: jest.Mock }
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks();
|
jest.clearAllMocks()
|
||||||
mockWorkspaceEdit = {
|
mockWorkspaceEdit = {
|
||||||
replace: jest.fn(),
|
replace: jest.fn(),
|
||||||
delete: jest.fn(),
|
delete: jest.fn(),
|
||||||
};
|
}
|
||||||
(vscode.WorkspaceEdit as jest.Mock).mockImplementation(() => mockWorkspaceEdit);
|
;(vscode.WorkspaceEdit as jest.Mock).mockImplementation(() => mockWorkspaceEdit)
|
||||||
|
|
||||||
diffViewProvider = new DiffViewProvider(mockCwd);
|
diffViewProvider = new DiffViewProvider(mockCwd)
|
||||||
// Mock the necessary properties and methods
|
// Mock the necessary properties and methods
|
||||||
(diffViewProvider as any).relPath = 'test.txt';
|
;(diffViewProvider as any).relPath = "test.txt"
|
||||||
(diffViewProvider as any).activeDiffEditor = {
|
;(diffViewProvider as any).activeDiffEditor = {
|
||||||
document: {
|
document: {
|
||||||
uri: { fsPath: `${mockCwd}/test.txt` },
|
uri: { fsPath: `${mockCwd}/test.txt` },
|
||||||
getText: jest.fn(),
|
getText: jest.fn(),
|
||||||
@@ -58,43 +58,39 @@ describe('DiffViewProvider', () => {
|
|||||||
},
|
},
|
||||||
edit: jest.fn().mockResolvedValue(true),
|
edit: jest.fn().mockResolvedValue(true),
|
||||||
revealRange: jest.fn(),
|
revealRange: jest.fn(),
|
||||||
};
|
}
|
||||||
(diffViewProvider as any).activeLineController = { setActiveLine: jest.fn(), clear: jest.fn() };
|
;(diffViewProvider as any).activeLineController = { setActiveLine: jest.fn(), clear: jest.fn() }
|
||||||
(diffViewProvider as any).fadedOverlayController = { updateOverlayAfterLine: jest.fn(), clear: jest.fn() };
|
;(diffViewProvider as any).fadedOverlayController = { updateOverlayAfterLine: jest.fn(), clear: jest.fn() }
|
||||||
});
|
})
|
||||||
|
|
||||||
describe('update method', () => {
|
describe("update method", () => {
|
||||||
it('should preserve empty last line when original content has one', async () => {
|
it("should preserve empty last line when original content has one", async () => {
|
||||||
(diffViewProvider as any).originalContent = 'Original content\n';
|
;(diffViewProvider as any).originalContent = "Original content\n"
|
||||||
await diffViewProvider.update('New content', true);
|
await diffViewProvider.update("New content", true)
|
||||||
|
|
||||||
expect(mockWorkspaceEdit.replace).toHaveBeenCalledWith(
|
expect(mockWorkspaceEdit.replace).toHaveBeenCalledWith(
|
||||||
expect.anything(),
|
expect.anything(),
|
||||||
expect.anything(),
|
expect.anything(),
|
||||||
'New content\n'
|
"New content\n",
|
||||||
);
|
)
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should not add extra newline when accumulated content already ends with one', async () => {
|
it("should not add extra newline when accumulated content already ends with one", async () => {
|
||||||
(diffViewProvider as any).originalContent = 'Original content\n';
|
;(diffViewProvider as any).originalContent = "Original content\n"
|
||||||
await diffViewProvider.update('New content\n', true);
|
await diffViewProvider.update("New content\n", true)
|
||||||
|
|
||||||
expect(mockWorkspaceEdit.replace).toHaveBeenCalledWith(
|
expect(mockWorkspaceEdit.replace).toHaveBeenCalledWith(
|
||||||
expect.anything(),
|
expect.anything(),
|
||||||
expect.anything(),
|
expect.anything(),
|
||||||
'New content\n'
|
"New content\n",
|
||||||
);
|
)
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should not add newline when original content does not end with one', async () => {
|
it("should not add newline when original content does not end with one", async () => {
|
||||||
(diffViewProvider as any).originalContent = 'Original content';
|
;(diffViewProvider as any).originalContent = "Original content"
|
||||||
await diffViewProvider.update('New content', true);
|
await diffViewProvider.update("New content", true)
|
||||||
|
|
||||||
expect(mockWorkspaceEdit.replace).toHaveBeenCalledWith(
|
expect(mockWorkspaceEdit.replace).toHaveBeenCalledWith(expect.anything(), expect.anything(), "New content")
|
||||||
expect.anything(),
|
})
|
||||||
expect.anything(),
|
})
|
||||||
'New content'
|
})
|
||||||
);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { detectCodeOmission } from '../detect-omission'
|
import { detectCodeOmission } from "../detect-omission"
|
||||||
|
|
||||||
describe('detectCodeOmission', () => {
|
describe("detectCodeOmission", () => {
|
||||||
const originalContent = `function example() {
|
const originalContent = `function example() {
|
||||||
// Some code
|
// Some code
|
||||||
const x = 1;
|
const x = 1;
|
||||||
@@ -10,124 +10,132 @@ describe('detectCodeOmission', () => {
|
|||||||
|
|
||||||
const generateLongContent = (commentLine: string, length: number = 90) => {
|
const generateLongContent = (commentLine: string, length: number = 90) => {
|
||||||
return `${commentLine}
|
return `${commentLine}
|
||||||
${Array.from({ length }, (_, i) => `const x${i} = ${i};`).join('\n')}
|
${Array.from({ length }, (_, i) => `const x${i} = ${i};`).join("\n")}
|
||||||
const y = 2;`
|
const y = 2;`
|
||||||
}
|
}
|
||||||
|
|
||||||
it('should skip comment checks for files under 100 lines', () => {
|
it("should skip comment checks for files under 100 lines", () => {
|
||||||
const newContent = `// Lines 1-50 remain unchanged
|
const newContent = `// Lines 1-50 remain unchanged
|
||||||
const z = 3;`
|
const z = 3;`
|
||||||
const predictedLineCount = 50
|
const predictedLineCount = 50
|
||||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should not detect regular comments without omission keywords', () => {
|
it("should not detect regular comments without omission keywords", () => {
|
||||||
const newContent = generateLongContent('// Adding new functionality')
|
const newContent = generateLongContent("// Adding new functionality")
|
||||||
const predictedLineCount = 150
|
const predictedLineCount = 150
|
||||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should not detect when comment is part of original content', () => {
|
it("should not detect when comment is part of original content", () => {
|
||||||
const originalWithComment = `// Content remains unchanged
|
const originalWithComment = `// Content remains unchanged
|
||||||
${originalContent}`
|
${originalContent}`
|
||||||
const newContent = generateLongContent('// Content remains unchanged')
|
const newContent = generateLongContent("// Content remains unchanged")
|
||||||
const predictedLineCount = 150
|
const predictedLineCount = 150
|
||||||
expect(detectCodeOmission(originalWithComment, newContent, predictedLineCount)).toBe(false)
|
expect(detectCodeOmission(originalWithComment, newContent, predictedLineCount)).toBe(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should not detect code that happens to contain omission keywords', () => {
|
it("should not detect code that happens to contain omission keywords", () => {
|
||||||
const newContent = generateLongContent(`const remains = 'some value';
|
const newContent = generateLongContent(`const remains = 'some value';
|
||||||
const unchanged = true;`)
|
const unchanged = true;`)
|
||||||
const predictedLineCount = 150
|
const predictedLineCount = 150
|
||||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should detect suspicious single-line comment when content is more than 20% shorter', () => {
|
it("should detect suspicious single-line comment when content is more than 20% shorter", () => {
|
||||||
const newContent = generateLongContent('// Previous content remains here\nconst x = 1;')
|
const newContent = generateLongContent("// Previous content remains here\nconst x = 1;")
|
||||||
const predictedLineCount = 150
|
const predictedLineCount = 150
|
||||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should not flag suspicious single-line comment when content is less than 20% shorter', () => {
|
it("should not flag suspicious single-line comment when content is less than 20% shorter", () => {
|
||||||
const newContent = generateLongContent('// Previous content remains here', 130)
|
const newContent = generateLongContent("// Previous content remains here", 130)
|
||||||
const predictedLineCount = 150
|
const predictedLineCount = 150
|
||||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should detect suspicious Python-style comment when content is more than 20% shorter', () => {
|
it("should detect suspicious Python-style comment when content is more than 20% shorter", () => {
|
||||||
const newContent = generateLongContent('# Previous content remains here\nconst x = 1;')
|
const newContent = generateLongContent("# Previous content remains here\nconst x = 1;")
|
||||||
const predictedLineCount = 150
|
const predictedLineCount = 150
|
||||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should not flag suspicious Python-style comment when content is less than 20% shorter', () => {
|
it("should not flag suspicious Python-style comment when content is less than 20% shorter", () => {
|
||||||
const newContent = generateLongContent('# Previous content remains here', 130)
|
const newContent = generateLongContent("# Previous content remains here", 130)
|
||||||
const predictedLineCount = 150
|
const predictedLineCount = 150
|
||||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should detect suspicious multi-line comment when content is more than 20% shorter', () => {
|
it("should detect suspicious multi-line comment when content is more than 20% shorter", () => {
|
||||||
const newContent = generateLongContent('/* Previous content remains the same */\nconst x = 1;')
|
const newContent = generateLongContent("/* Previous content remains the same */\nconst x = 1;")
|
||||||
const predictedLineCount = 150
|
const predictedLineCount = 150
|
||||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should not flag suspicious multi-line comment when content is less than 20% shorter', () => {
|
it("should not flag suspicious multi-line comment when content is less than 20% shorter", () => {
|
||||||
const newContent = generateLongContent('/* Previous content remains the same */', 130)
|
const newContent = generateLongContent("/* Previous content remains the same */", 130)
|
||||||
const predictedLineCount = 150
|
const predictedLineCount = 150
|
||||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should detect suspicious JSX comment when content is more than 20% shorter', () => {
|
it("should detect suspicious JSX comment when content is more than 20% shorter", () => {
|
||||||
const newContent = generateLongContent('{/* Rest of the code remains the same */}\nconst x = 1;')
|
const newContent = generateLongContent("{/* Rest of the code remains the same */}\nconst x = 1;")
|
||||||
const predictedLineCount = 150
|
const predictedLineCount = 150
|
||||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should not flag suspicious JSX comment when content is less than 20% shorter', () => {
|
it("should not flag suspicious JSX comment when content is less than 20% shorter", () => {
|
||||||
const newContent = generateLongContent('{/* Rest of the code remains the same */}', 130)
|
const newContent = generateLongContent("{/* Rest of the code remains the same */}", 130)
|
||||||
const predictedLineCount = 150
|
const predictedLineCount = 150
|
||||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should detect suspicious HTML comment when content is more than 20% shorter', () => {
|
it("should detect suspicious HTML comment when content is more than 20% shorter", () => {
|
||||||
const newContent = generateLongContent('<!-- Existing content unchanged -->\nconst x = 1;')
|
const newContent = generateLongContent("<!-- Existing content unchanged -->\nconst x = 1;")
|
||||||
const predictedLineCount = 150
|
const predictedLineCount = 150
|
||||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should not flag suspicious HTML comment when content is less than 20% shorter', () => {
|
it("should not flag suspicious HTML comment when content is less than 20% shorter", () => {
|
||||||
const newContent = generateLongContent('<!-- Existing content unchanged -->', 130)
|
const newContent = generateLongContent("<!-- Existing content unchanged -->", 130)
|
||||||
const predictedLineCount = 150
|
const predictedLineCount = 150
|
||||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should detect suspicious square bracket notation when content is more than 20% shorter', () => {
|
it("should detect suspicious square bracket notation when content is more than 20% shorter", () => {
|
||||||
const newContent = generateLongContent('[Previous content from line 1-305 remains exactly the same]\nconst x = 1;')
|
const newContent = generateLongContent(
|
||||||
|
"[Previous content from line 1-305 remains exactly the same]\nconst x = 1;",
|
||||||
|
)
|
||||||
const predictedLineCount = 150
|
const predictedLineCount = 150
|
||||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should not flag suspicious square bracket notation when content is less than 20% shorter', () => {
|
it("should not flag suspicious square bracket notation when content is less than 20% shorter", () => {
|
||||||
const newContent = generateLongContent('[Previous content from line 1-305 remains exactly the same]', 130)
|
const newContent = generateLongContent("[Previous content from line 1-305 remains exactly the same]", 130)
|
||||||
const predictedLineCount = 150
|
const predictedLineCount = 150
|
||||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should not flag content very close to predicted length', () => {
|
it("should not flag content very close to predicted length", () => {
|
||||||
const newContent = generateLongContent(`const x = 1;
|
const newContent = generateLongContent(
|
||||||
|
`const x = 1;
|
||||||
const y = 2;
|
const y = 2;
|
||||||
// This is a legitimate comment that remains here`, 130)
|
// This is a legitimate comment that remains here`,
|
||||||
|
130,
|
||||||
|
)
|
||||||
const predictedLineCount = 150
|
const predictedLineCount = 150
|
||||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should not flag when content is longer than predicted', () => {
|
it("should not flag when content is longer than predicted", () => {
|
||||||
const newContent = generateLongContent(`const x = 1;
|
const newContent = generateLongContent(
|
||||||
|
`const x = 1;
|
||||||
const y = 2;
|
const y = 2;
|
||||||
// Previous content remains here but we added more
|
// Previous content remains here but we added more
|
||||||
const z = 3;
|
const z = 3;
|
||||||
const w = 4;`, 160)
|
const w = 4;`,
|
||||||
|
160,
|
||||||
|
)
|
||||||
const predictedLineCount = 150
|
const predictedLineCount = 150
|
||||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||||
})
|
})
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user