From 60a0a824b96a0b326af4d8871b6903f4ddcfe114 Mon Sep 17 00:00:00 2001 From: Matt Rubens Date: Fri, 17 Jan 2025 14:11:28 -0500 Subject: [PATCH] Prettier backfill --- .changeset/changelog-config.js | 24 +- .changeset/config.json | 18 +- .github/pull_request_template.md | 11 +- README.md | 43 +- cline_docs/settings.md | 215 +- jest.config.js | 75 +- .../@modelcontextprotocol/sdk/client/index.js | 22 +- .../@modelcontextprotocol/sdk/client/stdio.js | 30 +- .../@modelcontextprotocol/sdk/index.js | 42 +- .../@modelcontextprotocol/sdk/types.js | 58 +- src/__mocks__/McpHub.ts | 26 +- src/__mocks__/default-shell.js | 14 +- src/__mocks__/delay.js | 6 +- src/__mocks__/globby.js | 12 +- src/__mocks__/os-name.js | 6 +- src/__mocks__/p-wait-for.js | 32 +- src/__mocks__/serialize-error.js | 36 +- src/__mocks__/strip-ansi.js | 8 +- src/__mocks__/vscode.js | 110 +- src/api/providers/__tests__/anthropic.test.ts | 439 ++-- src/api/providers/__tests__/bedrock.test.ts | 437 ++-- src/api/providers/__tests__/deepseek.test.ts | 384 ++-- src/api/providers/__tests__/gemini.test.ts | 368 ++- src/api/providers/__tests__/glama.test.ts | 420 ++-- src/api/providers/__tests__/lmstudio.test.ts | 295 +-- src/api/providers/__tests__/ollama.test.ts | 293 +-- .../providers/__tests__/openai-native.test.ts | 575 ++--- src/api/providers/__tests__/openai.test.ts | 409 ++-- .../providers/__tests__/openrouter.test.ts | 508 ++--- src/api/providers/__tests__/vertex.test.ts | 531 +++-- src/api/providers/__tests__/vscode-lm.test.ts | 372 ++-- src/api/providers/anthropic.ts | 6 +- src/api/providers/bedrock.ts | 504 ++--- src/api/providers/deepseek.ts | 38 +- src/api/providers/glama.ts | 39 +- src/api/providers/lmstudio.ts | 2 +- src/api/providers/ollama.ts | 2 +- src/api/providers/openai-native.ts | 9 +- src/api/providers/openai.ts | 10 +- src/api/providers/openrouter.ts | 26 +- src/api/providers/vertex.ts | 6 +- src/api/providers/vscode-lm.ts | 459 ++-- .../__tests__/bedrock-converse-format.test.ts | 438 ++-- .../transform/__tests__/openai-format.test.ts | 502 +++-- src/api/transform/__tests__/stream.test.ts | 192 +- .../__tests__/vscode-lm-format.test.ts | 358 +-- src/api/transform/bedrock-converse-format.ts | 370 ++-- src/api/transform/vscode-lm-format.ts | 185 +- src/core/Cline.ts | 190 +- src/core/__tests__/Cline.test.ts | 1496 ++++++------- src/core/__tests__/mode-validator.test.ts | 90 +- src/core/config/ConfigManager.ts | 392 ++-- .../config/__tests__/ConfigManager.test.ts | 906 ++++---- src/core/diff/DiffStrategy.ts | 22 +- .../strategies/__tests__/new-unified.test.ts | 447 ++-- .../__tests__/search-replace.test.ts | 1212 +++++----- .../diff/strategies/__tests__/unified.test.ts | 141 +- .../__tests__/edit-strategies.test.ts | 8 +- .../__tests__/search-strategies.test.ts | 368 +-- .../strategies/new-unified/edit-strategies.ts | 62 +- src/core/diff/strategies/new-unified/index.ts | 10 +- .../new-unified/search-strategies.ts | 40 +- src/core/diff/strategies/new-unified/types.ts | 26 +- src/core/diff/strategies/search-replace.ts | 457 ++-- src/core/diff/strategies/unified.ts | 58 +- src/core/diff/types.ts | 54 +- src/core/mentions/__tests__/index.test.ts | 74 +- src/core/mode-validator.ts | 16 +- src/core/prompts/__tests__/system.test.ts | 677 +++--- src/core/prompts/sections/capabilities.ts | 30 +- .../prompts/sections/custom-instructions.ts | 75 +- src/core/prompts/sections/index.ts | 16 +- src/core/prompts/sections/mcp-servers.ts | 79 +- src/core/prompts/sections/objective.ts | 4 +- src/core/prompts/sections/rules.ts | 26 +- src/core/prompts/sections/system-info.ts | 4 +- .../prompts/sections/tool-use-guidelines.ts | 4 +- src/core/prompts/sections/tool-use.ts | 4 +- src/core/prompts/system.ts | 197 +- src/core/prompts/tools/access-mcp-resource.ts | 12 +- .../prompts/tools/ask-followup-question.ts | 4 +- src/core/prompts/tools/attempt-completion.ts | 4 +- src/core/prompts/tools/browser-action.ts | 12 +- src/core/prompts/tools/execute-command.ts | 6 +- src/core/prompts/tools/index.ts | 133 +- .../tools/list-code-definition-names.ts | 6 +- src/core/prompts/tools/list-files.ts | 6 +- src/core/prompts/tools/read-file.ts | 6 +- src/core/prompts/tools/search-files.ts | 4 +- src/core/prompts/tools/types.ts | 18 +- src/core/prompts/tools/use-mcp-tool.ts | 12 +- src/core/prompts/tools/write-to-file.ts | 4 +- src/core/prompts/types.ts | 86 +- src/core/tool-lists.ts | 50 +- src/core/webview/ClineProvider.ts | 429 ++-- .../webview/__tests__/ClineProvider.test.ts | 1910 ++++++++-------- src/extension.ts | 8 +- src/integrations/editor/DiffViewProvider.ts | 10 +- .../editor/__tests__/DiffViewProvider.test.ts | 84 +- .../editor/__tests__/detect-omission.test.ts | 86 +- src/integrations/editor/detect-omission.ts | 18 +- .../misc/__tests__/extract-text.test.ts | 237 +- src/integrations/misc/extract-text.ts | 25 +- src/integrations/misc/open-file.ts | 18 +- src/integrations/terminal/TerminalManager.ts | 4 +- src/integrations/terminal/TerminalRegistry.ts | 4 +- .../__tests__/TerminalProcess.test.ts | 364 +-- .../__tests__/TerminalRegistry.test.ts | 52 +- .../workspace/WorkspaceTracker.ts | 6 +- .../__tests__/WorkspaceTracker.test.ts | 232 +- src/services/browser/BrowserSession.ts | 6 +- src/services/mcp/McpHub.ts | 51 +- src/services/mcp/__tests__/McpHub.test.ts | 506 ++--- .../tree-sitter/__tests__/index.test.ts | 444 ++-- .../__tests__/languageParser.test.ts | 216 +- src/shared/ExtensionMessage.ts | 2 +- src/shared/WebviewMessage.ts | 4 +- .../__tests__/checkExistApiConfig.test.ts | 102 +- .../__tests__/vsCodeSelectorUtils.test.ts | 66 +- src/shared/api.ts | 69 +- src/shared/checkExistApiConfig.ts | 32 +- src/shared/context-mentions.ts | 21 +- src/shared/modes.ts | 278 +-- src/shared/vsCodeSelectorUtils.ts | 13 +- src/test/extension.test.ts | 358 ++- src/test/tsconfig.json | 36 +- src/utils/__tests__/cost.test.ts | 172 +- src/utils/__tests__/enhance-prompt.test.ts | 220 +- src/utils/__tests__/git.test.ts | 578 ++--- src/utils/__tests__/path.test.ts | 224 +- src/utils/enhance-prompt.ts | 42 +- src/utils/git.ts | 58 +- src/utils/sound.ts | 2 +- webview-ui/config-overrides.js | 38 +- .../src/components/chat/Announcement.tsx | 21 +- .../src/components/chat/AutoApproveMenu.tsx | 45 +- .../src/components/chat/BrowserSessionRow.tsx | 11 +- webview-ui/src/components/chat/ChatRow.tsx | 55 +- .../src/components/chat/ChatTextArea.tsx | 173 +- webview-ui/src/components/chat/ChatView.tsx | 121 +- .../src/components/chat/ContextMenu.tsx | 53 +- .../chat/__tests__/AutoApproveMenu.test.tsx | 354 +-- .../chat/__tests__/ChatTextArea.test.tsx | 264 +-- .../__tests__/ChatView.auto-approve.test.tsx | 541 ++--- .../chat/__tests__/ChatView.test.tsx | 1972 +++++++++-------- .../src/components/common/CaretIcon.tsx | 27 +- .../components/common/__mocks__/CodeBlock.tsx | 12 +- .../common/__mocks__/MarkdownBlock.tsx | 12 +- .../src/components/history/HistoryView.tsx | 62 +- .../history/__tests__/HistoryView.test.tsx | 387 ++-- .../src/components/mcp/McpEnabledToggle.tsx | 22 +- webview-ui/src/components/mcp/McpToolRow.tsx | 13 +- webview-ui/src/components/mcp/McpView.tsx | 43 +- .../mcp/__tests__/McpToolRow.test.tsx | 238 +- .../src/components/prompts/PromptsView.tsx | 386 ++-- .../prompts/__tests__/PromptsView.test.tsx | 288 +-- .../components/settings/ApiConfigManager.tsx | 397 ++-- .../src/components/settings/ApiOptions.tsx | 112 +- .../components/settings/GlamaModelPicker.tsx | 20 +- .../components/settings/OpenAiModelPicker.tsx | 34 +- .../settings/OpenRouterModelPicker.tsx | 19 +- .../src/components/settings/SettingsView.tsx | 357 +-- .../__tests__/ApiConfigManager.test.tsx | 238 +- .../settings/__tests__/SettingsView.test.tsx | 578 ++--- .../src/components/welcome/WelcomeView.tsx | 8 +- .../src/context/ExtensionStateContext.tsx | 193 +- .../__tests__/ExtensionStateContext.test.tsx | 106 +- webview-ui/src/services/GitService.ts | 18 +- webview-ui/src/setupTests.ts | 48 +- .../__tests__/command-validation.test.ts | 123 +- .../utils/__tests__/context-mentions.test.ts | 131 +- webview-ui/src/utils/command-validation.ts | 53 +- webview-ui/src/utils/context-mentions.ts | 48 +- webview-ui/src/utils/highlight.ts | 22 +- 174 files changed, 15715 insertions(+), 15428 deletions(-) diff --git a/.changeset/changelog-config.js b/.changeset/changelog-config.js index 125e11c..1e64dbf 100644 --- a/.changeset/changelog-config.js +++ b/.changeset/changelog-config.js @@ -1,20 +1,20 @@ // Half-works to simplify the format but needs 'overwrite_changeset_changelog.py' in GHA to finish formatting const getReleaseLine = async (changeset) => { - const [firstLine] = changeset.summary - .split('\n') - .map(l => l.trim()) - .filter(Boolean); - return `- ${firstLine}`; -}; + const [firstLine] = changeset.summary + .split("\n") + .map((l) => l.trim()) + .filter(Boolean) + return `- ${firstLine}` +} const getDependencyReleaseLine = async () => { - return ''; -}; + return "" +} const changelogFunctions = { - getReleaseLine, - getDependencyReleaseLine, -}; + getReleaseLine, + getDependencyReleaseLine, +} -module.exports = changelogFunctions; \ No newline at end of file +module.exports = changelogFunctions diff --git a/.changeset/config.json b/.changeset/config.json index afe84d3..bcd6eef 100644 --- a/.changeset/config.json +++ b/.changeset/config.json @@ -1,11 +1,11 @@ { - "$schema": "https://unpkg.com/@changesets/config@3.0.4/schema.json", - "changelog": "./changelog-config.js", - "commit": false, - "fixed": [], - "linked": [], - "access": "restricted", - "baseBranch": "main", - "updateInternalDependencies": "patch", - "ignore": [] + "$schema": "https://unpkg.com/@changesets/config@3.0.4/schema.json", + "changelog": "./changelog-config.js", + "commit": false, + "fixed": [], + "linked": [], + "access": "restricted", + "baseBranch": "main", + "updateInternalDependencies": "patch", + "ignore": [] } diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index a689eb9..7ee8bb9 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,28 +1,37 @@ + ## Description ## Type of change + + - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update ## How Has This Been Tested? + ## Checklist: + + - [ ] My code follows the patterns of this project - [ ] I have performed a self-review of my own code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have made corresponding changes to the documentation ## Additional context + ## Related Issues + ## Reviewers - \ No newline at end of file + + diff --git a/README.md b/README.md index f5856c7..2918f27 100644 --- a/README.md +++ b/README.md @@ -10,14 +10,15 @@ Hot off the heels of **v3.0** introducing Code, Architect, and Ask chat modes, o You can now tailor the **role definition** and **custom instructions** for every chat mode to perfectly fit your workflow. Want to adjust Architect mode to focus more on system scalability? Or tweak Ask mode for deeper research queries? Done. Plus, you can define these via **mode-specific `.clinerules-[mode]` files**. You’ll find all of this in the new **Prompts** tab in the top menu. -The second big feature in this release is a complete revamp of **prompt enhancements**. This feature helps you craft messages to get even better results from Cline. Here’s what’s new: -- Works with **any provider** and API configuration, not just OpenRouter. -- Fully customizable prompts to match your unique needs. +The second big feature in this release is a complete revamp of **prompt enhancements**. This feature helps you craft messages to get even better results from Cline. Here’s what’s new: + +- Works with **any provider** and API configuration, not just OpenRouter. +- Fully customizable prompts to match your unique needs. - Same simple workflow: just hit the ✨ **Enhance Prompt** button in the chat input to try it out. Whether you’re using GPT-4, other APIs, or switching configurations, this gives you total control over how your prompts are optimized. -As always, we’d love to hear your thoughts and ideas! What features do you want to see in **v3.2**? Drop by https://www.reddit.com/r/roocline and join the discussion - we're building Roo Cline together. 🚀 +As always, we’d love to hear your thoughts and ideas! What features do you want to see in **v3.2**? Drop by https://www.reddit.com/r/roocline and join the discussion - we're building Roo Cline together. 🚀 ## New in 3.0 - Chat Modes! @@ -33,6 +34,7 @@ You can now choose between different prompts for Roo Cline to better suit your w It’s super simple! There’s a dropdown in the bottom left of the chat input to switch modes. Right next to it, you’ll find a way to switch between the API configuration profiles associated with the current mode (configured on the settings screen). **Why Add This?** + - It keeps Cline from being overly eager to jump into solving problems when you just want to think or ask questions. - Each mode remembers the API configuration you last used with it. For example, you can use more thoughtful models like OpenAI o1 for Architect and Ask, while sticking with Sonnet or DeepSeek for coding tasks. - It builds on research suggesting better results when separating "thinking" from "coding," explained well in this very thoughtful [article](https://aider.chat/2024/09/26/architect.html) from aider. @@ -50,25 +52,27 @@ Here's an example of Roo-Cline autonomously creating a snake game with "Always a https://github.com/user-attachments/assets/c2bb31dc-e9b2-4d73-885d-17f1471a4987 ## Contributing + To contribute to the project, start by exploring [open issues](https://github.com/RooVetGit/Roo-Cline/issues) or checking our [feature request board](https://github.com/RooVetGit/Roo-Cline/discussions/categories/feature-requests). We'd also love to have you join the [Roo Cline Reddit](https://www.reddit.com/r/roocline/) to share ideas and connect with other contributors. ### Local Setup 1. Install dependencies: - ```bash - npm run install:all - ``` + + ```bash + npm run install:all + ``` 2. Build the VSIX file: - ```bash - npm run build - ``` + ```bash + npm run build + ``` 3. The new VSIX file will be created in the `bin/` directory 4. Install the extension from the VSIX file as described below: - - **Option 1:** Drag and drop the `.vsix` file into your VSCode-compatible editor's Extensions panel (Cmd/Ctrl+Shift+X). + - **Option 1:** Drag and drop the `.vsix` file into your VSCode-compatible editor's Extensions panel (Cmd/Ctrl+Shift+X). - - **Option 2:** Install the plugin using the CLI, make sure you have your VSCode-compatible CLI installed and in your `PATH` variable. Cursor example: `export PATH="$PATH:/Applications/Cursor.app/Contents/MacOS"` + - **Option 2:** Install the plugin using the CLI, make sure you have your VSCode-compatible CLI installed and in your `PATH` variable. Cursor example: `export PATH="$PATH:/Applications/Cursor.app/Contents/MacOS"` ```bash # Ex: cursor --install-extension bin/roo-cline-2.0.1.vsix @@ -83,16 +87,17 @@ We use [changesets](https://github.com/changesets/changesets) for versioning and 1. Create a PR with your changes 2. Create a new changeset by running `npm run changeset` - - Select the appropriate kind of change - `patch` for bug fixes, `minor` for new features, or `major` for breaking changes - - Write a clear description of your changes that will be included in the changelog + - Select the appropriate kind of change - `patch` for bug fixes, `minor` for new features, or `major` for breaking changes + - Write a clear description of your changes that will be included in the changelog 3. Get the PR approved and pass all checks 4. Merge it Once your merge is successful: + - The release workflow will automatically create a new "Changeset version bump" PR - This PR will: - - Update the version based on your changeset - - Update the `CHANGELOG.md` file + - Update the version based on your changeset + - Update the `CHANGELOG.md` file - Once the PR is approved and merged, a new version will be published --- @@ -193,9 +198,9 @@ Try asking Cline to "test the app", and watch as he runs a command like `npm run Thanks to the [Model Context Protocol](https://github.com/modelcontextprotocol), Cline can extend his capabilities through custom tools. While you can use [community-made servers](https://github.com/modelcontextprotocol/servers), Cline can instead create and install tools tailored to your specific workflow. Just ask Cline to "add a tool" and he will handle everything, from creating a new MCP server to installing it into the extension. These custom tools then become part of Cline's toolkit, ready to use in future tasks. -- "add a tool that fetches Jira tickets": Retrieve ticket ACs and put Cline to work -- "add a tool that manages AWS EC2s": Check server metrics and scale instances up or down -- "add a tool that pulls the latest PagerDuty incidents": Fetch details and ask Cline to fix bugs +- "add a tool that fetches Jira tickets": Retrieve ticket ACs and put Cline to work +- "add a tool that manages AWS EC2s": Check server metrics and scale instances up or down +- "add a tool that pulls the latest PagerDuty incidents": Fetch details and ask Cline to fix bugs diff --git a/cline_docs/settings.md b/cline_docs/settings.md index c9c4702..f4b0682 100644 --- a/cline_docs/settings.md +++ b/cline_docs/settings.md @@ -1,137 +1,146 @@ - ## For All Settings 1. Add the setting to ExtensionMessage.ts: - - Add the setting to the ExtensionState interface - - Make it required if it has a default value, optional if it can be undefined - - Example: `preferredLanguage: string` + + - Add the setting to the ExtensionState interface + - Make it required if it has a default value, optional if it can be undefined + - Example: `preferredLanguage: string` 2. Add test coverage: - - Add the setting to mockState in ClineProvider.test.ts - - Add test cases for setting persistence and state updates - - Ensure all tests pass before submitting changes + - Add the setting to mockState in ClineProvider.test.ts + - Add test cases for setting persistence and state updates + - Ensure all tests pass before submitting changes ## For Checkbox Settings 1. Add the message type to WebviewMessage.ts: - - Add the setting name to the WebviewMessage type's type union - - Example: `| "multisearchDiffEnabled"` + + - Add the setting name to the WebviewMessage type's type union + - Example: `| "multisearchDiffEnabled"` 2. Add the setting to ExtensionStateContext.tsx: - - Add the setting to the ExtensionStateContextType interface - - Add the setter function to the interface - - Add the setting to the initial state in useState - - Add the setting to the contextValue object - - Example: - ```typescript - interface ExtensionStateContextType { - multisearchDiffEnabled: boolean; - setMultisearchDiffEnabled: (value: boolean) => void; - } - ``` + + - Add the setting to the ExtensionStateContextType interface + - Add the setter function to the interface + - Add the setting to the initial state in useState + - Add the setting to the contextValue object + - Example: + ```typescript + interface ExtensionStateContextType { + multisearchDiffEnabled: boolean + setMultisearchDiffEnabled: (value: boolean) => void + } + ``` 3. Add the setting to ClineProvider.ts: - - Add the setting name to the GlobalStateKey type union - - Add the setting to the Promise.all array in getState - - Add the setting to the return value in getState with a default value - - Add the setting to the destructured variables in getStateToPostToWebview - - Add the setting to the return value in getStateToPostToWebview - - Add a case in setWebviewMessageListener to handle the setting's message type - - Example: - ```typescript - case "multisearchDiffEnabled": - await this.updateGlobalState("multisearchDiffEnabled", message.bool) - await this.postStateToWebview() - break - ``` + + - Add the setting name to the GlobalStateKey type union + - Add the setting to the Promise.all array in getState + - Add the setting to the return value in getState with a default value + - Add the setting to the destructured variables in getStateToPostToWebview + - Add the setting to the return value in getStateToPostToWebview + - Add a case in setWebviewMessageListener to handle the setting's message type + - Example: + ```typescript + case "multisearchDiffEnabled": + await this.updateGlobalState("multisearchDiffEnabled", message.bool) + await this.postStateToWebview() + break + ``` 4. Add the checkbox UI to SettingsView.tsx: - - Import the setting and its setter from ExtensionStateContext - - Add the VSCodeCheckbox component with the setting's state and onChange handler - - Add appropriate labels and description text - - Example: - ```typescript - setMultisearchDiffEnabled(e.target.checked)} - > - Enable multi-search diff matching - - ``` + + - Import the setting and its setter from ExtensionStateContext + - Add the VSCodeCheckbox component with the setting's state and onChange handler + - Add appropriate labels and description text + - Example: + ```typescript + setMultisearchDiffEnabled(e.target.checked)} + > + Enable multi-search diff matching + + ``` 5. Add the setting to handleSubmit in SettingsView.tsx: - - Add a vscode.postMessage call to send the setting's value when clicking Done - - Example: - ```typescript - vscode.postMessage({ type: "multisearchDiffEnabled", bool: multisearchDiffEnabled }) - ``` + - Add a vscode.postMessage call to send the setting's value when clicking Done + - Example: + ```typescript + vscode.postMessage({ type: "multisearchDiffEnabled", bool: multisearchDiffEnabled }) + ``` ## For Select/Dropdown Settings 1. Add the message type to WebviewMessage.ts: - - Add the setting name to the WebviewMessage type's type union - - Example: `| "preferredLanguage"` + + - Add the setting name to the WebviewMessage type's type union + - Example: `| "preferredLanguage"` 2. Add the setting to ExtensionStateContext.tsx: - - Add the setting to the ExtensionStateContextType interface - - Add the setter function to the interface - - Add the setting to the initial state in useState with a default value - - Add the setting to the contextValue object - - Example: - ```typescript - interface ExtensionStateContextType { - preferredLanguage: string; - setPreferredLanguage: (value: string) => void; - } - ``` + + - Add the setting to the ExtensionStateContextType interface + - Add the setter function to the interface + - Add the setting to the initial state in useState with a default value + - Add the setting to the contextValue object + - Example: + ```typescript + interface ExtensionStateContextType { + preferredLanguage: string + setPreferredLanguage: (value: string) => void + } + ``` 3. Add the setting to ClineProvider.ts: - - Add the setting name to the GlobalStateKey type union - - Add the setting to the Promise.all array in getState - - Add the setting to the return value in getState with a default value - - Add the setting to the destructured variables in getStateToPostToWebview - - Add the setting to the return value in getStateToPostToWebview - - Add a case in setWebviewMessageListener to handle the setting's message type - - Example: - ```typescript - case "preferredLanguage": - await this.updateGlobalState("preferredLanguage", message.text) - await this.postStateToWebview() - break - ``` + + - Add the setting name to the GlobalStateKey type union + - Add the setting to the Promise.all array in getState + - Add the setting to the return value in getState with a default value + - Add the setting to the destructured variables in getStateToPostToWebview + - Add the setting to the return value in getStateToPostToWebview + - Add a case in setWebviewMessageListener to handle the setting's message type + - Example: + ```typescript + case "preferredLanguage": + await this.updateGlobalState("preferredLanguage", message.text) + await this.postStateToWebview() + break + ``` 4. Add the select UI to SettingsView.tsx: - - Import the setting and its setter from ExtensionStateContext - - Add the select element with appropriate styling to match VSCode's theme - - Add options for the dropdown - - Add appropriate labels and description text - - Example: - ```typescript - - ``` + + - Import the setting and its setter from ExtensionStateContext + - Add the select element with appropriate styling to match VSCode's theme + - Add options for the dropdown + - Add appropriate labels and description text + - Example: + ```typescript + + ``` 5. Add the setting to handleSubmit in SettingsView.tsx: - - Add a vscode.postMessage call to send the setting's value when clicking Done - - Example: - ```typescript - vscode.postMessage({ type: "preferredLanguage", text: preferredLanguage }) - ``` + - Add a vscode.postMessage call to send the setting's value when clicking Done + - Example: + ```typescript + vscode.postMessage({ type: "preferredLanguage", text: preferredLanguage }) + ``` These steps ensure that: + - The setting's state is properly typed throughout the application - The setting persists between sessions - The setting's value is properly synchronized between the webview and extension diff --git a/jest.config.js b/jest.config.js index a912315..02a4a78 100644 --- a/jest.config.js +++ b/jest.config.js @@ -1,41 +1,40 @@ /** @type {import('ts-jest').JestConfigWithTsJest} */ module.exports = { - preset: 'ts-jest', - testEnvironment: 'node', - moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'], - transform: { - '^.+\\.tsx?$': ['ts-jest', { - tsconfig: { - "module": "CommonJS", - "moduleResolution": "node", - "esModuleInterop": true, - "allowJs": true - }, - diagnostics: false, - isolatedModules: true - }] - }, - testMatch: ['**/__tests__/**/*.test.ts'], - moduleNameMapper: { - '^vscode$': '/src/__mocks__/vscode.js', - '@modelcontextprotocol/sdk$': '/src/__mocks__/@modelcontextprotocol/sdk/index.js', - '@modelcontextprotocol/sdk/(.*)': '/src/__mocks__/@modelcontextprotocol/sdk/$1', - '^delay$': '/src/__mocks__/delay.js', - '^p-wait-for$': '/src/__mocks__/p-wait-for.js', - '^globby$': '/src/__mocks__/globby.js', - '^serialize-error$': '/src/__mocks__/serialize-error.js', - '^strip-ansi$': '/src/__mocks__/strip-ansi.js', - '^default-shell$': '/src/__mocks__/default-shell.js', - '^os-name$': '/src/__mocks__/os-name.js' - }, - transformIgnorePatterns: [ - 'node_modules/(?!(@modelcontextprotocol|delay|p-wait-for|globby|serialize-error|strip-ansi|default-shell|os-name)/)' - ], - modulePathIgnorePatterns: [ - '.vscode-test' - ], - reporters: [ - ["jest-simple-dot-reporter", {}] - ], - setupFiles: [] + preset: "ts-jest", + testEnvironment: "node", + moduleFileExtensions: ["ts", "tsx", "js", "jsx", "json", "node"], + transform: { + "^.+\\.tsx?$": [ + "ts-jest", + { + tsconfig: { + module: "CommonJS", + moduleResolution: "node", + esModuleInterop: true, + allowJs: true, + }, + diagnostics: false, + isolatedModules: true, + }, + ], + }, + testMatch: ["**/__tests__/**/*.test.ts"], + moduleNameMapper: { + "^vscode$": "/src/__mocks__/vscode.js", + "@modelcontextprotocol/sdk$": "/src/__mocks__/@modelcontextprotocol/sdk/index.js", + "@modelcontextprotocol/sdk/(.*)": "/src/__mocks__/@modelcontextprotocol/sdk/$1", + "^delay$": "/src/__mocks__/delay.js", + "^p-wait-for$": "/src/__mocks__/p-wait-for.js", + "^globby$": "/src/__mocks__/globby.js", + "^serialize-error$": "/src/__mocks__/serialize-error.js", + "^strip-ansi$": "/src/__mocks__/strip-ansi.js", + "^default-shell$": "/src/__mocks__/default-shell.js", + "^os-name$": "/src/__mocks__/os-name.js", + }, + transformIgnorePatterns: [ + "node_modules/(?!(@modelcontextprotocol|delay|p-wait-for|globby|serialize-error|strip-ansi|default-shell|os-name)/)", + ], + modulePathIgnorePatterns: [".vscode-test"], + reporters: [["jest-simple-dot-reporter", {}]], + setupFiles: [], } diff --git a/src/__mocks__/@modelcontextprotocol/sdk/client/index.js b/src/__mocks__/@modelcontextprotocol/sdk/client/index.js index 6ed5825..cfba5c4 100644 --- a/src/__mocks__/@modelcontextprotocol/sdk/client/index.js +++ b/src/__mocks__/@modelcontextprotocol/sdk/client/index.js @@ -1,17 +1,17 @@ class Client { - constructor() { - this.request = jest.fn() - } + constructor() { + this.request = jest.fn() + } - connect() { - return Promise.resolve() - } + connect() { + return Promise.resolve() + } - close() { - return Promise.resolve() - } + close() { + return Promise.resolve() + } } module.exports = { - Client -} \ No newline at end of file + Client, +} diff --git a/src/__mocks__/@modelcontextprotocol/sdk/client/stdio.js b/src/__mocks__/@modelcontextprotocol/sdk/client/stdio.js index afa42ad..39e4cb1 100644 --- a/src/__mocks__/@modelcontextprotocol/sdk/client/stdio.js +++ b/src/__mocks__/@modelcontextprotocol/sdk/client/stdio.js @@ -1,22 +1,22 @@ class StdioClientTransport { - constructor() { - this.start = jest.fn().mockResolvedValue(undefined) - this.close = jest.fn().mockResolvedValue(undefined) - this.stderr = { - on: jest.fn() - } - } + constructor() { + this.start = jest.fn().mockResolvedValue(undefined) + this.close = jest.fn().mockResolvedValue(undefined) + this.stderr = { + on: jest.fn(), + } + } } class StdioServerParameters { - constructor() { - this.command = '' - this.args = [] - this.env = {} - } + constructor() { + this.command = "" + this.args = [] + this.env = {} + } } module.exports = { - StdioClientTransport, - StdioServerParameters -} \ No newline at end of file + StdioClientTransport, + StdioServerParameters, +} diff --git a/src/__mocks__/@modelcontextprotocol/sdk/index.js b/src/__mocks__/@modelcontextprotocol/sdk/index.js index c6e43e6..4a5395a 100644 --- a/src/__mocks__/@modelcontextprotocol/sdk/index.js +++ b/src/__mocks__/@modelcontextprotocol/sdk/index.js @@ -1,24 +1,24 @@ -const { Client } = require('./client/index.js') -const { StdioClientTransport, StdioServerParameters } = require('./client/stdio.js') +const { Client } = require("./client/index.js") +const { StdioClientTransport, StdioServerParameters } = require("./client/stdio.js") const { - CallToolResultSchema, - ListToolsResultSchema, - ListResourcesResultSchema, - ListResourceTemplatesResultSchema, - ReadResourceResultSchema, - ErrorCode, - McpError -} = require('./types.js') + CallToolResultSchema, + ListToolsResultSchema, + ListResourcesResultSchema, + ListResourceTemplatesResultSchema, + ReadResourceResultSchema, + ErrorCode, + McpError, +} = require("./types.js") module.exports = { - Client, - StdioClientTransport, - StdioServerParameters, - CallToolResultSchema, - ListToolsResultSchema, - ListResourcesResultSchema, - ListResourceTemplatesResultSchema, - ReadResourceResultSchema, - ErrorCode, - McpError -} \ No newline at end of file + Client, + StdioClientTransport, + StdioServerParameters, + CallToolResultSchema, + ListToolsResultSchema, + ListResourcesResultSchema, + ListResourceTemplatesResultSchema, + ReadResourceResultSchema, + ErrorCode, + McpError, +} diff --git a/src/__mocks__/@modelcontextprotocol/sdk/types.js b/src/__mocks__/@modelcontextprotocol/sdk/types.js index a2b3ea1..2e96448 100644 --- a/src/__mocks__/@modelcontextprotocol/sdk/types.js +++ b/src/__mocks__/@modelcontextprotocol/sdk/types.js @@ -1,51 +1,51 @@ const CallToolResultSchema = { - parse: jest.fn().mockReturnValue({}) + parse: jest.fn().mockReturnValue({}), } const ListToolsResultSchema = { - parse: jest.fn().mockReturnValue({ - tools: [] - }) + parse: jest.fn().mockReturnValue({ + tools: [], + }), } const ListResourcesResultSchema = { - parse: jest.fn().mockReturnValue({ - resources: [] - }) + parse: jest.fn().mockReturnValue({ + resources: [], + }), } const ListResourceTemplatesResultSchema = { - parse: jest.fn().mockReturnValue({ - resourceTemplates: [] - }) + parse: jest.fn().mockReturnValue({ + resourceTemplates: [], + }), } const ReadResourceResultSchema = { - parse: jest.fn().mockReturnValue({ - contents: [] - }) + parse: jest.fn().mockReturnValue({ + contents: [], + }), } const ErrorCode = { - InvalidRequest: 'InvalidRequest', - MethodNotFound: 'MethodNotFound', - InvalidParams: 'InvalidParams', - InternalError: 'InternalError' + InvalidRequest: "InvalidRequest", + MethodNotFound: "MethodNotFound", + InvalidParams: "InvalidParams", + InternalError: "InternalError", } class McpError extends Error { - constructor(code, message) { - super(message) - this.code = code - } + constructor(code, message) { + super(message) + this.code = code + } } module.exports = { - CallToolResultSchema, - ListToolsResultSchema, - ListResourcesResultSchema, - ListResourceTemplatesResultSchema, - ReadResourceResultSchema, - ErrorCode, - McpError -} \ No newline at end of file + CallToolResultSchema, + ListToolsResultSchema, + ListResourcesResultSchema, + ListResourceTemplatesResultSchema, + ReadResourceResultSchema, + ErrorCode, + McpError, +} diff --git a/src/__mocks__/McpHub.ts b/src/__mocks__/McpHub.ts index d39b2d7..7aef91b 100644 --- a/src/__mocks__/McpHub.ts +++ b/src/__mocks__/McpHub.ts @@ -1,17 +1,17 @@ export class McpHub { - connections = [] - isConnecting = false + connections = [] + isConnecting = false - constructor() { - this.toggleToolAlwaysAllow = jest.fn() - this.callTool = jest.fn() - } + constructor() { + this.toggleToolAlwaysAllow = jest.fn() + this.callTool = jest.fn() + } - async toggleToolAlwaysAllow(serverName: string, toolName: string, shouldAllow: boolean): Promise { - return Promise.resolve() - } + async toggleToolAlwaysAllow(serverName: string, toolName: string, shouldAllow: boolean): Promise { + return Promise.resolve() + } - async callTool(serverName: string, toolName: string, toolArguments?: Record): Promise { - return Promise.resolve({ result: 'success' }) - } -} \ No newline at end of file + async callTool(serverName: string, toolName: string, toolArguments?: Record): Promise { + return Promise.resolve({ result: "success" }) + } +} diff --git a/src/__mocks__/default-shell.js b/src/__mocks__/default-shell.js index f03e4fb..83ad760 100644 --- a/src/__mocks__/default-shell.js +++ b/src/__mocks__/default-shell.js @@ -1,12 +1,12 @@ // Mock default shell based on platform -const os = require('os'); +const os = require("os") -let defaultShell; -if (os.platform() === 'win32') { - defaultShell = 'cmd.exe'; +let defaultShell +if (os.platform() === "win32") { + defaultShell = "cmd.exe" } else { - defaultShell = '/bin/bash'; + defaultShell = "/bin/bash" } -module.exports = defaultShell; -module.exports.default = defaultShell; \ No newline at end of file +module.exports = defaultShell +module.exports.default = defaultShell diff --git a/src/__mocks__/delay.js b/src/__mocks__/delay.js index 9ecb361..35cba90 100644 --- a/src/__mocks__/delay.js +++ b/src/__mocks__/delay.js @@ -1,6 +1,6 @@ function delay(ms) { - return new Promise(resolve => setTimeout(resolve, ms)); + return new Promise((resolve) => setTimeout(resolve, ms)) } -module.exports = delay; -module.exports.default = delay; \ No newline at end of file +module.exports = delay +module.exports.default = delay diff --git a/src/__mocks__/globby.js b/src/__mocks__/globby.js index 2584cd1..493487e 100644 --- a/src/__mocks__/globby.js +++ b/src/__mocks__/globby.js @@ -1,10 +1,10 @@ function globby(patterns, options) { - return Promise.resolve([]); + return Promise.resolve([]) } -globby.sync = function(patterns, options) { - return []; -}; +globby.sync = function (patterns, options) { + return [] +} -module.exports = globby; -module.exports.default = globby; \ No newline at end of file +module.exports = globby +module.exports.default = globby diff --git a/src/__mocks__/os-name.js b/src/__mocks__/os-name.js index e760ff3..a9b36f8 100644 --- a/src/__mocks__/os-name.js +++ b/src/__mocks__/os-name.js @@ -1,6 +1,6 @@ function osName() { - return 'macOS'; + return "macOS" } -module.exports = osName; -module.exports.default = osName; \ No newline at end of file +module.exports = osName +module.exports.default = osName diff --git a/src/__mocks__/p-wait-for.js b/src/__mocks__/p-wait-for.js index f1e6a68..4fb3617 100644 --- a/src/__mocks__/p-wait-for.js +++ b/src/__mocks__/p-wait-for.js @@ -1,20 +1,20 @@ function pWaitFor(condition, options = {}) { - return new Promise((resolve, reject) => { - const interval = setInterval(() => { - if (condition()) { - clearInterval(interval); - resolve(); - } - }, options.interval || 20); + return new Promise((resolve, reject) => { + const interval = setInterval(() => { + if (condition()) { + clearInterval(interval) + resolve() + } + }, options.interval || 20) - if (options.timeout) { - setTimeout(() => { - clearInterval(interval); - reject(new Error('Timed out')); - }, options.timeout); - } - }); + if (options.timeout) { + setTimeout(() => { + clearInterval(interval) + reject(new Error("Timed out")) + }, options.timeout) + } + }) } -module.exports = pWaitFor; -module.exports.default = pWaitFor; \ No newline at end of file +module.exports = pWaitFor +module.exports.default = pWaitFor diff --git a/src/__mocks__/serialize-error.js b/src/__mocks__/serialize-error.js index bf01dc1..66c8fdf 100644 --- a/src/__mocks__/serialize-error.js +++ b/src/__mocks__/serialize-error.js @@ -1,25 +1,25 @@ function serializeError(error) { - if (error instanceof Error) { - return { - name: error.name, - message: error.message, - stack: error.stack - }; - } - return error; + if (error instanceof Error) { + return { + name: error.name, + message: error.message, + stack: error.stack, + } + } + return error } function deserializeError(errorData) { - if (errorData && typeof errorData === 'object') { - const error = new Error(errorData.message); - error.name = errorData.name; - error.stack = errorData.stack; - return error; - } - return errorData; + if (errorData && typeof errorData === "object") { + const error = new Error(errorData.message) + error.name = errorData.name + error.stack = errorData.stack + return error + } + return errorData } module.exports = { - serializeError, - deserializeError -}; \ No newline at end of file + serializeError, + deserializeError, +} diff --git a/src/__mocks__/strip-ansi.js b/src/__mocks__/strip-ansi.js index bf7aff9..dde0687 100644 --- a/src/__mocks__/strip-ansi.js +++ b/src/__mocks__/strip-ansi.js @@ -1,7 +1,7 @@ function stripAnsi(string) { - // Simple mock that just returns the input string - return string; + // Simple mock that just returns the input string + return string } -module.exports = stripAnsi; -module.exports.default = stripAnsi; \ No newline at end of file +module.exports = stripAnsi +module.exports.default = stripAnsi diff --git a/src/__mocks__/vscode.js b/src/__mocks__/vscode.js index 23f3ae5..2e8ed72 100644 --- a/src/__mocks__/vscode.js +++ b/src/__mocks__/vscode.js @@ -1,57 +1,57 @@ const vscode = { - window: { - showInformationMessage: jest.fn(), - showErrorMessage: jest.fn(), - createTextEditorDecorationType: jest.fn().mockReturnValue({ - dispose: jest.fn() - }) - }, - workspace: { - onDidSaveTextDocument: jest.fn() - }, - Disposable: class { - dispose() {} - }, - Uri: { - file: (path) => ({ - fsPath: path, - scheme: 'file', - authority: '', - path: path, - query: '', - fragment: '', - with: jest.fn(), - toJSON: jest.fn() - }) - }, - EventEmitter: class { - constructor() { - this.event = jest.fn(); - this.fire = jest.fn(); - } - }, - ConfigurationTarget: { - Global: 1, - Workspace: 2, - WorkspaceFolder: 3 - }, - Position: class { - constructor(line, character) { - this.line = line; - this.character = character; - } - }, - Range: class { - constructor(startLine, startCharacter, endLine, endCharacter) { - this.start = new vscode.Position(startLine, startCharacter); - this.end = new vscode.Position(endLine, endCharacter); - } - }, - ThemeColor: class { - constructor(id) { - this.id = id; - } - } -}; + window: { + showInformationMessage: jest.fn(), + showErrorMessage: jest.fn(), + createTextEditorDecorationType: jest.fn().mockReturnValue({ + dispose: jest.fn(), + }), + }, + workspace: { + onDidSaveTextDocument: jest.fn(), + }, + Disposable: class { + dispose() {} + }, + Uri: { + file: (path) => ({ + fsPath: path, + scheme: "file", + authority: "", + path: path, + query: "", + fragment: "", + with: jest.fn(), + toJSON: jest.fn(), + }), + }, + EventEmitter: class { + constructor() { + this.event = jest.fn() + this.fire = jest.fn() + } + }, + ConfigurationTarget: { + Global: 1, + Workspace: 2, + WorkspaceFolder: 3, + }, + Position: class { + constructor(line, character) { + this.line = line + this.character = character + } + }, + Range: class { + constructor(startLine, startCharacter, endLine, endCharacter) { + this.start = new vscode.Position(startLine, startCharacter) + this.end = new vscode.Position(endLine, endCharacter) + } + }, + ThemeColor: class { + constructor(id) { + this.id = id + } + }, +} -module.exports = vscode; \ No newline at end of file +module.exports = vscode diff --git a/src/api/providers/__tests__/anthropic.test.ts b/src/api/providers/__tests__/anthropic.test.ts index d0357d7..df0050a 100644 --- a/src/api/providers/__tests__/anthropic.test.ts +++ b/src/api/providers/__tests__/anthropic.test.ts @@ -1,239 +1,238 @@ -import { AnthropicHandler } from '../anthropic'; -import { ApiHandlerOptions } from '../../../shared/api'; -import { ApiStream } from '../../transform/stream'; -import { Anthropic } from '@anthropic-ai/sdk'; +import { AnthropicHandler } from "../anthropic" +import { ApiHandlerOptions } from "../../../shared/api" +import { ApiStream } from "../../transform/stream" +import { Anthropic } from "@anthropic-ai/sdk" // Mock Anthropic client -const mockBetaCreate = jest.fn(); -const mockCreate = jest.fn(); -jest.mock('@anthropic-ai/sdk', () => { - return { - Anthropic: jest.fn().mockImplementation(() => ({ - beta: { - promptCaching: { - messages: { - create: mockBetaCreate.mockImplementation(async () => ({ - async *[Symbol.asyncIterator]() { - yield { - type: 'message_start', - message: { - usage: { - input_tokens: 100, - output_tokens: 50, - cache_creation_input_tokens: 20, - cache_read_input_tokens: 10 - } - } - }; - yield { - type: 'content_block_start', - index: 0, - content_block: { - type: 'text', - text: 'Hello' - } - }; - yield { - type: 'content_block_delta', - delta: { - type: 'text_delta', - text: ' world' - } - }; - } - })) - } - } - }, - messages: { - create: mockCreate.mockImplementation(async (options) => { - if (!options.stream) { - return { - id: 'test-completion', - content: [ - { type: 'text', text: 'Test response' } - ], - role: 'assistant', - model: options.model, - usage: { - input_tokens: 10, - output_tokens: 5 - } - } - } - return { - async *[Symbol.asyncIterator]() { - yield { - type: 'message_start', - message: { - usage: { - input_tokens: 10, - output_tokens: 5 - } - } - } - yield { - type: 'content_block_start', - content_block: { - type: 'text', - text: 'Test response' - } - } - } - } - }) - } - })) - }; -}); +const mockBetaCreate = jest.fn() +const mockCreate = jest.fn() +jest.mock("@anthropic-ai/sdk", () => { + return { + Anthropic: jest.fn().mockImplementation(() => ({ + beta: { + promptCaching: { + messages: { + create: mockBetaCreate.mockImplementation(async () => ({ + async *[Symbol.asyncIterator]() { + yield { + type: "message_start", + message: { + usage: { + input_tokens: 100, + output_tokens: 50, + cache_creation_input_tokens: 20, + cache_read_input_tokens: 10, + }, + }, + } + yield { + type: "content_block_start", + index: 0, + content_block: { + type: "text", + text: "Hello", + }, + } + yield { + type: "content_block_delta", + delta: { + type: "text_delta", + text: " world", + }, + } + }, + })), + }, + }, + }, + messages: { + create: mockCreate.mockImplementation(async (options) => { + if (!options.stream) { + return { + id: "test-completion", + content: [{ type: "text", text: "Test response" }], + role: "assistant", + model: options.model, + usage: { + input_tokens: 10, + output_tokens: 5, + }, + } + } + return { + async *[Symbol.asyncIterator]() { + yield { + type: "message_start", + message: { + usage: { + input_tokens: 10, + output_tokens: 5, + }, + }, + } + yield { + type: "content_block_start", + content_block: { + type: "text", + text: "Test response", + }, + } + }, + } + }), + }, + })), + } +}) -describe('AnthropicHandler', () => { - let handler: AnthropicHandler; - let mockOptions: ApiHandlerOptions; +describe("AnthropicHandler", () => { + let handler: AnthropicHandler + let mockOptions: ApiHandlerOptions - beforeEach(() => { - mockOptions = { - apiKey: 'test-api-key', - apiModelId: 'claude-3-5-sonnet-20241022' - }; - handler = new AnthropicHandler(mockOptions); - mockBetaCreate.mockClear(); - mockCreate.mockClear(); - }); + beforeEach(() => { + mockOptions = { + apiKey: "test-api-key", + apiModelId: "claude-3-5-sonnet-20241022", + } + handler = new AnthropicHandler(mockOptions) + mockBetaCreate.mockClear() + mockCreate.mockClear() + }) - describe('constructor', () => { - it('should initialize with provided options', () => { - expect(handler).toBeInstanceOf(AnthropicHandler); - expect(handler.getModel().id).toBe(mockOptions.apiModelId); - }); + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(AnthropicHandler) + expect(handler.getModel().id).toBe(mockOptions.apiModelId) + }) - it('should initialize with undefined API key', () => { - // The SDK will handle API key validation, so we just verify it initializes - const handlerWithoutKey = new AnthropicHandler({ - ...mockOptions, - apiKey: undefined - }); - expect(handlerWithoutKey).toBeInstanceOf(AnthropicHandler); - }); + it("should initialize with undefined API key", () => { + // The SDK will handle API key validation, so we just verify it initializes + const handlerWithoutKey = new AnthropicHandler({ + ...mockOptions, + apiKey: undefined, + }) + expect(handlerWithoutKey).toBeInstanceOf(AnthropicHandler) + }) - it('should use custom base URL if provided', () => { - const customBaseUrl = 'https://custom.anthropic.com'; - const handlerWithCustomUrl = new AnthropicHandler({ - ...mockOptions, - anthropicBaseUrl: customBaseUrl - }); - expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler); - }); - }); + it("should use custom base URL if provided", () => { + const customBaseUrl = "https://custom.anthropic.com" + const handlerWithCustomUrl = new AnthropicHandler({ + ...mockOptions, + anthropicBaseUrl: customBaseUrl, + }) + expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler) + }) + }) - describe('createMessage', () => { - const systemPrompt = 'You are a helpful assistant.'; - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: 'user', - content: [{ - type: 'text' as const, - text: 'Hello!' - }] - } - ]; + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text" as const, + text: "Hello!", + }, + ], + }, + ] - it('should handle prompt caching for supported models', async () => { - const stream = handler.createMessage(systemPrompt, [ - { - role: 'user', - content: [{ type: 'text' as const, text: 'First message' }] - }, - { - role: 'assistant', - content: [{ type: 'text' as const, text: 'Response' }] - }, - { - role: 'user', - content: [{ type: 'text' as const, text: 'Second message' }] - } - ]); + it("should handle prompt caching for supported models", async () => { + const stream = handler.createMessage(systemPrompt, [ + { + role: "user", + content: [{ type: "text" as const, text: "First message" }], + }, + { + role: "assistant", + content: [{ type: "text" as const, text: "Response" }], + }, + { + role: "user", + content: [{ type: "text" as const, text: "Second message" }], + }, + ]) - const chunks: any[] = []; - for await (const chunk of stream) { - chunks.push(chunk); - } + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } - // Verify usage information - const usageChunk = chunks.find(chunk => chunk.type === 'usage'); - expect(usageChunk).toBeDefined(); - expect(usageChunk?.inputTokens).toBe(100); - expect(usageChunk?.outputTokens).toBe(50); - expect(usageChunk?.cacheWriteTokens).toBe(20); - expect(usageChunk?.cacheReadTokens).toBe(10); + // Verify usage information + const usageChunk = chunks.find((chunk) => chunk.type === "usage") + expect(usageChunk).toBeDefined() + expect(usageChunk?.inputTokens).toBe(100) + expect(usageChunk?.outputTokens).toBe(50) + expect(usageChunk?.cacheWriteTokens).toBe(20) + expect(usageChunk?.cacheReadTokens).toBe(10) - // Verify text content - const textChunks = chunks.filter(chunk => chunk.type === 'text'); - expect(textChunks).toHaveLength(2); - expect(textChunks[0].text).toBe('Hello'); - expect(textChunks[1].text).toBe(' world'); + // Verify text content + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(2) + expect(textChunks[0].text).toBe("Hello") + expect(textChunks[1].text).toBe(" world") - // Verify beta API was used - expect(mockBetaCreate).toHaveBeenCalled(); - expect(mockCreate).not.toHaveBeenCalled(); - }); - }); + // Verify beta API was used + expect(mockBetaCreate).toHaveBeenCalled() + expect(mockCreate).not.toHaveBeenCalled() + }) + }) - describe('completePrompt', () => { - it('should complete prompt successfully', async () => { - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe('Test response'); - expect(mockCreate).toHaveBeenCalledWith({ - model: mockOptions.apiModelId, - messages: [{ role: 'user', content: 'Test prompt' }], - max_tokens: 8192, - temperature: 0, - stream: false - }); - }); + describe("completePrompt", () => { + it("should complete prompt successfully", async () => { + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test response") + expect(mockCreate).toHaveBeenCalledWith({ + model: mockOptions.apiModelId, + messages: [{ role: "user", content: "Test prompt" }], + max_tokens: 8192, + temperature: 0, + stream: false, + }) + }) - it('should handle API errors', async () => { - mockCreate.mockRejectedValueOnce(new Error('API Error')); - await expect(handler.completePrompt('Test prompt')) - .rejects.toThrow('Anthropic completion error: API Error'); - }); + it("should handle API errors", async () => { + mockCreate.mockRejectedValueOnce(new Error("API Error")) + await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Anthropic completion error: API Error") + }) - it('should handle non-text content', async () => { - mockCreate.mockImplementationOnce(async () => ({ - content: [{ type: 'image' }] - })); - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe(''); - }); + it("should handle non-text content", async () => { + mockCreate.mockImplementationOnce(async () => ({ + content: [{ type: "image" }], + })) + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("") + }) - it('should handle empty response', async () => { - mockCreate.mockImplementationOnce(async () => ({ - content: [{ type: 'text', text: '' }] - })); - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe(''); - }); - }); + it("should handle empty response", async () => { + mockCreate.mockImplementationOnce(async () => ({ + content: [{ type: "text", text: "" }], + })) + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("") + }) + }) - describe('getModel', () => { - it('should return default model if no model ID is provided', () => { - const handlerWithoutModel = new AnthropicHandler({ - ...mockOptions, - apiModelId: undefined - }); - const model = handlerWithoutModel.getModel(); - expect(model.id).toBeDefined(); - expect(model.info).toBeDefined(); - }); + describe("getModel", () => { + it("should return default model if no model ID is provided", () => { + const handlerWithoutModel = new AnthropicHandler({ + ...mockOptions, + apiModelId: undefined, + }) + const model = handlerWithoutModel.getModel() + expect(model.id).toBeDefined() + expect(model.info).toBeDefined() + }) - it('should return specified model if valid model ID is provided', () => { - const model = handler.getModel(); - expect(model.id).toBe(mockOptions.apiModelId); - expect(model.info).toBeDefined(); - expect(model.info.maxTokens).toBe(8192); - expect(model.info.contextWindow).toBe(200_000); - expect(model.info.supportsImages).toBe(true); - expect(model.info.supportsPromptCache).toBe(true); - }); - }); -}); \ No newline at end of file + it("should return specified model if valid model ID is provided", () => { + const model = handler.getModel() + expect(model.id).toBe(mockOptions.apiModelId) + expect(model.info).toBeDefined() + expect(model.info.maxTokens).toBe(8192) + expect(model.info.contextWindow).toBe(200_000) + expect(model.info.supportsImages).toBe(true) + expect(model.info.supportsPromptCache).toBe(true) + }) + }) +}) diff --git a/src/api/providers/__tests__/bedrock.test.ts b/src/api/providers/__tests__/bedrock.test.ts index e8e3f3a..e8c1a44 100644 --- a/src/api/providers/__tests__/bedrock.test.ts +++ b/src/api/providers/__tests__/bedrock.test.ts @@ -1,246 +1,259 @@ -import { AwsBedrockHandler } from '../bedrock'; -import { MessageContent } from '../../../shared/api'; -import { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime'; -import { Anthropic } from '@anthropic-ai/sdk'; +import { AwsBedrockHandler } from "../bedrock" +import { MessageContent } from "../../../shared/api" +import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime" +import { Anthropic } from "@anthropic-ai/sdk" -describe('AwsBedrockHandler', () => { - let handler: AwsBedrockHandler; +describe("AwsBedrockHandler", () => { + let handler: AwsBedrockHandler - beforeEach(() => { - handler = new AwsBedrockHandler({ - apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0', - awsAccessKey: 'test-access-key', - awsSecretKey: 'test-secret-key', - awsRegion: 'us-east-1' - }); - }); + beforeEach(() => { + handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + }) - describe('constructor', () => { - it('should initialize with provided config', () => { - expect(handler['options'].awsAccessKey).toBe('test-access-key'); - expect(handler['options'].awsSecretKey).toBe('test-secret-key'); - expect(handler['options'].awsRegion).toBe('us-east-1'); - expect(handler['options'].apiModelId).toBe('anthropic.claude-3-5-sonnet-20241022-v2:0'); - }); + describe("constructor", () => { + it("should initialize with provided config", () => { + expect(handler["options"].awsAccessKey).toBe("test-access-key") + expect(handler["options"].awsSecretKey).toBe("test-secret-key") + expect(handler["options"].awsRegion).toBe("us-east-1") + expect(handler["options"].apiModelId).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") + }) - it('should initialize with missing AWS credentials', () => { - const handlerWithoutCreds = new AwsBedrockHandler({ - apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0', - awsRegion: 'us-east-1' - }); - expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler); - }); - }); + it("should initialize with missing AWS credentials", () => { + const handlerWithoutCreds = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsRegion: "us-east-1", + }) + expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler) + }) + }) - describe('createMessage', () => { - const mockMessages: Anthropic.Messages.MessageParam[] = [ - { - role: 'user', - content: 'Hello' - }, - { - role: 'assistant', - content: 'Hi there!' - } - ]; + describe("createMessage", () => { + const mockMessages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: "Hello", + }, + { + role: "assistant", + content: "Hi there!", + }, + ] - const systemPrompt = 'You are a helpful assistant'; + const systemPrompt = "You are a helpful assistant" - it('should handle text messages correctly', async () => { - const mockResponse = { - messages: [{ - role: 'assistant', - content: [{ type: 'text', text: 'Hello! How can I help you?' }] - }], - usage: { - input_tokens: 10, - output_tokens: 5 - } - }; + it("should handle text messages correctly", async () => { + const mockResponse = { + messages: [ + { + role: "assistant", + content: [{ type: "text", text: "Hello! How can I help you?" }], + }, + ], + usage: { + input_tokens: 10, + output_tokens: 5, + }, + } - // Mock AWS SDK invoke - const mockStream = { - [Symbol.asyncIterator]: async function* () { - yield { - metadata: { - usage: { - inputTokens: 10, - outputTokens: 5 - } - } - }; - } - }; + // Mock AWS SDK invoke + const mockStream = { + [Symbol.asyncIterator]: async function* () { + yield { + metadata: { + usage: { + inputTokens: 10, + outputTokens: 5, + }, + }, + } + }, + } - const mockInvoke = jest.fn().mockResolvedValue({ - stream: mockStream - }); + const mockInvoke = jest.fn().mockResolvedValue({ + stream: mockStream, + }) - handler['client'] = { - send: mockInvoke - } as unknown as BedrockRuntimeClient; + handler["client"] = { + send: mockInvoke, + } as unknown as BedrockRuntimeClient - const stream = handler.createMessage(systemPrompt, mockMessages); - const chunks = []; - - for await (const chunk of stream) { - chunks.push(chunk); - } + const stream = handler.createMessage(systemPrompt, mockMessages) + const chunks = [] - expect(chunks.length).toBeGreaterThan(0); - expect(chunks[0]).toEqual({ - type: 'usage', - inputTokens: 10, - outputTokens: 5 - }); + for await (const chunk of stream) { + chunks.push(chunk) + } - expect(mockInvoke).toHaveBeenCalledWith(expect.objectContaining({ - input: expect.objectContaining({ - modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0' - }) - })); - }); + expect(chunks.length).toBeGreaterThan(0) + expect(chunks[0]).toEqual({ + type: "usage", + inputTokens: 10, + outputTokens: 5, + }) - it('should handle API errors', async () => { - // Mock AWS SDK invoke with error - const mockInvoke = jest.fn().mockRejectedValue(new Error('AWS Bedrock error')); + expect(mockInvoke).toHaveBeenCalledWith( + expect.objectContaining({ + input: expect.objectContaining({ + modelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + }), + }), + ) + }) - handler['client'] = { - send: mockInvoke - } as unknown as BedrockRuntimeClient; + it("should handle API errors", async () => { + // Mock AWS SDK invoke with error + const mockInvoke = jest.fn().mockRejectedValue(new Error("AWS Bedrock error")) - const stream = handler.createMessage(systemPrompt, mockMessages); + handler["client"] = { + send: mockInvoke, + } as unknown as BedrockRuntimeClient - await expect(async () => { - for await (const chunk of stream) { - // Should throw before yielding any chunks - } - }).rejects.toThrow('AWS Bedrock error'); - }); - }); + const stream = handler.createMessage(systemPrompt, mockMessages) - describe('completePrompt', () => { - it('should complete prompt successfully', async () => { - const mockResponse = { - output: new TextEncoder().encode(JSON.stringify({ - content: 'Test response' - })) - }; + await expect(async () => { + for await (const chunk of stream) { + // Should throw before yielding any chunks + } + }).rejects.toThrow("AWS Bedrock error") + }) + }) - const mockSend = jest.fn().mockResolvedValue(mockResponse); - handler['client'] = { - send: mockSend - } as unknown as BedrockRuntimeClient; + describe("completePrompt", () => { + it("should complete prompt successfully", async () => { + const mockResponse = { + output: new TextEncoder().encode( + JSON.stringify({ + content: "Test response", + }), + ), + } - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe('Test response'); - expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({ - input: expect.objectContaining({ - modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0', - messages: expect.arrayContaining([ - expect.objectContaining({ - role: 'user', - content: [{ text: 'Test prompt' }] - }) - ]), - inferenceConfig: expect.objectContaining({ - maxTokens: 5000, - temperature: 0.3, - topP: 0.1 - }) - }) - })); - }); + const mockSend = jest.fn().mockResolvedValue(mockResponse) + handler["client"] = { + send: mockSend, + } as unknown as BedrockRuntimeClient - it('should handle API errors', async () => { - const mockError = new Error('AWS Bedrock error'); - const mockSend = jest.fn().mockRejectedValue(mockError); - handler['client'] = { - send: mockSend - } as unknown as BedrockRuntimeClient; + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test response") + expect(mockSend).toHaveBeenCalledWith( + expect.objectContaining({ + input: expect.objectContaining({ + modelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + messages: expect.arrayContaining([ + expect.objectContaining({ + role: "user", + content: [{ text: "Test prompt" }], + }), + ]), + inferenceConfig: expect.objectContaining({ + maxTokens: 5000, + temperature: 0.3, + topP: 0.1, + }), + }), + }), + ) + }) - await expect(handler.completePrompt('Test prompt')) - .rejects.toThrow('Bedrock completion error: AWS Bedrock error'); - }); + it("should handle API errors", async () => { + const mockError = new Error("AWS Bedrock error") + const mockSend = jest.fn().mockRejectedValue(mockError) + handler["client"] = { + send: mockSend, + } as unknown as BedrockRuntimeClient - it('should handle invalid response format', async () => { - const mockResponse = { - output: new TextEncoder().encode('invalid json') - }; + await expect(handler.completePrompt("Test prompt")).rejects.toThrow( + "Bedrock completion error: AWS Bedrock error", + ) + }) - const mockSend = jest.fn().mockResolvedValue(mockResponse); - handler['client'] = { - send: mockSend - } as unknown as BedrockRuntimeClient; + it("should handle invalid response format", async () => { + const mockResponse = { + output: new TextEncoder().encode("invalid json"), + } - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe(''); - }); + const mockSend = jest.fn().mockResolvedValue(mockResponse) + handler["client"] = { + send: mockSend, + } as unknown as BedrockRuntimeClient - it('should handle empty response', async () => { - const mockResponse = { - output: new TextEncoder().encode(JSON.stringify({})) - }; + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("") + }) - const mockSend = jest.fn().mockResolvedValue(mockResponse); - handler['client'] = { - send: mockSend - } as unknown as BedrockRuntimeClient; + it("should handle empty response", async () => { + const mockResponse = { + output: new TextEncoder().encode(JSON.stringify({})), + } - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe(''); - }); + const mockSend = jest.fn().mockResolvedValue(mockResponse) + handler["client"] = { + send: mockSend, + } as unknown as BedrockRuntimeClient - it('should handle cross-region inference', async () => { - handler = new AwsBedrockHandler({ - apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0', - awsAccessKey: 'test-access-key', - awsSecretKey: 'test-secret-key', - awsRegion: 'us-east-1', - awsUseCrossRegionInference: true - }); + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("") + }) - const mockResponse = { - output: new TextEncoder().encode(JSON.stringify({ - content: 'Test response' - })) - }; + it("should handle cross-region inference", async () => { + handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + awsUseCrossRegionInference: true, + }) - const mockSend = jest.fn().mockResolvedValue(mockResponse); - handler['client'] = { - send: mockSend - } as unknown as BedrockRuntimeClient; + const mockResponse = { + output: new TextEncoder().encode( + JSON.stringify({ + content: "Test response", + }), + ), + } - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe('Test response'); - expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({ - input: expect.objectContaining({ - modelId: 'us.anthropic.claude-3-5-sonnet-20241022-v2:0' - }) - })); - }); - }); + const mockSend = jest.fn().mockResolvedValue(mockResponse) + handler["client"] = { + send: mockSend, + } as unknown as BedrockRuntimeClient - describe('getModel', () => { - it('should return correct model info in test environment', () => { - const modelInfo = handler.getModel(); - expect(modelInfo.id).toBe('anthropic.claude-3-5-sonnet-20241022-v2:0'); - expect(modelInfo.info).toBeDefined(); - expect(modelInfo.info.maxTokens).toBe(5000); // Test environment value - expect(modelInfo.info.contextWindow).toBe(128_000); // Test environment value - }); + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test response") + expect(mockSend).toHaveBeenCalledWith( + expect.objectContaining({ + input: expect.objectContaining({ + modelId: "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + }), + }), + ) + }) + }) - it('should return test model info for invalid model in test environment', () => { - const invalidHandler = new AwsBedrockHandler({ - apiModelId: 'invalid-model', - awsAccessKey: 'test-access-key', - awsSecretKey: 'test-secret-key', - awsRegion: 'us-east-1' - }); - const modelInfo = invalidHandler.getModel(); - expect(modelInfo.id).toBe('invalid-model'); // In test env, returns whatever is passed - expect(modelInfo.info.maxTokens).toBe(5000); - expect(modelInfo.info.contextWindow).toBe(128_000); - }); - }); -}); + describe("getModel", () => { + it("should return correct model info in test environment", () => { + const modelInfo = handler.getModel() + expect(modelInfo.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") + expect(modelInfo.info).toBeDefined() + expect(modelInfo.info.maxTokens).toBe(5000) // Test environment value + expect(modelInfo.info.contextWindow).toBe(128_000) // Test environment value + }) + + it("should return test model info for invalid model in test environment", () => { + const invalidHandler = new AwsBedrockHandler({ + apiModelId: "invalid-model", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + const modelInfo = invalidHandler.getModel() + expect(modelInfo.id).toBe("invalid-model") // In test env, returns whatever is passed + expect(modelInfo.info.maxTokens).toBe(5000) + expect(modelInfo.info.contextWindow).toBe(128_000) + }) + }) +}) diff --git a/src/api/providers/__tests__/deepseek.test.ts b/src/api/providers/__tests__/deepseek.test.ts index edf6598..00526dc 100644 --- a/src/api/providers/__tests__/deepseek.test.ts +++ b/src/api/providers/__tests__/deepseek.test.ts @@ -1,203 +1,217 @@ -import { DeepSeekHandler } from '../deepseek'; -import { ApiHandlerOptions, deepSeekDefaultModelId } from '../../../shared/api'; -import OpenAI from 'openai'; -import { Anthropic } from '@anthropic-ai/sdk'; +import { DeepSeekHandler } from "../deepseek" +import { ApiHandlerOptions, deepSeekDefaultModelId } from "../../../shared/api" +import OpenAI from "openai" +import { Anthropic } from "@anthropic-ai/sdk" // Mock OpenAI client -const mockCreate = jest.fn(); -jest.mock('openai', () => { - return { - __esModule: true, - default: jest.fn().mockImplementation(() => ({ - chat: { - completions: { - create: mockCreate.mockImplementation(async (options) => { - if (!options.stream) { - return { - id: 'test-completion', - choices: [{ - message: { role: 'assistant', content: 'Test response', refusal: null }, - finish_reason: 'stop', - index: 0 - }], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15 - } - }; - } - - // Return async iterator for streaming - return { - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ - delta: { content: 'Test response' }, - index: 0 - }], - usage: null - }; - yield { - choices: [{ - delta: {}, - index: 0 - }], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15 - } - }; - } - }; - }) - } - } - })) - }; -}); +const mockCreate = jest.fn() +jest.mock("openai", () => { + return { + __esModule: true, + default: jest.fn().mockImplementation(() => ({ + chat: { + completions: { + create: mockCreate.mockImplementation(async (options) => { + if (!options.stream) { + return { + id: "test-completion", + choices: [ + { + message: { role: "assistant", content: "Test response", refusal: null }, + finish_reason: "stop", + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + } -describe('DeepSeekHandler', () => { - let handler: DeepSeekHandler; - let mockOptions: ApiHandlerOptions; + // Return async iterator for streaming + return { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { content: "Test response" }, + index: 0, + }, + ], + usage: null, + } + yield { + choices: [ + { + delta: {}, + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + }, + } + }), + }, + }, + })), + } +}) - beforeEach(() => { - mockOptions = { - deepSeekApiKey: 'test-api-key', - deepSeekModelId: 'deepseek-chat', - deepSeekBaseUrl: 'https://api.deepseek.com/v1' - }; - handler = new DeepSeekHandler(mockOptions); - mockCreate.mockClear(); - }); +describe("DeepSeekHandler", () => { + let handler: DeepSeekHandler + let mockOptions: ApiHandlerOptions - describe('constructor', () => { - it('should initialize with provided options', () => { - expect(handler).toBeInstanceOf(DeepSeekHandler); - expect(handler.getModel().id).toBe(mockOptions.deepSeekModelId); - }); + beforeEach(() => { + mockOptions = { + deepSeekApiKey: "test-api-key", + deepSeekModelId: "deepseek-chat", + deepSeekBaseUrl: "https://api.deepseek.com/v1", + } + handler = new DeepSeekHandler(mockOptions) + mockCreate.mockClear() + }) - it('should throw error if API key is missing', () => { - expect(() => { - new DeepSeekHandler({ - ...mockOptions, - deepSeekApiKey: undefined - }); - }).toThrow('DeepSeek API key is required'); - }); + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(DeepSeekHandler) + expect(handler.getModel().id).toBe(mockOptions.deepSeekModelId) + }) - it('should use default model ID if not provided', () => { - const handlerWithoutModel = new DeepSeekHandler({ - ...mockOptions, - deepSeekModelId: undefined - }); - expect(handlerWithoutModel.getModel().id).toBe(deepSeekDefaultModelId); - }); + it("should throw error if API key is missing", () => { + expect(() => { + new DeepSeekHandler({ + ...mockOptions, + deepSeekApiKey: undefined, + }) + }).toThrow("DeepSeek API key is required") + }) - it('should use default base URL if not provided', () => { - const handlerWithoutBaseUrl = new DeepSeekHandler({ - ...mockOptions, - deepSeekBaseUrl: undefined - }); - expect(handlerWithoutBaseUrl).toBeInstanceOf(DeepSeekHandler); - // The base URL is passed to OpenAI client internally - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ - baseURL: 'https://api.deepseek.com/v1' - })); - }); + it("should use default model ID if not provided", () => { + const handlerWithoutModel = new DeepSeekHandler({ + ...mockOptions, + deepSeekModelId: undefined, + }) + expect(handlerWithoutModel.getModel().id).toBe(deepSeekDefaultModelId) + }) - it('should use custom base URL if provided', () => { - const customBaseUrl = 'https://custom.deepseek.com/v1'; - const handlerWithCustomUrl = new DeepSeekHandler({ - ...mockOptions, - deepSeekBaseUrl: customBaseUrl - }); - expect(handlerWithCustomUrl).toBeInstanceOf(DeepSeekHandler); - // The custom base URL is passed to OpenAI client - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ - baseURL: customBaseUrl - })); - }); + it("should use default base URL if not provided", () => { + const handlerWithoutBaseUrl = new DeepSeekHandler({ + ...mockOptions, + deepSeekBaseUrl: undefined, + }) + expect(handlerWithoutBaseUrl).toBeInstanceOf(DeepSeekHandler) + // The base URL is passed to OpenAI client internally + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://api.deepseek.com/v1", + }), + ) + }) - it('should set includeMaxTokens to true', () => { - // Create a new handler and verify OpenAI client was called with includeMaxTokens - new DeepSeekHandler(mockOptions); - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ - apiKey: mockOptions.deepSeekApiKey - })); - }); - }); + it("should use custom base URL if provided", () => { + const customBaseUrl = "https://custom.deepseek.com/v1" + const handlerWithCustomUrl = new DeepSeekHandler({ + ...mockOptions, + deepSeekBaseUrl: customBaseUrl, + }) + expect(handlerWithCustomUrl).toBeInstanceOf(DeepSeekHandler) + // The custom base URL is passed to OpenAI client + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: customBaseUrl, + }), + ) + }) - describe('getModel', () => { - it('should return model info for valid model ID', () => { - const model = handler.getModel(); - expect(model.id).toBe(mockOptions.deepSeekModelId); - expect(model.info).toBeDefined(); - expect(model.info.maxTokens).toBe(8192); - expect(model.info.contextWindow).toBe(64_000); - expect(model.info.supportsImages).toBe(false); - expect(model.info.supportsPromptCache).toBe(false); - }); + it("should set includeMaxTokens to true", () => { + // Create a new handler and verify OpenAI client was called with includeMaxTokens + new DeepSeekHandler(mockOptions) + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: mockOptions.deepSeekApiKey, + }), + ) + }) + }) - it('should return provided model ID with default model info if model does not exist', () => { - const handlerWithInvalidModel = new DeepSeekHandler({ - ...mockOptions, - deepSeekModelId: 'invalid-model' - }); - const model = handlerWithInvalidModel.getModel(); - expect(model.id).toBe('invalid-model'); // Returns provided ID - expect(model.info).toBeDefined(); - expect(model.info).toBe(handler.getModel().info); // But uses default model info - }); + describe("getModel", () => { + it("should return model info for valid model ID", () => { + const model = handler.getModel() + expect(model.id).toBe(mockOptions.deepSeekModelId) + expect(model.info).toBeDefined() + expect(model.info.maxTokens).toBe(8192) + expect(model.info.contextWindow).toBe(64_000) + expect(model.info.supportsImages).toBe(false) + expect(model.info.supportsPromptCache).toBe(false) + }) - it('should return default model if no model ID is provided', () => { - const handlerWithoutModel = new DeepSeekHandler({ - ...mockOptions, - deepSeekModelId: undefined - }); - const model = handlerWithoutModel.getModel(); - expect(model.id).toBe(deepSeekDefaultModelId); - expect(model.info).toBeDefined(); - }); - }); + it("should return provided model ID with default model info if model does not exist", () => { + const handlerWithInvalidModel = new DeepSeekHandler({ + ...mockOptions, + deepSeekModelId: "invalid-model", + }) + const model = handlerWithInvalidModel.getModel() + expect(model.id).toBe("invalid-model") // Returns provided ID + expect(model.info).toBeDefined() + expect(model.info).toBe(handler.getModel().info) // But uses default model info + }) - describe('createMessage', () => { - const systemPrompt = 'You are a helpful assistant.'; - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: 'user', - content: [{ - type: 'text' as const, - text: 'Hello!' - }] - } - ]; + it("should return default model if no model ID is provided", () => { + const handlerWithoutModel = new DeepSeekHandler({ + ...mockOptions, + deepSeekModelId: undefined, + }) + const model = handlerWithoutModel.getModel() + expect(model.id).toBe(deepSeekDefaultModelId) + expect(model.info).toBeDefined() + }) + }) - it('should handle streaming responses', async () => { - const stream = handler.createMessage(systemPrompt, messages); - const chunks: any[] = []; - for await (const chunk of stream) { - chunks.push(chunk); - } + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text" as const, + text: "Hello!", + }, + ], + }, + ] - expect(chunks.length).toBeGreaterThan(0); - const textChunks = chunks.filter(chunk => chunk.type === 'text'); - expect(textChunks).toHaveLength(1); - expect(textChunks[0].text).toBe('Test response'); - }); + it("should handle streaming responses", async () => { + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } - it('should include usage information', async () => { - const stream = handler.createMessage(systemPrompt, messages); - const chunks: any[] = []; - for await (const chunk of stream) { - chunks.push(chunk); - } + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response") + }) - const usageChunks = chunks.filter(chunk => chunk.type === 'usage'); - expect(usageChunks.length).toBeGreaterThan(0); - expect(usageChunks[0].inputTokens).toBe(10); - expect(usageChunks[0].outputTokens).toBe(5); - }); - }); -}); \ No newline at end of file + it("should include usage information", async () => { + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(10) + expect(usageChunks[0].outputTokens).toBe(5) + }) + }) +}) diff --git a/src/api/providers/__tests__/gemini.test.ts b/src/api/providers/__tests__/gemini.test.ts index a59028e..a8a4eec 100644 --- a/src/api/providers/__tests__/gemini.test.ts +++ b/src/api/providers/__tests__/gemini.test.ts @@ -1,212 +1,210 @@ -import { GeminiHandler } from '../gemini'; -import { Anthropic } from '@anthropic-ai/sdk'; -import { GoogleGenerativeAI } from '@google/generative-ai'; +import { GeminiHandler } from "../gemini" +import { Anthropic } from "@anthropic-ai/sdk" +import { GoogleGenerativeAI } from "@google/generative-ai" // Mock the Google Generative AI SDK -jest.mock('@google/generative-ai', () => ({ - GoogleGenerativeAI: jest.fn().mockImplementation(() => ({ - getGenerativeModel: jest.fn().mockReturnValue({ - generateContentStream: jest.fn(), - generateContent: jest.fn().mockResolvedValue({ - response: { - text: () => 'Test response' - } - }) - }) - })) -})); +jest.mock("@google/generative-ai", () => ({ + GoogleGenerativeAI: jest.fn().mockImplementation(() => ({ + getGenerativeModel: jest.fn().mockReturnValue({ + generateContentStream: jest.fn(), + generateContent: jest.fn().mockResolvedValue({ + response: { + text: () => "Test response", + }, + }), + }), + })), +})) -describe('GeminiHandler', () => { - let handler: GeminiHandler; +describe("GeminiHandler", () => { + let handler: GeminiHandler - beforeEach(() => { - handler = new GeminiHandler({ - apiKey: 'test-key', - apiModelId: 'gemini-2.0-flash-thinking-exp-1219', - geminiApiKey: 'test-key' - }); - }); + beforeEach(() => { + handler = new GeminiHandler({ + apiKey: "test-key", + apiModelId: "gemini-2.0-flash-thinking-exp-1219", + geminiApiKey: "test-key", + }) + }) - describe('constructor', () => { - it('should initialize with provided config', () => { - expect(handler['options'].geminiApiKey).toBe('test-key'); - expect(handler['options'].apiModelId).toBe('gemini-2.0-flash-thinking-exp-1219'); - }); + describe("constructor", () => { + it("should initialize with provided config", () => { + expect(handler["options"].geminiApiKey).toBe("test-key") + expect(handler["options"].apiModelId).toBe("gemini-2.0-flash-thinking-exp-1219") + }) - it('should throw if API key is missing', () => { - expect(() => { - new GeminiHandler({ - apiModelId: 'gemini-2.0-flash-thinking-exp-1219', - geminiApiKey: '' - }); - }).toThrow('API key is required for Google Gemini'); - }); - }); + it("should throw if API key is missing", () => { + expect(() => { + new GeminiHandler({ + apiModelId: "gemini-2.0-flash-thinking-exp-1219", + geminiApiKey: "", + }) + }).toThrow("API key is required for Google Gemini") + }) + }) - describe('createMessage', () => { - const mockMessages: Anthropic.Messages.MessageParam[] = [ - { - role: 'user', - content: 'Hello' - }, - { - role: 'assistant', - content: 'Hi there!' - } - ]; + describe("createMessage", () => { + const mockMessages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: "Hello", + }, + { + role: "assistant", + content: "Hi there!", + }, + ] - const systemPrompt = 'You are a helpful assistant'; + const systemPrompt = "You are a helpful assistant" - it('should handle text messages correctly', async () => { - // Mock the stream response - const mockStream = { - stream: [ - { text: () => 'Hello' }, - { text: () => ' world!' } - ], - response: { - usageMetadata: { - promptTokenCount: 10, - candidatesTokenCount: 5 - } - } - }; + it("should handle text messages correctly", async () => { + // Mock the stream response + const mockStream = { + stream: [{ text: () => "Hello" }, { text: () => " world!" }], + response: { + usageMetadata: { + promptTokenCount: 10, + candidatesTokenCount: 5, + }, + }, + } - // Setup the mock implementation - const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream); - const mockGetGenerativeModel = jest.fn().mockReturnValue({ - generateContentStream: mockGenerateContentStream - }); + // Setup the mock implementation + const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream) + const mockGetGenerativeModel = jest.fn().mockReturnValue({ + generateContentStream: mockGenerateContentStream, + }) - (handler['client'] as any).getGenerativeModel = mockGetGenerativeModel; + ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel - const stream = handler.createMessage(systemPrompt, mockMessages); - const chunks = []; - - for await (const chunk of stream) { - chunks.push(chunk); - } + const stream = handler.createMessage(systemPrompt, mockMessages) + const chunks = [] - // Should have 3 chunks: 'Hello', ' world!', and usage info - expect(chunks.length).toBe(3); - expect(chunks[0]).toEqual({ - type: 'text', - text: 'Hello' - }); - expect(chunks[1]).toEqual({ - type: 'text', - text: ' world!' - }); - expect(chunks[2]).toEqual({ - type: 'usage', - inputTokens: 10, - outputTokens: 5 - }); + for await (const chunk of stream) { + chunks.push(chunk) + } - // Verify the model configuration - expect(mockGetGenerativeModel).toHaveBeenCalledWith({ - model: 'gemini-2.0-flash-thinking-exp-1219', - systemInstruction: systemPrompt - }); + // Should have 3 chunks: 'Hello', ' world!', and usage info + expect(chunks.length).toBe(3) + expect(chunks[0]).toEqual({ + type: "text", + text: "Hello", + }) + expect(chunks[1]).toEqual({ + type: "text", + text: " world!", + }) + expect(chunks[2]).toEqual({ + type: "usage", + inputTokens: 10, + outputTokens: 5, + }) - // Verify generation config - expect(mockGenerateContentStream).toHaveBeenCalledWith( - expect.objectContaining({ - generationConfig: { - temperature: 0 - } - }) - ); - }); + // Verify the model configuration + expect(mockGetGenerativeModel).toHaveBeenCalledWith({ + model: "gemini-2.0-flash-thinking-exp-1219", + systemInstruction: systemPrompt, + }) - it('should handle API errors', async () => { - const mockError = new Error('Gemini API error'); - const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError); - const mockGetGenerativeModel = jest.fn().mockReturnValue({ - generateContentStream: mockGenerateContentStream - }); + // Verify generation config + expect(mockGenerateContentStream).toHaveBeenCalledWith( + expect.objectContaining({ + generationConfig: { + temperature: 0, + }, + }), + ) + }) - (handler['client'] as any).getGenerativeModel = mockGetGenerativeModel; + it("should handle API errors", async () => { + const mockError = new Error("Gemini API error") + const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError) + const mockGetGenerativeModel = jest.fn().mockReturnValue({ + generateContentStream: mockGenerateContentStream, + }) - const stream = handler.createMessage(systemPrompt, mockMessages); + ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel - await expect(async () => { - for await (const chunk of stream) { - // Should throw before yielding any chunks - } - }).rejects.toThrow('Gemini API error'); - }); - }); + const stream = handler.createMessage(systemPrompt, mockMessages) - describe('completePrompt', () => { - it('should complete prompt successfully', async () => { - const mockGenerateContent = jest.fn().mockResolvedValue({ - response: { - text: () => 'Test response' - } - }); - const mockGetGenerativeModel = jest.fn().mockReturnValue({ - generateContent: mockGenerateContent - }); - (handler['client'] as any).getGenerativeModel = mockGetGenerativeModel; + await expect(async () => { + for await (const chunk of stream) { + // Should throw before yielding any chunks + } + }).rejects.toThrow("Gemini API error") + }) + }) - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe('Test response'); - expect(mockGetGenerativeModel).toHaveBeenCalledWith({ - model: 'gemini-2.0-flash-thinking-exp-1219' - }); - expect(mockGenerateContent).toHaveBeenCalledWith({ - contents: [{ role: 'user', parts: [{ text: 'Test prompt' }] }], - generationConfig: { - temperature: 0 - } - }); - }); + describe("completePrompt", () => { + it("should complete prompt successfully", async () => { + const mockGenerateContent = jest.fn().mockResolvedValue({ + response: { + text: () => "Test response", + }, + }) + const mockGetGenerativeModel = jest.fn().mockReturnValue({ + generateContent: mockGenerateContent, + }) + ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel - it('should handle API errors', async () => { - const mockError = new Error('Gemini API error'); - const mockGenerateContent = jest.fn().mockRejectedValue(mockError); - const mockGetGenerativeModel = jest.fn().mockReturnValue({ - generateContent: mockGenerateContent - }); - (handler['client'] as any).getGenerativeModel = mockGetGenerativeModel; + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test response") + expect(mockGetGenerativeModel).toHaveBeenCalledWith({ + model: "gemini-2.0-flash-thinking-exp-1219", + }) + expect(mockGenerateContent).toHaveBeenCalledWith({ + contents: [{ role: "user", parts: [{ text: "Test prompt" }] }], + generationConfig: { + temperature: 0, + }, + }) + }) - await expect(handler.completePrompt('Test prompt')) - .rejects.toThrow('Gemini completion error: Gemini API error'); - }); + it("should handle API errors", async () => { + const mockError = new Error("Gemini API error") + const mockGenerateContent = jest.fn().mockRejectedValue(mockError) + const mockGetGenerativeModel = jest.fn().mockReturnValue({ + generateContent: mockGenerateContent, + }) + ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel - it('should handle empty response', async () => { - const mockGenerateContent = jest.fn().mockResolvedValue({ - response: { - text: () => '' - } - }); - const mockGetGenerativeModel = jest.fn().mockReturnValue({ - generateContent: mockGenerateContent - }); - (handler['client'] as any).getGenerativeModel = mockGetGenerativeModel; + await expect(handler.completePrompt("Test prompt")).rejects.toThrow( + "Gemini completion error: Gemini API error", + ) + }) - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe(''); - }); - }); + it("should handle empty response", async () => { + const mockGenerateContent = jest.fn().mockResolvedValue({ + response: { + text: () => "", + }, + }) + const mockGetGenerativeModel = jest.fn().mockReturnValue({ + generateContent: mockGenerateContent, + }) + ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel - describe('getModel', () => { - it('should return correct model info', () => { - const modelInfo = handler.getModel(); - expect(modelInfo.id).toBe('gemini-2.0-flash-thinking-exp-1219'); - expect(modelInfo.info).toBeDefined(); - expect(modelInfo.info.maxTokens).toBe(8192); - expect(modelInfo.info.contextWindow).toBe(32_767); - }); + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("") + }) + }) - it('should return default model if invalid model specified', () => { - const invalidHandler = new GeminiHandler({ - apiModelId: 'invalid-model', - geminiApiKey: 'test-key' - }); - const modelInfo = invalidHandler.getModel(); - expect(modelInfo.id).toBe('gemini-2.0-flash-thinking-exp-1219'); // Default model - }); - }); -}); \ No newline at end of file + describe("getModel", () => { + it("should return correct model info", () => { + const modelInfo = handler.getModel() + expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219") + expect(modelInfo.info).toBeDefined() + expect(modelInfo.info.maxTokens).toBe(8192) + expect(modelInfo.info.contextWindow).toBe(32_767) + }) + + it("should return default model if invalid model specified", () => { + const invalidHandler = new GeminiHandler({ + apiModelId: "invalid-model", + geminiApiKey: "test-key", + }) + const modelInfo = invalidHandler.getModel() + expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219") // Default model + }) + }) +}) diff --git a/src/api/providers/__tests__/glama.test.ts b/src/api/providers/__tests__/glama.test.ts index e67b80e..c3fc90e 100644 --- a/src/api/providers/__tests__/glama.test.ts +++ b/src/api/providers/__tests__/glama.test.ts @@ -1,226 +1,238 @@ -import { GlamaHandler } from '../glama'; -import { ApiHandlerOptions } from '../../../shared/api'; -import OpenAI from 'openai'; -import { Anthropic } from '@anthropic-ai/sdk'; -import axios from 'axios'; +import { GlamaHandler } from "../glama" +import { ApiHandlerOptions } from "../../../shared/api" +import OpenAI from "openai" +import { Anthropic } from "@anthropic-ai/sdk" +import axios from "axios" // Mock OpenAI client -const mockCreate = jest.fn(); -const mockWithResponse = jest.fn(); +const mockCreate = jest.fn() +const mockWithResponse = jest.fn() -jest.mock('openai', () => { - return { - __esModule: true, - default: jest.fn().mockImplementation(() => ({ - chat: { - completions: { - create: (...args: any[]) => { - const stream = { - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ - delta: { content: 'Test response' }, - index: 0 - }], - usage: null - }; - yield { - choices: [{ - delta: {}, - index: 0 - }], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15 - } - }; - } - }; +jest.mock("openai", () => { + return { + __esModule: true, + default: jest.fn().mockImplementation(() => ({ + chat: { + completions: { + create: (...args: any[]) => { + const stream = { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { content: "Test response" }, + index: 0, + }, + ], + usage: null, + } + yield { + choices: [ + { + delta: {}, + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + }, + } - const result = mockCreate(...args); - if (args[0].stream) { - mockWithResponse.mockReturnValue(Promise.resolve({ - data: stream, - response: { - headers: { - get: (name: string) => name === 'x-completion-request-id' ? 'test-request-id' : null - } - } - })); - result.withResponse = mockWithResponse; - } - return result; - } - } - } - })) - }; -}); + const result = mockCreate(...args) + if (args[0].stream) { + mockWithResponse.mockReturnValue( + Promise.resolve({ + data: stream, + response: { + headers: { + get: (name: string) => + name === "x-completion-request-id" ? "test-request-id" : null, + }, + }, + }), + ) + result.withResponse = mockWithResponse + } + return result + }, + }, + }, + })), + } +}) -describe('GlamaHandler', () => { - let handler: GlamaHandler; - let mockOptions: ApiHandlerOptions; +describe("GlamaHandler", () => { + let handler: GlamaHandler + let mockOptions: ApiHandlerOptions - beforeEach(() => { - mockOptions = { - apiModelId: 'anthropic/claude-3-5-sonnet', - glamaModelId: 'anthropic/claude-3-5-sonnet', - glamaApiKey: 'test-api-key' - }; - handler = new GlamaHandler(mockOptions); - mockCreate.mockClear(); - mockWithResponse.mockClear(); + beforeEach(() => { + mockOptions = { + apiModelId: "anthropic/claude-3-5-sonnet", + glamaModelId: "anthropic/claude-3-5-sonnet", + glamaApiKey: "test-api-key", + } + handler = new GlamaHandler(mockOptions) + mockCreate.mockClear() + mockWithResponse.mockClear() - // Default mock implementation for non-streaming responses - mockCreate.mockResolvedValue({ - id: 'test-completion', - choices: [{ - message: { role: 'assistant', content: 'Test response' }, - finish_reason: 'stop', - index: 0 - }], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15 - } - }); - }); + // Default mock implementation for non-streaming responses + mockCreate.mockResolvedValue({ + id: "test-completion", + choices: [ + { + message: { role: "assistant", content: "Test response" }, + finish_reason: "stop", + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + }) + }) - describe('constructor', () => { - it('should initialize with provided options', () => { - expect(handler).toBeInstanceOf(GlamaHandler); - expect(handler.getModel().id).toBe(mockOptions.apiModelId); - }); - }); + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(GlamaHandler) + expect(handler.getModel().id).toBe(mockOptions.apiModelId) + }) + }) - describe('createMessage', () => { - const systemPrompt = 'You are a helpful assistant.'; - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: 'user', - content: 'Hello!' - } - ]; + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: "Hello!", + }, + ] - it('should handle streaming responses', async () => { - // Mock axios for token usage request - const mockAxios = jest.spyOn(axios, 'get').mockResolvedValueOnce({ - data: { - tokenUsage: { - promptTokens: 10, - completionTokens: 5, - cacheCreationInputTokens: 0, - cacheReadInputTokens: 0 - }, - totalCostUsd: "0.00" - } - }); + it("should handle streaming responses", async () => { + // Mock axios for token usage request + const mockAxios = jest.spyOn(axios, "get").mockResolvedValueOnce({ + data: { + tokenUsage: { + promptTokens: 10, + completionTokens: 5, + cacheCreationInputTokens: 0, + cacheReadInputTokens: 0, + }, + totalCostUsd: "0.00", + }, + }) - const stream = handler.createMessage(systemPrompt, messages); - const chunks: any[] = []; - for await (const chunk of stream) { - chunks.push(chunk); - } + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } - expect(chunks.length).toBe(2); // Text chunk and usage chunk - expect(chunks[0]).toEqual({ - type: 'text', - text: 'Test response' - }); - expect(chunks[1]).toEqual({ - type: 'usage', - inputTokens: 10, - outputTokens: 5, - cacheWriteTokens: 0, - cacheReadTokens: 0, - totalCost: 0 - }); + expect(chunks.length).toBe(2) // Text chunk and usage chunk + expect(chunks[0]).toEqual({ + type: "text", + text: "Test response", + }) + expect(chunks[1]).toEqual({ + type: "usage", + inputTokens: 10, + outputTokens: 5, + cacheWriteTokens: 0, + cacheReadTokens: 0, + totalCost: 0, + }) - mockAxios.mockRestore(); - }); + mockAxios.mockRestore() + }) - it('should handle API errors', async () => { - mockCreate.mockImplementationOnce(() => { - throw new Error('API Error'); - }); + it("should handle API errors", async () => { + mockCreate.mockImplementationOnce(() => { + throw new Error("API Error") + }) - const stream = handler.createMessage(systemPrompt, messages); - const chunks = []; + const stream = handler.createMessage(systemPrompt, messages) + const chunks = [] - try { - for await (const chunk of stream) { - chunks.push(chunk); - } - fail('Expected error to be thrown'); - } catch (error) { - expect(error).toBeInstanceOf(Error); - expect(error.message).toBe('API Error'); - } - }); - }); + try { + for await (const chunk of stream) { + chunks.push(chunk) + } + fail("Expected error to be thrown") + } catch (error) { + expect(error).toBeInstanceOf(Error) + expect(error.message).toBe("API Error") + } + }) + }) - describe('completePrompt', () => { - it('should complete prompt successfully', async () => { - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe('Test response'); - expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ - model: mockOptions.apiModelId, - messages: [{ role: 'user', content: 'Test prompt' }], - temperature: 0, - max_tokens: 8192 - })); - }); + describe("completePrompt", () => { + it("should complete prompt successfully", async () => { + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test response") + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: mockOptions.apiModelId, + messages: [{ role: "user", content: "Test prompt" }], + temperature: 0, + max_tokens: 8192, + }), + ) + }) - it('should handle API errors', async () => { - mockCreate.mockRejectedValueOnce(new Error('API Error')); - await expect(handler.completePrompt('Test prompt')) - .rejects.toThrow('Glama completion error: API Error'); - }); + it("should handle API errors", async () => { + mockCreate.mockRejectedValueOnce(new Error("API Error")) + await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Glama completion error: API Error") + }) - it('should handle empty response', async () => { - mockCreate.mockResolvedValueOnce({ - choices: [{ message: { content: '' } }] - }); - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe(''); - }); + it("should handle empty response", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "" } }], + }) + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("") + }) - it('should not set max_tokens for non-Anthropic models', async () => { - // Reset mock to clear any previous calls - mockCreate.mockClear(); - - const nonAnthropicOptions = { - apiModelId: 'openai/gpt-4', - glamaModelId: 'openai/gpt-4', - glamaApiKey: 'test-key', - glamaModelInfo: { - maxTokens: 4096, - contextWindow: 8192, - supportsImages: true, - supportsPromptCache: false - } - }; - const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions); + it("should not set max_tokens for non-Anthropic models", async () => { + // Reset mock to clear any previous calls + mockCreate.mockClear() - await nonAnthropicHandler.completePrompt('Test prompt'); - expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ - model: 'openai/gpt-4', - messages: [{ role: 'user', content: 'Test prompt' }], - temperature: 0 - })); - expect(mockCreate.mock.calls[0][0]).not.toHaveProperty('max_tokens'); - }); - }); + const nonAnthropicOptions = { + apiModelId: "openai/gpt-4", + glamaModelId: "openai/gpt-4", + glamaApiKey: "test-key", + glamaModelInfo: { + maxTokens: 4096, + contextWindow: 8192, + supportsImages: true, + supportsPromptCache: false, + }, + } + const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions) - describe('getModel', () => { - it('should return model info', () => { - const modelInfo = handler.getModel(); - expect(modelInfo.id).toBe(mockOptions.apiModelId); - expect(modelInfo.info).toBeDefined(); - expect(modelInfo.info.maxTokens).toBe(8192); - expect(modelInfo.info.contextWindow).toBe(200_000); - }); - }); -}); \ No newline at end of file + await nonAnthropicHandler.completePrompt("Test prompt") + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: "openai/gpt-4", + messages: [{ role: "user", content: "Test prompt" }], + temperature: 0, + }), + ) + expect(mockCreate.mock.calls[0][0]).not.toHaveProperty("max_tokens") + }) + }) + + describe("getModel", () => { + it("should return model info", () => { + const modelInfo = handler.getModel() + expect(modelInfo.id).toBe(mockOptions.apiModelId) + expect(modelInfo.info).toBeDefined() + expect(modelInfo.info.maxTokens).toBe(8192) + expect(modelInfo.info.contextWindow).toBe(200_000) + }) + }) +}) diff --git a/src/api/providers/__tests__/lmstudio.test.ts b/src/api/providers/__tests__/lmstudio.test.ts index 6b84796..114f993 100644 --- a/src/api/providers/__tests__/lmstudio.test.ts +++ b/src/api/providers/__tests__/lmstudio.test.ts @@ -1,160 +1,167 @@ -import { LmStudioHandler } from '../lmstudio'; -import { ApiHandlerOptions } from '../../../shared/api'; -import OpenAI from 'openai'; -import { Anthropic } from '@anthropic-ai/sdk'; +import { LmStudioHandler } from "../lmstudio" +import { ApiHandlerOptions } from "../../../shared/api" +import OpenAI from "openai" +import { Anthropic } from "@anthropic-ai/sdk" // Mock OpenAI client -const mockCreate = jest.fn(); -jest.mock('openai', () => { - return { - __esModule: true, - default: jest.fn().mockImplementation(() => ({ - chat: { - completions: { - create: mockCreate.mockImplementation(async (options) => { - if (!options.stream) { - return { - id: 'test-completion', - choices: [{ - message: { role: 'assistant', content: 'Test response' }, - finish_reason: 'stop', - index: 0 - }], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15 - } - }; - } - - return { - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ - delta: { content: 'Test response' }, - index: 0 - }], - usage: null - }; - yield { - choices: [{ - delta: {}, - index: 0 - }], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15 - } - }; - } - }; - }) - } - } - })) - }; -}); +const mockCreate = jest.fn() +jest.mock("openai", () => { + return { + __esModule: true, + default: jest.fn().mockImplementation(() => ({ + chat: { + completions: { + create: mockCreate.mockImplementation(async (options) => { + if (!options.stream) { + return { + id: "test-completion", + choices: [ + { + message: { role: "assistant", content: "Test response" }, + finish_reason: "stop", + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + } -describe('LmStudioHandler', () => { - let handler: LmStudioHandler; - let mockOptions: ApiHandlerOptions; + return { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { content: "Test response" }, + index: 0, + }, + ], + usage: null, + } + yield { + choices: [ + { + delta: {}, + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + }, + } + }), + }, + }, + })), + } +}) - beforeEach(() => { - mockOptions = { - apiModelId: 'local-model', - lmStudioModelId: 'local-model', - lmStudioBaseUrl: 'http://localhost:1234/v1' - }; - handler = new LmStudioHandler(mockOptions); - mockCreate.mockClear(); - }); +describe("LmStudioHandler", () => { + let handler: LmStudioHandler + let mockOptions: ApiHandlerOptions - describe('constructor', () => { - it('should initialize with provided options', () => { - expect(handler).toBeInstanceOf(LmStudioHandler); - expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId); - }); + beforeEach(() => { + mockOptions = { + apiModelId: "local-model", + lmStudioModelId: "local-model", + lmStudioBaseUrl: "http://localhost:1234/v1", + } + handler = new LmStudioHandler(mockOptions) + mockCreate.mockClear() + }) - it('should use default base URL if not provided', () => { - const handlerWithoutUrl = new LmStudioHandler({ - apiModelId: 'local-model', - lmStudioModelId: 'local-model' - }); - expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler); - }); - }); + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(LmStudioHandler) + expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId) + }) - describe('createMessage', () => { - const systemPrompt = 'You are a helpful assistant.'; - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: 'user', - content: 'Hello!' - } - ]; + it("should use default base URL if not provided", () => { + const handlerWithoutUrl = new LmStudioHandler({ + apiModelId: "local-model", + lmStudioModelId: "local-model", + }) + expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler) + }) + }) - it('should handle streaming responses', async () => { - const stream = handler.createMessage(systemPrompt, messages); - const chunks: any[] = []; - for await (const chunk of stream) { - chunks.push(chunk); - } + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: "Hello!", + }, + ] - expect(chunks.length).toBeGreaterThan(0); - const textChunks = chunks.filter(chunk => chunk.type === 'text'); - expect(textChunks).toHaveLength(1); - expect(textChunks[0].text).toBe('Test response'); - }); + it("should handle streaming responses", async () => { + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } - it('should handle API errors', async () => { - mockCreate.mockRejectedValueOnce(new Error('API Error')); + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response") + }) - const stream = handler.createMessage(systemPrompt, messages); + it("should handle API errors", async () => { + mockCreate.mockRejectedValueOnce(new Error("API Error")) - await expect(async () => { - for await (const chunk of stream) { - // Should not reach here - } - }).rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong'); - }); - }); + const stream = handler.createMessage(systemPrompt, messages) - describe('completePrompt', () => { - it('should complete prompt successfully', async () => { - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe('Test response'); - expect(mockCreate).toHaveBeenCalledWith({ - model: mockOptions.lmStudioModelId, - messages: [{ role: 'user', content: 'Test prompt' }], - temperature: 0, - stream: false - }); - }); + await expect(async () => { + for await (const chunk of stream) { + // Should not reach here + } + }).rejects.toThrow("Please check the LM Studio developer logs to debug what went wrong") + }) + }) - it('should handle API errors', async () => { - mockCreate.mockRejectedValueOnce(new Error('API Error')); - await expect(handler.completePrompt('Test prompt')) - .rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong'); - }); + describe("completePrompt", () => { + it("should complete prompt successfully", async () => { + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test response") + expect(mockCreate).toHaveBeenCalledWith({ + model: mockOptions.lmStudioModelId, + messages: [{ role: "user", content: "Test prompt" }], + temperature: 0, + stream: false, + }) + }) - it('should handle empty response', async () => { - mockCreate.mockResolvedValueOnce({ - choices: [{ message: { content: '' } }] - }); - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe(''); - }); - }); + it("should handle API errors", async () => { + mockCreate.mockRejectedValueOnce(new Error("API Error")) + await expect(handler.completePrompt("Test prompt")).rejects.toThrow( + "Please check the LM Studio developer logs to debug what went wrong", + ) + }) - describe('getModel', () => { - it('should return model info', () => { - const modelInfo = handler.getModel(); - expect(modelInfo.id).toBe(mockOptions.lmStudioModelId); - expect(modelInfo.info).toBeDefined(); - expect(modelInfo.info.maxTokens).toBe(-1); - expect(modelInfo.info.contextWindow).toBe(128_000); - }); - }); -}); \ No newline at end of file + it("should handle empty response", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "" } }], + }) + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("") + }) + }) + + describe("getModel", () => { + it("should return model info", () => { + const modelInfo = handler.getModel() + expect(modelInfo.id).toBe(mockOptions.lmStudioModelId) + expect(modelInfo.info).toBeDefined() + expect(modelInfo.info.maxTokens).toBe(-1) + expect(modelInfo.info.contextWindow).toBe(128_000) + }) + }) +}) diff --git a/src/api/providers/__tests__/ollama.test.ts b/src/api/providers/__tests__/ollama.test.ts index fc4c9f5..a0fc009 100644 --- a/src/api/providers/__tests__/ollama.test.ts +++ b/src/api/providers/__tests__/ollama.test.ts @@ -1,160 +1,165 @@ -import { OllamaHandler } from '../ollama'; -import { ApiHandlerOptions } from '../../../shared/api'; -import OpenAI from 'openai'; -import { Anthropic } from '@anthropic-ai/sdk'; +import { OllamaHandler } from "../ollama" +import { ApiHandlerOptions } from "../../../shared/api" +import OpenAI from "openai" +import { Anthropic } from "@anthropic-ai/sdk" // Mock OpenAI client -const mockCreate = jest.fn(); -jest.mock('openai', () => { - return { - __esModule: true, - default: jest.fn().mockImplementation(() => ({ - chat: { - completions: { - create: mockCreate.mockImplementation(async (options) => { - if (!options.stream) { - return { - id: 'test-completion', - choices: [{ - message: { role: 'assistant', content: 'Test response' }, - finish_reason: 'stop', - index: 0 - }], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15 - } - }; - } - - return { - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ - delta: { content: 'Test response' }, - index: 0 - }], - usage: null - }; - yield { - choices: [{ - delta: {}, - index: 0 - }], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15 - } - }; - } - }; - }) - } - } - })) - }; -}); +const mockCreate = jest.fn() +jest.mock("openai", () => { + return { + __esModule: true, + default: jest.fn().mockImplementation(() => ({ + chat: { + completions: { + create: mockCreate.mockImplementation(async (options) => { + if (!options.stream) { + return { + id: "test-completion", + choices: [ + { + message: { role: "assistant", content: "Test response" }, + finish_reason: "stop", + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + } -describe('OllamaHandler', () => { - let handler: OllamaHandler; - let mockOptions: ApiHandlerOptions; + return { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { content: "Test response" }, + index: 0, + }, + ], + usage: null, + } + yield { + choices: [ + { + delta: {}, + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + }, + } + }), + }, + }, + })), + } +}) - beforeEach(() => { - mockOptions = { - apiModelId: 'llama2', - ollamaModelId: 'llama2', - ollamaBaseUrl: 'http://localhost:11434/v1' - }; - handler = new OllamaHandler(mockOptions); - mockCreate.mockClear(); - }); +describe("OllamaHandler", () => { + let handler: OllamaHandler + let mockOptions: ApiHandlerOptions - describe('constructor', () => { - it('should initialize with provided options', () => { - expect(handler).toBeInstanceOf(OllamaHandler); - expect(handler.getModel().id).toBe(mockOptions.ollamaModelId); - }); + beforeEach(() => { + mockOptions = { + apiModelId: "llama2", + ollamaModelId: "llama2", + ollamaBaseUrl: "http://localhost:11434/v1", + } + handler = new OllamaHandler(mockOptions) + mockCreate.mockClear() + }) - it('should use default base URL if not provided', () => { - const handlerWithoutUrl = new OllamaHandler({ - apiModelId: 'llama2', - ollamaModelId: 'llama2' - }); - expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler); - }); - }); + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(OllamaHandler) + expect(handler.getModel().id).toBe(mockOptions.ollamaModelId) + }) - describe('createMessage', () => { - const systemPrompt = 'You are a helpful assistant.'; - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: 'user', - content: 'Hello!' - } - ]; + it("should use default base URL if not provided", () => { + const handlerWithoutUrl = new OllamaHandler({ + apiModelId: "llama2", + ollamaModelId: "llama2", + }) + expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler) + }) + }) - it('should handle streaming responses', async () => { - const stream = handler.createMessage(systemPrompt, messages); - const chunks: any[] = []; - for await (const chunk of stream) { - chunks.push(chunk); - } + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: "Hello!", + }, + ] - expect(chunks.length).toBeGreaterThan(0); - const textChunks = chunks.filter(chunk => chunk.type === 'text'); - expect(textChunks).toHaveLength(1); - expect(textChunks[0].text).toBe('Test response'); - }); + it("should handle streaming responses", async () => { + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } - it('should handle API errors', async () => { - mockCreate.mockRejectedValueOnce(new Error('API Error')); + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response") + }) - const stream = handler.createMessage(systemPrompt, messages); + it("should handle API errors", async () => { + mockCreate.mockRejectedValueOnce(new Error("API Error")) - await expect(async () => { - for await (const chunk of stream) { - // Should not reach here - } - }).rejects.toThrow('API Error'); - }); - }); + const stream = handler.createMessage(systemPrompt, messages) - describe('completePrompt', () => { - it('should complete prompt successfully', async () => { - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe('Test response'); - expect(mockCreate).toHaveBeenCalledWith({ - model: mockOptions.ollamaModelId, - messages: [{ role: 'user', content: 'Test prompt' }], - temperature: 0, - stream: false - }); - }); + await expect(async () => { + for await (const chunk of stream) { + // Should not reach here + } + }).rejects.toThrow("API Error") + }) + }) - it('should handle API errors', async () => { - mockCreate.mockRejectedValueOnce(new Error('API Error')); - await expect(handler.completePrompt('Test prompt')) - .rejects.toThrow('Ollama completion error: API Error'); - }); + describe("completePrompt", () => { + it("should complete prompt successfully", async () => { + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test response") + expect(mockCreate).toHaveBeenCalledWith({ + model: mockOptions.ollamaModelId, + messages: [{ role: "user", content: "Test prompt" }], + temperature: 0, + stream: false, + }) + }) - it('should handle empty response', async () => { - mockCreate.mockResolvedValueOnce({ - choices: [{ message: { content: '' } }] - }); - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe(''); - }); - }); + it("should handle API errors", async () => { + mockCreate.mockRejectedValueOnce(new Error("API Error")) + await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Ollama completion error: API Error") + }) - describe('getModel', () => { - it('should return model info', () => { - const modelInfo = handler.getModel(); - expect(modelInfo.id).toBe(mockOptions.ollamaModelId); - expect(modelInfo.info).toBeDefined(); - expect(modelInfo.info.maxTokens).toBe(-1); - expect(modelInfo.info.contextWindow).toBe(128_000); - }); - }); -}); \ No newline at end of file + it("should handle empty response", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "" } }], + }) + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("") + }) + }) + + describe("getModel", () => { + it("should return model info", () => { + const modelInfo = handler.getModel() + expect(modelInfo.id).toBe(mockOptions.ollamaModelId) + expect(modelInfo.info).toBeDefined() + expect(modelInfo.info.maxTokens).toBe(-1) + expect(modelInfo.info.contextWindow).toBe(128_000) + }) + }) +}) diff --git a/src/api/providers/__tests__/openai-native.test.ts b/src/api/providers/__tests__/openai-native.test.ts index 7b263b0..f1da211 100644 --- a/src/api/providers/__tests__/openai-native.test.ts +++ b/src/api/providers/__tests__/openai-native.test.ts @@ -1,319 +1,326 @@ -import { OpenAiNativeHandler } from '../openai-native'; -import { ApiHandlerOptions } from '../../../shared/api'; -import OpenAI from 'openai'; -import { Anthropic } from '@anthropic-ai/sdk'; +import { OpenAiNativeHandler } from "../openai-native" +import { ApiHandlerOptions } from "../../../shared/api" +import OpenAI from "openai" +import { Anthropic } from "@anthropic-ai/sdk" // Mock OpenAI client -const mockCreate = jest.fn(); -jest.mock('openai', () => { - return { - __esModule: true, - default: jest.fn().mockImplementation(() => ({ - chat: { - completions: { - create: mockCreate.mockImplementation(async (options) => { - if (!options.stream) { - return { - id: 'test-completion', - choices: [{ - message: { role: 'assistant', content: 'Test response' }, - finish_reason: 'stop', - index: 0 - }], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15 - } - }; - } - - return { - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ - delta: { content: 'Test response' }, - index: 0 - }], - usage: null - }; - yield { - choices: [{ - delta: {}, - index: 0 - }], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15 - } - }; - } - }; - }) - } - } - })) - }; -}); +const mockCreate = jest.fn() +jest.mock("openai", () => { + return { + __esModule: true, + default: jest.fn().mockImplementation(() => ({ + chat: { + completions: { + create: mockCreate.mockImplementation(async (options) => { + if (!options.stream) { + return { + id: "test-completion", + choices: [ + { + message: { role: "assistant", content: "Test response" }, + finish_reason: "stop", + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + } -describe('OpenAiNativeHandler', () => { - let handler: OpenAiNativeHandler; - let mockOptions: ApiHandlerOptions; - const systemPrompt = 'You are a helpful assistant.'; - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: 'user', - content: 'Hello!' - } - ]; + return { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { content: "Test response" }, + index: 0, + }, + ], + usage: null, + } + yield { + choices: [ + { + delta: {}, + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + }, + } + }), + }, + }, + })), + } +}) - beforeEach(() => { - mockOptions = { - apiModelId: 'gpt-4o', - openAiNativeApiKey: 'test-api-key' - }; - handler = new OpenAiNativeHandler(mockOptions); - mockCreate.mockClear(); - }); +describe("OpenAiNativeHandler", () => { + let handler: OpenAiNativeHandler + let mockOptions: ApiHandlerOptions + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: "Hello!", + }, + ] - describe('constructor', () => { - it('should initialize with provided options', () => { - expect(handler).toBeInstanceOf(OpenAiNativeHandler); - expect(handler.getModel().id).toBe(mockOptions.apiModelId); - }); + beforeEach(() => { + mockOptions = { + apiModelId: "gpt-4o", + openAiNativeApiKey: "test-api-key", + } + handler = new OpenAiNativeHandler(mockOptions) + mockCreate.mockClear() + }) - it('should initialize with empty API key', () => { - const handlerWithoutKey = new OpenAiNativeHandler({ - apiModelId: 'gpt-4o', - openAiNativeApiKey: '' - }); - expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler); - }); - }); + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(OpenAiNativeHandler) + expect(handler.getModel().id).toBe(mockOptions.apiModelId) + }) - describe('createMessage', () => { - it('should handle streaming responses', async () => { - const stream = handler.createMessage(systemPrompt, messages); - const chunks: any[] = []; - for await (const chunk of stream) { - chunks.push(chunk); - } + it("should initialize with empty API key", () => { + const handlerWithoutKey = new OpenAiNativeHandler({ + apiModelId: "gpt-4o", + openAiNativeApiKey: "", + }) + expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler) + }) + }) - expect(chunks.length).toBeGreaterThan(0); - const textChunks = chunks.filter(chunk => chunk.type === 'text'); - expect(textChunks).toHaveLength(1); - expect(textChunks[0].text).toBe('Test response'); - }); + describe("createMessage", () => { + it("should handle streaming responses", async () => { + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } - it('should handle API errors', async () => { - mockCreate.mockRejectedValueOnce(new Error('API Error')); - const stream = handler.createMessage(systemPrompt, messages); - await expect(async () => { - for await (const chunk of stream) { - // Should not reach here - } - }).rejects.toThrow('API Error'); - }); + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response") + }) - it('should handle missing content in response for o1 model', async () => { - // Use o1 model which supports developer role - handler = new OpenAiNativeHandler({ - ...mockOptions, - apiModelId: 'o1' - }); + it("should handle API errors", async () => { + mockCreate.mockRejectedValueOnce(new Error("API Error")) + const stream = handler.createMessage(systemPrompt, messages) + await expect(async () => { + for await (const chunk of stream) { + // Should not reach here + } + }).rejects.toThrow("API Error") + }) - mockCreate.mockResolvedValueOnce({ - choices: [{ message: { content: null } }], - usage: { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0 - } - }); + it("should handle missing content in response for o1 model", async () => { + // Use o1 model which supports developer role + handler = new OpenAiNativeHandler({ + ...mockOptions, + apiModelId: "o1", + }) - const generator = handler.createMessage(systemPrompt, messages); - const results = []; - for await (const result of generator) { - results.push(result); - } + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: null } }], + usage: { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }, + }) - expect(results).toEqual([ - { type: 'text', text: '' }, - { type: 'usage', inputTokens: 0, outputTokens: 0 } - ]); + const generator = handler.createMessage(systemPrompt, messages) + const results = [] + for await (const result of generator) { + results.push(result) + } - // Verify developer role is used for system prompt with o1 model - expect(mockCreate).toHaveBeenCalledWith({ - model: 'o1', - messages: [ - { role: 'developer', content: systemPrompt }, - { role: 'user', content: 'Hello!' } - ] - }); - }); - }); + expect(results).toEqual([ + { type: "text", text: "" }, + { type: "usage", inputTokens: 0, outputTokens: 0 }, + ]) - describe('streaming models', () => { - beforeEach(() => { - handler = new OpenAiNativeHandler({ - ...mockOptions, - apiModelId: 'gpt-4o', - }); - }); + // Verify developer role is used for system prompt with o1 model + expect(mockCreate).toHaveBeenCalledWith({ + model: "o1", + messages: [ + { role: "developer", content: systemPrompt }, + { role: "user", content: "Hello!" }, + ], + }) + }) + }) - it('should handle streaming response', async () => { - const mockStream = [ - { choices: [{ delta: { content: 'Hello' } }], usage: null }, - { choices: [{ delta: { content: ' there' } }], usage: null }, - { choices: [{ delta: { content: '!' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } }, - ]; + describe("streaming models", () => { + beforeEach(() => { + handler = new OpenAiNativeHandler({ + ...mockOptions, + apiModelId: "gpt-4o", + }) + }) - mockCreate.mockResolvedValueOnce( - (async function* () { - for (const chunk of mockStream) { - yield chunk; - } - })() - ); + it("should handle streaming response", async () => { + const mockStream = [ + { choices: [{ delta: { content: "Hello" } }], usage: null }, + { choices: [{ delta: { content: " there" } }], usage: null }, + { choices: [{ delta: { content: "!" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } }, + ] - const generator = handler.createMessage(systemPrompt, messages); - const results = []; - for await (const result of generator) { - results.push(result); - } + mockCreate.mockResolvedValueOnce( + (async function* () { + for (const chunk of mockStream) { + yield chunk + } + })(), + ) - expect(results).toEqual([ - { type: 'text', text: 'Hello' }, - { type: 'text', text: ' there' }, - { type: 'text', text: '!' }, - { type: 'usage', inputTokens: 10, outputTokens: 5 }, - ]); + const generator = handler.createMessage(systemPrompt, messages) + const results = [] + for await (const result of generator) { + results.push(result) + } - expect(mockCreate).toHaveBeenCalledWith({ - model: 'gpt-4o', - temperature: 0, - messages: [ - { role: 'system', content: systemPrompt }, - { role: 'user', content: 'Hello!' }, - ], - stream: true, - stream_options: { include_usage: true }, - }); - }); + expect(results).toEqual([ + { type: "text", text: "Hello" }, + { type: "text", text: " there" }, + { type: "text", text: "!" }, + { type: "usage", inputTokens: 10, outputTokens: 5 }, + ]) - it('should handle empty delta content', async () => { - const mockStream = [ - { choices: [{ delta: {} }], usage: null }, - { choices: [{ delta: { content: null } }], usage: null }, - { choices: [{ delta: { content: 'Hello' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } }, - ]; + expect(mockCreate).toHaveBeenCalledWith({ + model: "gpt-4o", + temperature: 0, + messages: [ + { role: "system", content: systemPrompt }, + { role: "user", content: "Hello!" }, + ], + stream: true, + stream_options: { include_usage: true }, + }) + }) - mockCreate.mockResolvedValueOnce( - (async function* () { - for (const chunk of mockStream) { - yield chunk; - } - })() - ); + it("should handle empty delta content", async () => { + const mockStream = [ + { choices: [{ delta: {} }], usage: null }, + { choices: [{ delta: { content: null } }], usage: null }, + { choices: [{ delta: { content: "Hello" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } }, + ] - const generator = handler.createMessage(systemPrompt, messages); - const results = []; - for await (const result of generator) { - results.push(result); - } + mockCreate.mockResolvedValueOnce( + (async function* () { + for (const chunk of mockStream) { + yield chunk + } + })(), + ) - expect(results).toEqual([ - { type: 'text', text: 'Hello' }, - { type: 'usage', inputTokens: 10, outputTokens: 5 }, - ]); - }); - }); + const generator = handler.createMessage(systemPrompt, messages) + const results = [] + for await (const result of generator) { + results.push(result) + } - describe('completePrompt', () => { - it('should complete prompt successfully with gpt-4o model', async () => { - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe('Test response'); - expect(mockCreate).toHaveBeenCalledWith({ - model: 'gpt-4o', - messages: [{ role: 'user', content: 'Test prompt' }], - temperature: 0 - }); - }); + expect(results).toEqual([ + { type: "text", text: "Hello" }, + { type: "usage", inputTokens: 10, outputTokens: 5 }, + ]) + }) + }) - it('should complete prompt successfully with o1 model', async () => { - handler = new OpenAiNativeHandler({ - apiModelId: 'o1', - openAiNativeApiKey: 'test-api-key' - }); + describe("completePrompt", () => { + it("should complete prompt successfully with gpt-4o model", async () => { + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test response") + expect(mockCreate).toHaveBeenCalledWith({ + model: "gpt-4o", + messages: [{ role: "user", content: "Test prompt" }], + temperature: 0, + }) + }) - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe('Test response'); - expect(mockCreate).toHaveBeenCalledWith({ - model: 'o1', - messages: [{ role: 'user', content: 'Test prompt' }] - }); - }); + it("should complete prompt successfully with o1 model", async () => { + handler = new OpenAiNativeHandler({ + apiModelId: "o1", + openAiNativeApiKey: "test-api-key", + }) - it('should complete prompt successfully with o1-preview model', async () => { - handler = new OpenAiNativeHandler({ - apiModelId: 'o1-preview', - openAiNativeApiKey: 'test-api-key' - }); + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test response") + expect(mockCreate).toHaveBeenCalledWith({ + model: "o1", + messages: [{ role: "user", content: "Test prompt" }], + }) + }) - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe('Test response'); - expect(mockCreate).toHaveBeenCalledWith({ - model: 'o1-preview', - messages: [{ role: 'user', content: 'Test prompt' }] - }); - }); + it("should complete prompt successfully with o1-preview model", async () => { + handler = new OpenAiNativeHandler({ + apiModelId: "o1-preview", + openAiNativeApiKey: "test-api-key", + }) - it('should complete prompt successfully with o1-mini model', async () => { - handler = new OpenAiNativeHandler({ - apiModelId: 'o1-mini', - openAiNativeApiKey: 'test-api-key' - }); + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test response") + expect(mockCreate).toHaveBeenCalledWith({ + model: "o1-preview", + messages: [{ role: "user", content: "Test prompt" }], + }) + }) - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe('Test response'); - expect(mockCreate).toHaveBeenCalledWith({ - model: 'o1-mini', - messages: [{ role: 'user', content: 'Test prompt' }] - }); - }); + it("should complete prompt successfully with o1-mini model", async () => { + handler = new OpenAiNativeHandler({ + apiModelId: "o1-mini", + openAiNativeApiKey: "test-api-key", + }) - it('should handle API errors', async () => { - mockCreate.mockRejectedValueOnce(new Error('API Error')); - await expect(handler.completePrompt('Test prompt')) - .rejects.toThrow('OpenAI Native completion error: API Error'); - }); + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test response") + expect(mockCreate).toHaveBeenCalledWith({ + model: "o1-mini", + messages: [{ role: "user", content: "Test prompt" }], + }) + }) - it('should handle empty response', async () => { - mockCreate.mockResolvedValueOnce({ - choices: [{ message: { content: '' } }] - }); - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe(''); - }); - }); + it("should handle API errors", async () => { + mockCreate.mockRejectedValueOnce(new Error("API Error")) + await expect(handler.completePrompt("Test prompt")).rejects.toThrow( + "OpenAI Native completion error: API Error", + ) + }) - describe('getModel', () => { - it('should return model info', () => { - const modelInfo = handler.getModel(); - expect(modelInfo.id).toBe(mockOptions.apiModelId); - expect(modelInfo.info).toBeDefined(); - expect(modelInfo.info.maxTokens).toBe(4096); - expect(modelInfo.info.contextWindow).toBe(128_000); - }); + it("should handle empty response", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "" } }], + }) + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("") + }) + }) - it('should handle undefined model ID', () => { - const handlerWithoutModel = new OpenAiNativeHandler({ - openAiNativeApiKey: 'test-api-key' - }); - const modelInfo = handlerWithoutModel.getModel(); - expect(modelInfo.id).toBe('gpt-4o'); // Default model - expect(modelInfo.info).toBeDefined(); - }); - }); -}); + describe("getModel", () => { + it("should return model info", () => { + const modelInfo = handler.getModel() + expect(modelInfo.id).toBe(mockOptions.apiModelId) + expect(modelInfo.info).toBeDefined() + expect(modelInfo.info.maxTokens).toBe(4096) + expect(modelInfo.info.contextWindow).toBe(128_000) + }) + + it("should handle undefined model ID", () => { + const handlerWithoutModel = new OpenAiNativeHandler({ + openAiNativeApiKey: "test-api-key", + }) + const modelInfo = handlerWithoutModel.getModel() + expect(modelInfo.id).toBe("gpt-4o") // Default model + expect(modelInfo.info).toBeDefined() + }) + }) +}) diff --git a/src/api/providers/__tests__/openai.test.ts b/src/api/providers/__tests__/openai.test.ts index 4a4a449..ba65971 100644 --- a/src/api/providers/__tests__/openai.test.ts +++ b/src/api/providers/__tests__/openai.test.ts @@ -1,224 +1,233 @@ -import { OpenAiHandler } from '../openai'; -import { ApiHandlerOptions } from '../../../shared/api'; -import { ApiStream } from '../../transform/stream'; -import OpenAI from 'openai'; -import { Anthropic } from '@anthropic-ai/sdk'; +import { OpenAiHandler } from "../openai" +import { ApiHandlerOptions } from "../../../shared/api" +import { ApiStream } from "../../transform/stream" +import OpenAI from "openai" +import { Anthropic } from "@anthropic-ai/sdk" // Mock OpenAI client -const mockCreate = jest.fn(); -jest.mock('openai', () => { - return { - __esModule: true, - default: jest.fn().mockImplementation(() => ({ - chat: { - completions: { - create: mockCreate.mockImplementation(async (options) => { - if (!options.stream) { - return { - id: 'test-completion', - choices: [{ - message: { role: 'assistant', content: 'Test response', refusal: null }, - finish_reason: 'stop', - index: 0 - }], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15 - } - }; - } - - return { - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ - delta: { content: 'Test response' }, - index: 0 - }], - usage: null - }; - yield { - choices: [{ - delta: {}, - index: 0 - }], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15 - } - }; - } - }; - }) - } - } - })) - }; -}); +const mockCreate = jest.fn() +jest.mock("openai", () => { + return { + __esModule: true, + default: jest.fn().mockImplementation(() => ({ + chat: { + completions: { + create: mockCreate.mockImplementation(async (options) => { + if (!options.stream) { + return { + id: "test-completion", + choices: [ + { + message: { role: "assistant", content: "Test response", refusal: null }, + finish_reason: "stop", + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + } -describe('OpenAiHandler', () => { - let handler: OpenAiHandler; - let mockOptions: ApiHandlerOptions; + return { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { content: "Test response" }, + index: 0, + }, + ], + usage: null, + } + yield { + choices: [ + { + delta: {}, + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + }, + } + }), + }, + }, + })), + } +}) - beforeEach(() => { - mockOptions = { - openAiApiKey: 'test-api-key', - openAiModelId: 'gpt-4', - openAiBaseUrl: 'https://api.openai.com/v1' - }; - handler = new OpenAiHandler(mockOptions); - mockCreate.mockClear(); - }); +describe("OpenAiHandler", () => { + let handler: OpenAiHandler + let mockOptions: ApiHandlerOptions - describe('constructor', () => { - it('should initialize with provided options', () => { - expect(handler).toBeInstanceOf(OpenAiHandler); - expect(handler.getModel().id).toBe(mockOptions.openAiModelId); - }); + beforeEach(() => { + mockOptions = { + openAiApiKey: "test-api-key", + openAiModelId: "gpt-4", + openAiBaseUrl: "https://api.openai.com/v1", + } + handler = new OpenAiHandler(mockOptions) + mockCreate.mockClear() + }) - it('should use custom base URL if provided', () => { - const customBaseUrl = 'https://custom.openai.com/v1'; - const handlerWithCustomUrl = new OpenAiHandler({ - ...mockOptions, - openAiBaseUrl: customBaseUrl - }); - expect(handlerWithCustomUrl).toBeInstanceOf(OpenAiHandler); - }); - }); + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(OpenAiHandler) + expect(handler.getModel().id).toBe(mockOptions.openAiModelId) + }) - describe('createMessage', () => { - const systemPrompt = 'You are a helpful assistant.'; - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: 'user', - content: [{ - type: 'text' as const, - text: 'Hello!' - }] - } - ]; + it("should use custom base URL if provided", () => { + const customBaseUrl = "https://custom.openai.com/v1" + const handlerWithCustomUrl = new OpenAiHandler({ + ...mockOptions, + openAiBaseUrl: customBaseUrl, + }) + expect(handlerWithCustomUrl).toBeInstanceOf(OpenAiHandler) + }) + }) - it('should handle non-streaming mode', async () => { - const handler = new OpenAiHandler({ - ...mockOptions, - openAiStreamingEnabled: false - }); + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text" as const, + text: "Hello!", + }, + ], + }, + ] - const stream = handler.createMessage(systemPrompt, messages); - const chunks: any[] = []; - for await (const chunk of stream) { - chunks.push(chunk); - } + it("should handle non-streaming mode", async () => { + const handler = new OpenAiHandler({ + ...mockOptions, + openAiStreamingEnabled: false, + }) - expect(chunks.length).toBeGreaterThan(0); - const textChunk = chunks.find(chunk => chunk.type === 'text'); - const usageChunk = chunks.find(chunk => chunk.type === 'usage'); - - expect(textChunk).toBeDefined(); - expect(textChunk?.text).toBe('Test response'); - expect(usageChunk).toBeDefined(); - expect(usageChunk?.inputTokens).toBe(10); - expect(usageChunk?.outputTokens).toBe(5); - }); + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } - it('should handle streaming responses', async () => { - const stream = handler.createMessage(systemPrompt, messages); - const chunks: any[] = []; - for await (const chunk of stream) { - chunks.push(chunk); - } + expect(chunks.length).toBeGreaterThan(0) + const textChunk = chunks.find((chunk) => chunk.type === "text") + const usageChunk = chunks.find((chunk) => chunk.type === "usage") - expect(chunks.length).toBeGreaterThan(0); - const textChunks = chunks.filter(chunk => chunk.type === 'text'); - expect(textChunks).toHaveLength(1); - expect(textChunks[0].text).toBe('Test response'); - }); - }); + expect(textChunk).toBeDefined() + expect(textChunk?.text).toBe("Test response") + expect(usageChunk).toBeDefined() + expect(usageChunk?.inputTokens).toBe(10) + expect(usageChunk?.outputTokens).toBe(5) + }) - describe('error handling', () => { - const testMessages: Anthropic.Messages.MessageParam[] = [ - { - role: 'user', - content: [{ - type: 'text' as const, - text: 'Hello' - }] - } - ]; + it("should handle streaming responses", async () => { + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } - it('should handle API errors', async () => { - mockCreate.mockRejectedValueOnce(new Error('API Error')); + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response") + }) + }) - const stream = handler.createMessage('system prompt', testMessages); + describe("error handling", () => { + const testMessages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text" as const, + text: "Hello", + }, + ], + }, + ] - await expect(async () => { - for await (const chunk of stream) { - // Should not reach here - } - }).rejects.toThrow('API Error'); - }); + it("should handle API errors", async () => { + mockCreate.mockRejectedValueOnce(new Error("API Error")) - it('should handle rate limiting', async () => { - const rateLimitError = new Error('Rate limit exceeded'); - rateLimitError.name = 'Error'; - (rateLimitError as any).status = 429; - mockCreate.mockRejectedValueOnce(rateLimitError); + const stream = handler.createMessage("system prompt", testMessages) - const stream = handler.createMessage('system prompt', testMessages); + await expect(async () => { + for await (const chunk of stream) { + // Should not reach here + } + }).rejects.toThrow("API Error") + }) - await expect(async () => { - for await (const chunk of stream) { - // Should not reach here - } - }).rejects.toThrow('Rate limit exceeded'); - }); - }); + it("should handle rate limiting", async () => { + const rateLimitError = new Error("Rate limit exceeded") + rateLimitError.name = "Error" + ;(rateLimitError as any).status = 429 + mockCreate.mockRejectedValueOnce(rateLimitError) - describe('completePrompt', () => { - it('should complete prompt successfully', async () => { - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe('Test response'); - expect(mockCreate).toHaveBeenCalledWith({ - model: mockOptions.openAiModelId, - messages: [{ role: 'user', content: 'Test prompt' }], - temperature: 0 - }); - }); + const stream = handler.createMessage("system prompt", testMessages) - it('should handle API errors', async () => { - mockCreate.mockRejectedValueOnce(new Error('API Error')); - await expect(handler.completePrompt('Test prompt')) - .rejects.toThrow('OpenAI completion error: API Error'); - }); + await expect(async () => { + for await (const chunk of stream) { + // Should not reach here + } + }).rejects.toThrow("Rate limit exceeded") + }) + }) - it('should handle empty response', async () => { - mockCreate.mockImplementationOnce(() => ({ - choices: [{ message: { content: '' } }] - })); - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe(''); - }); - }); + describe("completePrompt", () => { + it("should complete prompt successfully", async () => { + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test response") + expect(mockCreate).toHaveBeenCalledWith({ + model: mockOptions.openAiModelId, + messages: [{ role: "user", content: "Test prompt" }], + temperature: 0, + }) + }) - describe('getModel', () => { - it('should return model info with sane defaults', () => { - const model = handler.getModel(); - expect(model.id).toBe(mockOptions.openAiModelId); - expect(model.info).toBeDefined(); - expect(model.info.contextWindow).toBe(128_000); - expect(model.info.supportsImages).toBe(true); - }); + it("should handle API errors", async () => { + mockCreate.mockRejectedValueOnce(new Error("API Error")) + await expect(handler.completePrompt("Test prompt")).rejects.toThrow("OpenAI completion error: API Error") + }) - it('should handle undefined model ID', () => { - const handlerWithoutModel = new OpenAiHandler({ - ...mockOptions, - openAiModelId: undefined - }); - const model = handlerWithoutModel.getModel(); - expect(model.id).toBe(''); - expect(model.info).toBeDefined(); - }); - }); -}); \ No newline at end of file + it("should handle empty response", async () => { + mockCreate.mockImplementationOnce(() => ({ + choices: [{ message: { content: "" } }], + })) + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("") + }) + }) + + describe("getModel", () => { + it("should return model info with sane defaults", () => { + const model = handler.getModel() + expect(model.id).toBe(mockOptions.openAiModelId) + expect(model.info).toBeDefined() + expect(model.info.contextWindow).toBe(128_000) + expect(model.info.supportsImages).toBe(true) + }) + + it("should handle undefined model ID", () => { + const handlerWithoutModel = new OpenAiHandler({ + ...mockOptions, + openAiModelId: undefined, + }) + const model = handlerWithoutModel.getModel() + expect(model.id).toBe("") + expect(model.info).toBeDefined() + }) + }) +}) diff --git a/src/api/providers/__tests__/openrouter.test.ts b/src/api/providers/__tests__/openrouter.test.ts index fb24516..b395e27 100644 --- a/src/api/providers/__tests__/openrouter.test.ts +++ b/src/api/providers/__tests__/openrouter.test.ts @@ -1,283 +1,297 @@ -import { OpenRouterHandler } from '../openrouter' -import { ApiHandlerOptions, ModelInfo } from '../../../shared/api' -import OpenAI from 'openai' -import axios from 'axios' -import { Anthropic } from '@anthropic-ai/sdk' +import { OpenRouterHandler } from "../openrouter" +import { ApiHandlerOptions, ModelInfo } from "../../../shared/api" +import OpenAI from "openai" +import axios from "axios" +import { Anthropic } from "@anthropic-ai/sdk" // Mock dependencies -jest.mock('openai') -jest.mock('axios') -jest.mock('delay', () => jest.fn(() => Promise.resolve())) +jest.mock("openai") +jest.mock("axios") +jest.mock("delay", () => jest.fn(() => Promise.resolve())) -describe('OpenRouterHandler', () => { - const mockOptions: ApiHandlerOptions = { - openRouterApiKey: 'test-key', - openRouterModelId: 'test-model', - openRouterModelInfo: { - name: 'Test Model', - description: 'Test Description', - maxTokens: 1000, - contextWindow: 2000, - supportsPromptCache: true, - inputPrice: 0.01, - outputPrice: 0.02 - } as ModelInfo - } +describe("OpenRouterHandler", () => { + const mockOptions: ApiHandlerOptions = { + openRouterApiKey: "test-key", + openRouterModelId: "test-model", + openRouterModelInfo: { + name: "Test Model", + description: "Test Description", + maxTokens: 1000, + contextWindow: 2000, + supportsPromptCache: true, + inputPrice: 0.01, + outputPrice: 0.02, + } as ModelInfo, + } - beforeEach(() => { - jest.clearAllMocks() - }) + beforeEach(() => { + jest.clearAllMocks() + }) - test('constructor initializes with correct options', () => { - const handler = new OpenRouterHandler(mockOptions) - expect(handler).toBeInstanceOf(OpenRouterHandler) - expect(OpenAI).toHaveBeenCalledWith({ - baseURL: 'https://openrouter.ai/api/v1', - apiKey: mockOptions.openRouterApiKey, - defaultHeaders: { - 'HTTP-Referer': 'https://github.com/RooVetGit/Roo-Cline', - 'X-Title': 'Roo-Cline', - }, - }) - }) + test("constructor initializes with correct options", () => { + const handler = new OpenRouterHandler(mockOptions) + expect(handler).toBeInstanceOf(OpenRouterHandler) + expect(OpenAI).toHaveBeenCalledWith({ + baseURL: "https://openrouter.ai/api/v1", + apiKey: mockOptions.openRouterApiKey, + defaultHeaders: { + "HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline", + "X-Title": "Roo-Cline", + }, + }) + }) - test('getModel returns correct model info when options are provided', () => { - const handler = new OpenRouterHandler(mockOptions) - const result = handler.getModel() - - expect(result).toEqual({ - id: mockOptions.openRouterModelId, - info: mockOptions.openRouterModelInfo - }) - }) + test("getModel returns correct model info when options are provided", () => { + const handler = new OpenRouterHandler(mockOptions) + const result = handler.getModel() - test('getModel returns default model info when options are not provided', () => { - const handler = new OpenRouterHandler({}) - const result = handler.getModel() - - expect(result.id).toBe('anthropic/claude-3.5-sonnet:beta') - expect(result.info.supportsPromptCache).toBe(true) - }) + expect(result).toEqual({ + id: mockOptions.openRouterModelId, + info: mockOptions.openRouterModelInfo, + }) + }) - test('createMessage generates correct stream chunks', async () => { - const handler = new OpenRouterHandler(mockOptions) - const mockStream = { - async *[Symbol.asyncIterator]() { - yield { - id: 'test-id', - choices: [{ - delta: { - content: 'test response' - } - }] - } - } - } + test("getModel returns default model info when options are not provided", () => { + const handler = new OpenRouterHandler({}) + const result = handler.getModel() - // Mock OpenAI chat.completions.create - const mockCreate = jest.fn().mockResolvedValue(mockStream) - ;(OpenAI as jest.MockedClass).prototype.chat = { - completions: { create: mockCreate } - } as any + expect(result.id).toBe("anthropic/claude-3.5-sonnet:beta") + expect(result.info.supportsPromptCache).toBe(true) + }) - // Mock axios.get for generation details - ;(axios.get as jest.Mock).mockResolvedValue({ - data: { - data: { - native_tokens_prompt: 10, - native_tokens_completion: 20, - total_cost: 0.001 - } - } - }) + test("createMessage generates correct stream chunks", async () => { + const handler = new OpenRouterHandler(mockOptions) + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + id: "test-id", + choices: [ + { + delta: { + content: "test response", + }, + }, + ], + } + }, + } - const systemPrompt = 'test system prompt' - const messages: Anthropic.Messages.MessageParam[] = [{ role: 'user' as const, content: 'test message' }] + // Mock OpenAI chat.completions.create + const mockCreate = jest.fn().mockResolvedValue(mockStream) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate }, + } as any - const generator = handler.createMessage(systemPrompt, messages) - const chunks = [] - - for await (const chunk of generator) { - chunks.push(chunk) - } + // Mock axios.get for generation details + ;(axios.get as jest.Mock).mockResolvedValue({ + data: { + data: { + native_tokens_prompt: 10, + native_tokens_completion: 20, + total_cost: 0.001, + }, + }, + }) - // Verify stream chunks - expect(chunks).toHaveLength(2) // One text chunk and one usage chunk - expect(chunks[0]).toEqual({ - type: 'text', - text: 'test response' - }) - expect(chunks[1]).toEqual({ - type: 'usage', - inputTokens: 10, - outputTokens: 20, - totalCost: 0.001, - fullResponseText: 'test response' - }) + const systemPrompt = "test system prompt" + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user" as const, content: "test message" }] - // Verify OpenAI client was called with correct parameters - expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ - model: mockOptions.openRouterModelId, - temperature: 0, - messages: expect.arrayContaining([ - { role: 'system', content: systemPrompt }, - { role: 'user', content: 'test message' } - ]), - stream: true - })) - }) + const generator = handler.createMessage(systemPrompt, messages) + const chunks = [] - test('createMessage with middle-out transform enabled', async () => { - const handler = new OpenRouterHandler({ - ...mockOptions, - openRouterUseMiddleOutTransform: true - }) - const mockStream = { - async *[Symbol.asyncIterator]() { - yield { - id: 'test-id', - choices: [{ - delta: { - content: 'test response' - } - }] - } - } - } + for await (const chunk of generator) { + chunks.push(chunk) + } - const mockCreate = jest.fn().mockResolvedValue(mockStream) - ;(OpenAI as jest.MockedClass).prototype.chat = { - completions: { create: mockCreate } - } as any - ;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } }) + // Verify stream chunks + expect(chunks).toHaveLength(2) // One text chunk and one usage chunk + expect(chunks[0]).toEqual({ + type: "text", + text: "test response", + }) + expect(chunks[1]).toEqual({ + type: "usage", + inputTokens: 10, + outputTokens: 20, + totalCost: 0.001, + fullResponseText: "test response", + }) - await handler.createMessage('test', []).next() + // Verify OpenAI client was called with correct parameters + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: mockOptions.openRouterModelId, + temperature: 0, + messages: expect.arrayContaining([ + { role: "system", content: systemPrompt }, + { role: "user", content: "test message" }, + ]), + stream: true, + }), + ) + }) - expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ - transforms: ['middle-out'] - })) - }) + test("createMessage with middle-out transform enabled", async () => { + const handler = new OpenRouterHandler({ + ...mockOptions, + openRouterUseMiddleOutTransform: true, + }) + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + id: "test-id", + choices: [ + { + delta: { + content: "test response", + }, + }, + ], + } + }, + } - test('createMessage with Claude model adds cache control', async () => { - const handler = new OpenRouterHandler({ - ...mockOptions, - openRouterModelId: 'anthropic/claude-3.5-sonnet' - }) - const mockStream = { - async *[Symbol.asyncIterator]() { - yield { - id: 'test-id', - choices: [{ - delta: { - content: 'test response' - } - }] - } - } - } + const mockCreate = jest.fn().mockResolvedValue(mockStream) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate }, + } as any + ;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } }) - const mockCreate = jest.fn().mockResolvedValue(mockStream) - ;(OpenAI as jest.MockedClass).prototype.chat = { - completions: { create: mockCreate } - } as any - ;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } }) + await handler.createMessage("test", []).next() - const messages: Anthropic.Messages.MessageParam[] = [ - { role: 'user', content: 'message 1' }, - { role: 'assistant', content: 'response 1' }, - { role: 'user', content: 'message 2' } - ] + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + transforms: ["middle-out"], + }), + ) + }) - await handler.createMessage('test system', messages).next() + test("createMessage with Claude model adds cache control", async () => { + const handler = new OpenRouterHandler({ + ...mockOptions, + openRouterModelId: "anthropic/claude-3.5-sonnet", + }) + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + id: "test-id", + choices: [ + { + delta: { + content: "test response", + }, + }, + ], + } + }, + } - expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ - messages: expect.arrayContaining([ - expect.objectContaining({ - role: 'system', - content: expect.arrayContaining([ - expect.objectContaining({ - cache_control: { type: 'ephemeral' } - }) - ]) - }) - ]) - })) - }) + const mockCreate = jest.fn().mockResolvedValue(mockStream) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate }, + } as any + ;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } }) - test('createMessage handles API errors', async () => { - const handler = new OpenRouterHandler(mockOptions) - const mockStream = { - async *[Symbol.asyncIterator]() { - yield { - error: { - message: 'API Error', - code: 500 - } - } - } - } + const messages: Anthropic.Messages.MessageParam[] = [ + { role: "user", content: "message 1" }, + { role: "assistant", content: "response 1" }, + { role: "user", content: "message 2" }, + ] - const mockCreate = jest.fn().mockResolvedValue(mockStream) - ;(OpenAI as jest.MockedClass).prototype.chat = { - completions: { create: mockCreate } - } as any + await handler.createMessage("test system", messages).next() - const generator = handler.createMessage('test', []) - await expect(generator.next()).rejects.toThrow('OpenRouter API Error 500: API Error') - }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + messages: expect.arrayContaining([ + expect.objectContaining({ + role: "system", + content: expect.arrayContaining([ + expect.objectContaining({ + cache_control: { type: "ephemeral" }, + }), + ]), + }), + ]), + }), + ) + }) - test('completePrompt returns correct response', async () => { - const handler = new OpenRouterHandler(mockOptions) - const mockResponse = { - choices: [{ - message: { - content: 'test completion' - } - }] - } + test("createMessage handles API errors", async () => { + const handler = new OpenRouterHandler(mockOptions) + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + error: { + message: "API Error", + code: 500, + }, + } + }, + } - const mockCreate = jest.fn().mockResolvedValue(mockResponse) - ;(OpenAI as jest.MockedClass).prototype.chat = { - completions: { create: mockCreate } - } as any + const mockCreate = jest.fn().mockResolvedValue(mockStream) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate }, + } as any - const result = await handler.completePrompt('test prompt') + const generator = handler.createMessage("test", []) + await expect(generator.next()).rejects.toThrow("OpenRouter API Error 500: API Error") + }) - expect(result).toBe('test completion') - expect(mockCreate).toHaveBeenCalledWith({ - model: mockOptions.openRouterModelId, - messages: [{ role: 'user', content: 'test prompt' }], - temperature: 0, - stream: false - }) - }) + test("completePrompt returns correct response", async () => { + const handler = new OpenRouterHandler(mockOptions) + const mockResponse = { + choices: [ + { + message: { + content: "test completion", + }, + }, + ], + } - test('completePrompt handles API errors', async () => { - const handler = new OpenRouterHandler(mockOptions) - const mockError = { - error: { - message: 'API Error', - code: 500 - } - } + const mockCreate = jest.fn().mockResolvedValue(mockResponse) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate }, + } as any - const mockCreate = jest.fn().mockResolvedValue(mockError) - ;(OpenAI as jest.MockedClass).prototype.chat = { - completions: { create: mockCreate } - } as any + const result = await handler.completePrompt("test prompt") - await expect(handler.completePrompt('test prompt')) - .rejects.toThrow('OpenRouter API Error 500: API Error') - }) + expect(result).toBe("test completion") + expect(mockCreate).toHaveBeenCalledWith({ + model: mockOptions.openRouterModelId, + messages: [{ role: "user", content: "test prompt" }], + temperature: 0, + stream: false, + }) + }) - test('completePrompt handles unexpected errors', async () => { - const handler = new OpenRouterHandler(mockOptions) - const mockCreate = jest.fn().mockRejectedValue(new Error('Unexpected error')) - ;(OpenAI as jest.MockedClass).prototype.chat = { - completions: { create: mockCreate } - } as any + test("completePrompt handles API errors", async () => { + const handler = new OpenRouterHandler(mockOptions) + const mockError = { + error: { + message: "API Error", + code: 500, + }, + } - await expect(handler.completePrompt('test prompt')) - .rejects.toThrow('OpenRouter completion error: Unexpected error') - }) + const mockCreate = jest.fn().mockResolvedValue(mockError) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate }, + } as any + + await expect(handler.completePrompt("test prompt")).rejects.toThrow("OpenRouter API Error 500: API Error") + }) + + test("completePrompt handles unexpected errors", async () => { + const handler = new OpenRouterHandler(mockOptions) + const mockCreate = jest.fn().mockRejectedValue(new Error("Unexpected error")) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate }, + } as any + + await expect(handler.completePrompt("test prompt")).rejects.toThrow( + "OpenRouter completion error: Unexpected error", + ) + }) }) diff --git a/src/api/providers/__tests__/vertex.test.ts b/src/api/providers/__tests__/vertex.test.ts index be5899f..a51033a 100644 --- a/src/api/providers/__tests__/vertex.test.ts +++ b/src/api/providers/__tests__/vertex.test.ts @@ -1,296 +1,295 @@ -import { VertexHandler } from '../vertex'; -import { Anthropic } from '@anthropic-ai/sdk'; -import { AnthropicVertex } from '@anthropic-ai/vertex-sdk'; +import { VertexHandler } from "../vertex" +import { Anthropic } from "@anthropic-ai/sdk" +import { AnthropicVertex } from "@anthropic-ai/vertex-sdk" // Mock Vertex SDK -jest.mock('@anthropic-ai/vertex-sdk', () => ({ - AnthropicVertex: jest.fn().mockImplementation(() => ({ - messages: { - create: jest.fn().mockImplementation(async (options) => { - if (!options.stream) { - return { - id: 'test-completion', - content: [ - { type: 'text', text: 'Test response' } - ], - role: 'assistant', - model: options.model, - usage: { - input_tokens: 10, - output_tokens: 5 - } - } - } - return { - async *[Symbol.asyncIterator]() { - yield { - type: 'message_start', - message: { - usage: { - input_tokens: 10, - output_tokens: 5 - } - } - } - yield { - type: 'content_block_start', - content_block: { - type: 'text', - text: 'Test response' - } - } - } - } - }) - } - })) -})); +jest.mock("@anthropic-ai/vertex-sdk", () => ({ + AnthropicVertex: jest.fn().mockImplementation(() => ({ + messages: { + create: jest.fn().mockImplementation(async (options) => { + if (!options.stream) { + return { + id: "test-completion", + content: [{ type: "text", text: "Test response" }], + role: "assistant", + model: options.model, + usage: { + input_tokens: 10, + output_tokens: 5, + }, + } + } + return { + async *[Symbol.asyncIterator]() { + yield { + type: "message_start", + message: { + usage: { + input_tokens: 10, + output_tokens: 5, + }, + }, + } + yield { + type: "content_block_start", + content_block: { + type: "text", + text: "Test response", + }, + } + }, + } + }), + }, + })), +})) -describe('VertexHandler', () => { - let handler: VertexHandler; +describe("VertexHandler", () => { + let handler: VertexHandler - beforeEach(() => { - handler = new VertexHandler({ - apiModelId: 'claude-3-5-sonnet-v2@20241022', - vertexProjectId: 'test-project', - vertexRegion: 'us-central1' - }); - }); + beforeEach(() => { + handler = new VertexHandler({ + apiModelId: "claude-3-5-sonnet-v2@20241022", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + }) - describe('constructor', () => { - it('should initialize with provided config', () => { - expect(AnthropicVertex).toHaveBeenCalledWith({ - projectId: 'test-project', - region: 'us-central1' - }); - }); - }); + describe("constructor", () => { + it("should initialize with provided config", () => { + expect(AnthropicVertex).toHaveBeenCalledWith({ + projectId: "test-project", + region: "us-central1", + }) + }) + }) - describe('createMessage', () => { - const mockMessages: Anthropic.Messages.MessageParam[] = [ - { - role: 'user', - content: 'Hello' - }, - { - role: 'assistant', - content: 'Hi there!' - } - ]; + describe("createMessage", () => { + const mockMessages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: "Hello", + }, + { + role: "assistant", + content: "Hi there!", + }, + ] - const systemPrompt = 'You are a helpful assistant'; + const systemPrompt = "You are a helpful assistant" - it('should handle streaming responses correctly', async () => { - const mockStream = [ - { - type: 'message_start', - message: { - usage: { - input_tokens: 10, - output_tokens: 0 - } - } - }, - { - type: 'content_block_start', - index: 0, - content_block: { - type: 'text', - text: 'Hello' - } - }, - { - type: 'content_block_delta', - delta: { - type: 'text_delta', - text: ' world!' - } - }, - { - type: 'message_delta', - usage: { - output_tokens: 5 - } - } - ]; + it("should handle streaming responses correctly", async () => { + const mockStream = [ + { + type: "message_start", + message: { + usage: { + input_tokens: 10, + output_tokens: 0, + }, + }, + }, + { + type: "content_block_start", + index: 0, + content_block: { + type: "text", + text: "Hello", + }, + }, + { + type: "content_block_delta", + delta: { + type: "text_delta", + text: " world!", + }, + }, + { + type: "message_delta", + usage: { + output_tokens: 5, + }, + }, + ] - // Setup async iterator for mock stream - const asyncIterator = { - async *[Symbol.asyncIterator]() { - for (const chunk of mockStream) { - yield chunk; - } - } - }; + // Setup async iterator for mock stream + const asyncIterator = { + async *[Symbol.asyncIterator]() { + for (const chunk of mockStream) { + yield chunk + } + }, + } - const mockCreate = jest.fn().mockResolvedValue(asyncIterator); - (handler['client'].messages as any).create = mockCreate; + const mockCreate = jest.fn().mockResolvedValue(asyncIterator) + ;(handler["client"].messages as any).create = mockCreate - const stream = handler.createMessage(systemPrompt, mockMessages); - const chunks = []; - - for await (const chunk of stream) { - chunks.push(chunk); - } + const stream = handler.createMessage(systemPrompt, mockMessages) + const chunks = [] - expect(chunks.length).toBe(4); - expect(chunks[0]).toEqual({ - type: 'usage', - inputTokens: 10, - outputTokens: 0 - }); - expect(chunks[1]).toEqual({ - type: 'text', - text: 'Hello' - }); - expect(chunks[2]).toEqual({ - type: 'text', - text: ' world!' - }); - expect(chunks[3]).toEqual({ - type: 'usage', - inputTokens: 0, - outputTokens: 5 - }); + for await (const chunk of stream) { + chunks.push(chunk) + } - expect(mockCreate).toHaveBeenCalledWith({ - model: 'claude-3-5-sonnet-v2@20241022', - max_tokens: 8192, - temperature: 0, - system: systemPrompt, - messages: mockMessages, - stream: true - }); - }); + expect(chunks.length).toBe(4) + expect(chunks[0]).toEqual({ + type: "usage", + inputTokens: 10, + outputTokens: 0, + }) + expect(chunks[1]).toEqual({ + type: "text", + text: "Hello", + }) + expect(chunks[2]).toEqual({ + type: "text", + text: " world!", + }) + expect(chunks[3]).toEqual({ + type: "usage", + inputTokens: 0, + outputTokens: 5, + }) - it('should handle multiple content blocks with line breaks', async () => { - const mockStream = [ - { - type: 'content_block_start', - index: 0, - content_block: { - type: 'text', - text: 'First line' - } - }, - { - type: 'content_block_start', - index: 1, - content_block: { - type: 'text', - text: 'Second line' - } - } - ]; + expect(mockCreate).toHaveBeenCalledWith({ + model: "claude-3-5-sonnet-v2@20241022", + max_tokens: 8192, + temperature: 0, + system: systemPrompt, + messages: mockMessages, + stream: true, + }) + }) - const asyncIterator = { - async *[Symbol.asyncIterator]() { - for (const chunk of mockStream) { - yield chunk; - } - } - }; + it("should handle multiple content blocks with line breaks", async () => { + const mockStream = [ + { + type: "content_block_start", + index: 0, + content_block: { + type: "text", + text: "First line", + }, + }, + { + type: "content_block_start", + index: 1, + content_block: { + type: "text", + text: "Second line", + }, + }, + ] - const mockCreate = jest.fn().mockResolvedValue(asyncIterator); - (handler['client'].messages as any).create = mockCreate; + const asyncIterator = { + async *[Symbol.asyncIterator]() { + for (const chunk of mockStream) { + yield chunk + } + }, + } - const stream = handler.createMessage(systemPrompt, mockMessages); - const chunks = []; - - for await (const chunk of stream) { - chunks.push(chunk); - } + const mockCreate = jest.fn().mockResolvedValue(asyncIterator) + ;(handler["client"].messages as any).create = mockCreate - expect(chunks.length).toBe(3); - expect(chunks[0]).toEqual({ - type: 'text', - text: 'First line' - }); - expect(chunks[1]).toEqual({ - type: 'text', - text: '\n' - }); - expect(chunks[2]).toEqual({ - type: 'text', - text: 'Second line' - }); - }); + const stream = handler.createMessage(systemPrompt, mockMessages) + const chunks = [] - it('should handle API errors', async () => { - const mockError = new Error('Vertex API error'); - const mockCreate = jest.fn().mockRejectedValue(mockError); - (handler['client'].messages as any).create = mockCreate; + for await (const chunk of stream) { + chunks.push(chunk) + } - const stream = handler.createMessage(systemPrompt, mockMessages); + expect(chunks.length).toBe(3) + expect(chunks[0]).toEqual({ + type: "text", + text: "First line", + }) + expect(chunks[1]).toEqual({ + type: "text", + text: "\n", + }) + expect(chunks[2]).toEqual({ + type: "text", + text: "Second line", + }) + }) - await expect(async () => { - for await (const chunk of stream) { - // Should throw before yielding any chunks - } - }).rejects.toThrow('Vertex API error'); - }); - }); + it("should handle API errors", async () => { + const mockError = new Error("Vertex API error") + const mockCreate = jest.fn().mockRejectedValue(mockError) + ;(handler["client"].messages as any).create = mockCreate - describe('completePrompt', () => { - it('should complete prompt successfully', async () => { - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe('Test response'); - expect(handler['client'].messages.create).toHaveBeenCalledWith({ - model: 'claude-3-5-sonnet-v2@20241022', - max_tokens: 8192, - temperature: 0, - messages: [{ role: 'user', content: 'Test prompt' }], - stream: false - }); - }); + const stream = handler.createMessage(systemPrompt, mockMessages) - it('should handle API errors', async () => { - const mockError = new Error('Vertex API error'); - const mockCreate = jest.fn().mockRejectedValue(mockError); - (handler['client'].messages as any).create = mockCreate; + await expect(async () => { + for await (const chunk of stream) { + // Should throw before yielding any chunks + } + }).rejects.toThrow("Vertex API error") + }) + }) - await expect(handler.completePrompt('Test prompt')) - .rejects.toThrow('Vertex completion error: Vertex API error'); - }); + describe("completePrompt", () => { + it("should complete prompt successfully", async () => { + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test response") + expect(handler["client"].messages.create).toHaveBeenCalledWith({ + model: "claude-3-5-sonnet-v2@20241022", + max_tokens: 8192, + temperature: 0, + messages: [{ role: "user", content: "Test prompt" }], + stream: false, + }) + }) - it('should handle non-text content', async () => { - const mockCreate = jest.fn().mockResolvedValue({ - content: [{ type: 'image' }] - }); - (handler['client'].messages as any).create = mockCreate; + it("should handle API errors", async () => { + const mockError = new Error("Vertex API error") + const mockCreate = jest.fn().mockRejectedValue(mockError) + ;(handler["client"].messages as any).create = mockCreate - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe(''); - }); + await expect(handler.completePrompt("Test prompt")).rejects.toThrow( + "Vertex completion error: Vertex API error", + ) + }) - it('should handle empty response', async () => { - const mockCreate = jest.fn().mockResolvedValue({ - content: [{ type: 'text', text: '' }] - }); - (handler['client'].messages as any).create = mockCreate; + it("should handle non-text content", async () => { + const mockCreate = jest.fn().mockResolvedValue({ + content: [{ type: "image" }], + }) + ;(handler["client"].messages as any).create = mockCreate - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe(''); - }); - }); + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("") + }) - describe('getModel', () => { - it('should return correct model info', () => { - const modelInfo = handler.getModel(); - expect(modelInfo.id).toBe('claude-3-5-sonnet-v2@20241022'); - expect(modelInfo.info).toBeDefined(); - expect(modelInfo.info.maxTokens).toBe(8192); - expect(modelInfo.info.contextWindow).toBe(200_000); - }); + it("should handle empty response", async () => { + const mockCreate = jest.fn().mockResolvedValue({ + content: [{ type: "text", text: "" }], + }) + ;(handler["client"].messages as any).create = mockCreate - it('should return default model if invalid model specified', () => { - const invalidHandler = new VertexHandler({ - apiModelId: 'invalid-model', - vertexProjectId: 'test-project', - vertexRegion: 'us-central1' - }); - const modelInfo = invalidHandler.getModel(); - expect(modelInfo.id).toBe('claude-3-5-sonnet-v2@20241022'); // Default model - }); - }); -}); \ No newline at end of file + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("") + }) + }) + + describe("getModel", () => { + it("should return correct model info", () => { + const modelInfo = handler.getModel() + expect(modelInfo.id).toBe("claude-3-5-sonnet-v2@20241022") + expect(modelInfo.info).toBeDefined() + expect(modelInfo.info.maxTokens).toBe(8192) + expect(modelInfo.info.contextWindow).toBe(200_000) + }) + + it("should return default model if invalid model specified", () => { + const invalidHandler = new VertexHandler({ + apiModelId: "invalid-model", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + const modelInfo = invalidHandler.getModel() + expect(modelInfo.id).toBe("claude-3-5-sonnet-v2@20241022") // Default model + }) + }) +}) diff --git a/src/api/providers/__tests__/vscode-lm.test.ts b/src/api/providers/__tests__/vscode-lm.test.ts index 396f13f..34e0d60 100644 --- a/src/api/providers/__tests__/vscode-lm.test.ts +++ b/src/api/providers/__tests__/vscode-lm.test.ts @@ -1,289 +1,295 @@ -import * as vscode from 'vscode'; -import { VsCodeLmHandler } from '../vscode-lm'; -import { ApiHandlerOptions } from '../../../shared/api'; -import { Anthropic } from '@anthropic-ai/sdk'; +import * as vscode from "vscode" +import { VsCodeLmHandler } from "../vscode-lm" +import { ApiHandlerOptions } from "../../../shared/api" +import { Anthropic } from "@anthropic-ai/sdk" // Mock vscode namespace -jest.mock('vscode', () => { +jest.mock("vscode", () => { class MockLanguageModelTextPart { - type = 'text'; + type = "text" constructor(public value: string) {} } class MockLanguageModelToolCallPart { - type = 'tool_call'; + type = "tool_call" constructor( public callId: string, public name: string, - public input: any + public input: any, ) {} } return { workspace: { onDidChangeConfiguration: jest.fn((callback) => ({ - dispose: jest.fn() - })) + dispose: jest.fn(), + })), }, CancellationTokenSource: jest.fn(() => ({ token: { isCancellationRequested: false, - onCancellationRequested: jest.fn() + onCancellationRequested: jest.fn(), }, cancel: jest.fn(), - dispose: jest.fn() + dispose: jest.fn(), })), CancellationError: class CancellationError extends Error { constructor() { - super('Operation cancelled'); - this.name = 'CancellationError'; + super("Operation cancelled") + this.name = "CancellationError" } }, LanguageModelChatMessage: { Assistant: jest.fn((content) => ({ - role: 'assistant', - content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)] + role: "assistant", + content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)], })), User: jest.fn((content) => ({ - role: 'user', - content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)] - })) + role: "user", + content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)], + })), }, LanguageModelTextPart: MockLanguageModelTextPart, LanguageModelToolCallPart: MockLanguageModelToolCallPart, lm: { - selectChatModels: jest.fn() - } - }; -}); + selectChatModels: jest.fn(), + }, + } +}) const mockLanguageModelChat = { - id: 'test-model', - name: 'Test Model', - vendor: 'test-vendor', - family: 'test-family', - version: '1.0', + id: "test-model", + name: "Test Model", + vendor: "test-vendor", + family: "test-family", + version: "1.0", maxInputTokens: 4096, sendRequest: jest.fn(), - countTokens: jest.fn() -}; + countTokens: jest.fn(), +} -describe('VsCodeLmHandler', () => { - let handler: VsCodeLmHandler; +describe("VsCodeLmHandler", () => { + let handler: VsCodeLmHandler const defaultOptions: ApiHandlerOptions = { vsCodeLmModelSelector: { - vendor: 'test-vendor', - family: 'test-family' - } - }; + vendor: "test-vendor", + family: "test-family", + }, + } beforeEach(() => { - jest.clearAllMocks(); - handler = new VsCodeLmHandler(defaultOptions); - }); + jest.clearAllMocks() + handler = new VsCodeLmHandler(defaultOptions) + }) afterEach(() => { - handler.dispose(); - }); + handler.dispose() + }) - describe('constructor', () => { - it('should initialize with provided options', () => { - expect(handler).toBeDefined(); - expect(vscode.workspace.onDidChangeConfiguration).toHaveBeenCalled(); - }); + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeDefined() + expect(vscode.workspace.onDidChangeConfiguration).toHaveBeenCalled() + }) - it('should handle configuration changes', () => { - const callback = (vscode.workspace.onDidChangeConfiguration as jest.Mock).mock.calls[0][0]; - callback({ affectsConfiguration: () => true }); + it("should handle configuration changes", () => { + const callback = (vscode.workspace.onDidChangeConfiguration as jest.Mock).mock.calls[0][0] + callback({ affectsConfiguration: () => true }) // Should reset client when config changes - expect(handler['client']).toBeNull(); - }); - }); + expect(handler["client"]).toBeNull() + }) + }) - describe('createClient', () => { - it('should create client with selector', async () => { - const mockModel = { ...mockLanguageModelChat }; - (vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]); + describe("createClient", () => { + it("should create client with selector", async () => { + const mockModel = { ...mockLanguageModelChat } + ;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]) - const client = await handler['createClient']({ - vendor: 'test-vendor', - family: 'test-family' - }); + const client = await handler["createClient"]({ + vendor: "test-vendor", + family: "test-family", + }) - expect(client).toBeDefined(); - expect(client.id).toBe('test-model'); + expect(client).toBeDefined() + expect(client.id).toBe("test-model") expect(vscode.lm.selectChatModels).toHaveBeenCalledWith({ - vendor: 'test-vendor', - family: 'test-family' - }); - }); + vendor: "test-vendor", + family: "test-family", + }) + }) - it('should return default client when no models available', async () => { - (vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([]); + it("should return default client when no models available", async () => { + ;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([]) - const client = await handler['createClient']({}); - - expect(client).toBeDefined(); - expect(client.id).toBe('default-lm'); - expect(client.vendor).toBe('vscode'); - }); - }); + const client = await handler["createClient"]({}) - describe('createMessage', () => { + expect(client).toBeDefined() + expect(client.id).toBe("default-lm") + expect(client.vendor).toBe("vscode") + }) + }) + + describe("createMessage", () => { beforeEach(() => { - const mockModel = { ...mockLanguageModelChat }; - (vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]); - mockLanguageModelChat.countTokens.mockResolvedValue(10); - }); + const mockModel = { ...mockLanguageModelChat } + ;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]) + mockLanguageModelChat.countTokens.mockResolvedValue(10) + }) - it('should stream text responses', async () => { - const systemPrompt = 'You are a helpful assistant'; - const messages: Anthropic.Messages.MessageParam[] = [{ - role: 'user' as const, - content: 'Hello' - }]; + it("should stream text responses", async () => { + const systemPrompt = "You are a helpful assistant" + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user" as const, + content: "Hello", + }, + ] - const responseText = 'Hello! How can I help you?'; + const responseText = "Hello! How can I help you?" mockLanguageModelChat.sendRequest.mockResolvedValueOnce({ stream: (async function* () { - yield new vscode.LanguageModelTextPart(responseText); - return; + yield new vscode.LanguageModelTextPart(responseText) + return })(), text: (async function* () { - yield responseText; - return; - })() - }); + yield responseText + return + })(), + }) - const stream = handler.createMessage(systemPrompt, messages); - const chunks = []; + const stream = handler.createMessage(systemPrompt, messages) + const chunks = [] for await (const chunk of stream) { - chunks.push(chunk); + chunks.push(chunk) } - expect(chunks).toHaveLength(2); // Text chunk + usage chunk + expect(chunks).toHaveLength(2) // Text chunk + usage chunk expect(chunks[0]).toEqual({ - type: 'text', - text: responseText - }); + type: "text", + text: responseText, + }) expect(chunks[1]).toMatchObject({ - type: 'usage', + type: "usage", inputTokens: expect.any(Number), - outputTokens: expect.any(Number) - }); - }); + outputTokens: expect.any(Number), + }) + }) - it('should handle tool calls', async () => { - const systemPrompt = 'You are a helpful assistant'; - const messages: Anthropic.Messages.MessageParam[] = [{ - role: 'user' as const, - content: 'Calculate 2+2' - }]; + it("should handle tool calls", async () => { + const systemPrompt = "You are a helpful assistant" + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user" as const, + content: "Calculate 2+2", + }, + ] const toolCallData = { - name: 'calculator', - arguments: { operation: 'add', numbers: [2, 2] }, - callId: 'call-1' - }; + name: "calculator", + arguments: { operation: "add", numbers: [2, 2] }, + callId: "call-1", + } mockLanguageModelChat.sendRequest.mockResolvedValueOnce({ stream: (async function* () { yield new vscode.LanguageModelToolCallPart( toolCallData.callId, toolCallData.name, - toolCallData.arguments - ); - return; + toolCallData.arguments, + ) + return })(), text: (async function* () { - yield JSON.stringify({ type: 'tool_call', ...toolCallData }); - return; - })() - }); + yield JSON.stringify({ type: "tool_call", ...toolCallData }) + return + })(), + }) - const stream = handler.createMessage(systemPrompt, messages); - const chunks = []; + const stream = handler.createMessage(systemPrompt, messages) + const chunks = [] for await (const chunk of stream) { - chunks.push(chunk); + chunks.push(chunk) } - expect(chunks).toHaveLength(2); // Tool call chunk + usage chunk + expect(chunks).toHaveLength(2) // Tool call chunk + usage chunk expect(chunks[0]).toEqual({ - type: 'text', - text: JSON.stringify({ type: 'tool_call', ...toolCallData }) - }); - }); + type: "text", + text: JSON.stringify({ type: "tool_call", ...toolCallData }), + }) + }) - it('should handle errors', async () => { - const systemPrompt = 'You are a helpful assistant'; - const messages: Anthropic.Messages.MessageParam[] = [{ - role: 'user' as const, - content: 'Hello' - }]; + it("should handle errors", async () => { + const systemPrompt = "You are a helpful assistant" + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user" as const, + content: "Hello", + }, + ] - mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error('API Error')); + mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("API Error")) await expect(async () => { - const stream = handler.createMessage(systemPrompt, messages); + const stream = handler.createMessage(systemPrompt, messages) for await (const _ of stream) { // consume stream } - }).rejects.toThrow('API Error'); - }); - }); + }).rejects.toThrow("API Error") + }) + }) + + describe("getModel", () => { + it("should return model info when client exists", async () => { + const mockModel = { ...mockLanguageModelChat } + ;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]) - describe('getModel', () => { - it('should return model info when client exists', async () => { - const mockModel = { ...mockLanguageModelChat }; - (vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]); - // Initialize client - await handler['getClient'](); - - const model = handler.getModel(); - expect(model.id).toBe('test-model'); - expect(model.info).toBeDefined(); - expect(model.info.contextWindow).toBe(4096); - }); + await handler["getClient"]() - it('should return fallback model info when no client exists', () => { - const model = handler.getModel(); - expect(model.id).toBe('test-vendor/test-family'); - expect(model.info).toBeDefined(); - }); - }); + const model = handler.getModel() + expect(model.id).toBe("test-model") + expect(model.info).toBeDefined() + expect(model.info.contextWindow).toBe(4096) + }) - describe('completePrompt', () => { - it('should complete single prompt', async () => { - const mockModel = { ...mockLanguageModelChat }; - (vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]); + it("should return fallback model info when no client exists", () => { + const model = handler.getModel() + expect(model.id).toBe("test-vendor/test-family") + expect(model.info).toBeDefined() + }) + }) - const responseText = 'Completed text'; + describe("completePrompt", () => { + it("should complete single prompt", async () => { + const mockModel = { ...mockLanguageModelChat } + ;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]) + + const responseText = "Completed text" mockLanguageModelChat.sendRequest.mockResolvedValueOnce({ stream: (async function* () { - yield new vscode.LanguageModelTextPart(responseText); - return; + yield new vscode.LanguageModelTextPart(responseText) + return })(), text: (async function* () { - yield responseText; - return; - })() - }); + yield responseText + return + })(), + }) - const result = await handler.completePrompt('Test prompt'); - expect(result).toBe(responseText); - expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled(); - }); + const result = await handler.completePrompt("Test prompt") + expect(result).toBe(responseText) + expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled() + }) - it('should handle errors during completion', async () => { - const mockModel = { ...mockLanguageModelChat }; - (vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]); + it("should handle errors during completion", async () => { + const mockModel = { ...mockLanguageModelChat } + ;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]) - mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error('Completion failed')); + mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("Completion failed")) - await expect(handler.completePrompt('Test prompt')) - .rejects - .toThrow('VSCode LM completion error: Completion failed'); - }); - }); -}); \ No newline at end of file + await expect(handler.completePrompt("Test prompt")).rejects.toThrow( + "VSCode LM completion error: Completion failed", + ) + }) + }) +}) diff --git a/src/api/providers/anthropic.ts b/src/api/providers/anthropic.ts index 5184281..e65b82d 100644 --- a/src/api/providers/anthropic.ts +++ b/src/api/providers/anthropic.ts @@ -181,14 +181,14 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler { max_tokens: this.getModel().info.maxTokens || 8192, temperature: 0, messages: [{ role: "user", content: prompt }], - stream: false + stream: false, }) const content = response.content[0] - if (content.type === 'text') { + if (content.type === "text") { return content.text } - return '' + return "" } catch (error) { if (error instanceof Error) { throw new Error(`Anthropic completion error: ${error.message}`) diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 3d07895..87591b7 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -1,4 +1,9 @@ -import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime" +import { + BedrockRuntimeClient, + ConverseStreamCommand, + ConverseCommand, + BedrockRuntimeClientConfig, +} from "@aws-sdk/client-bedrock-runtime" import { Anthropic } from "@anthropic-ai/sdk" import { ApiHandler, SingleCompletionHandler } from "../" import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api" @@ -7,275 +12,276 @@ import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../ // Define types for stream events based on AWS SDK export interface StreamEvent { - messageStart?: { - role?: string; - }; - messageStop?: { - stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence"; - additionalModelResponseFields?: Record; - }; - contentBlockStart?: { - start?: { - text?: string; - }; - contentBlockIndex?: number; - }; - contentBlockDelta?: { - delta?: { - text?: string; - }; - contentBlockIndex?: number; - }; - metadata?: { - usage?: { - inputTokens: number; - outputTokens: number; - totalTokens?: number; // Made optional since we don't use it - }; - metrics?: { - latencyMs: number; - }; - }; + messageStart?: { + role?: string + } + messageStop?: { + stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence" + additionalModelResponseFields?: Record + } + contentBlockStart?: { + start?: { + text?: string + } + contentBlockIndex?: number + } + contentBlockDelta?: { + delta?: { + text?: string + } + contentBlockIndex?: number + } + metadata?: { + usage?: { + inputTokens: number + outputTokens: number + totalTokens?: number // Made optional since we don't use it + } + metrics?: { + latencyMs: number + } + } } export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler { - private options: ApiHandlerOptions - private client: BedrockRuntimeClient + private options: ApiHandlerOptions + private client: BedrockRuntimeClient - constructor(options: ApiHandlerOptions) { - this.options = options - - // Only include credentials if they actually exist - const clientConfig: BedrockRuntimeClientConfig = { - region: this.options.awsRegion || "us-east-1" - } + constructor(options: ApiHandlerOptions) { + this.options = options - if (this.options.awsAccessKey && this.options.awsSecretKey) { - // Create credentials object with all properties at once - clientConfig.credentials = { - accessKeyId: this.options.awsAccessKey, - secretAccessKey: this.options.awsSecretKey, - ...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {}) - } - } + // Only include credentials if they actually exist + const clientConfig: BedrockRuntimeClientConfig = { + region: this.options.awsRegion || "us-east-1", + } - this.client = new BedrockRuntimeClient(clientConfig) - } + if (this.options.awsAccessKey && this.options.awsSecretKey) { + // Create credentials object with all properties at once + clientConfig.credentials = { + accessKeyId: this.options.awsAccessKey, + secretAccessKey: this.options.awsSecretKey, + ...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {}), + } + } - async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - const modelConfig = this.getModel() - - // Handle cross-region inference - let modelId: string - if (this.options.awsUseCrossRegionInference) { - let regionPrefix = (this.options.awsRegion || "").slice(0, 3) - switch (regionPrefix) { - case "us-": - modelId = `us.${modelConfig.id}` - break - case "eu-": - modelId = `eu.${modelConfig.id}` - break - default: - modelId = modelConfig.id - break - } - } else { - modelId = modelConfig.id - } + this.client = new BedrockRuntimeClient(clientConfig) + } - // Convert messages to Bedrock format - const formattedMessages = convertToBedrockConverseMessages(messages) + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { + const modelConfig = this.getModel() - // Construct the payload - const payload = { - modelId, - messages: formattedMessages, - system: [{ text: systemPrompt }], - inferenceConfig: { - maxTokens: modelConfig.info.maxTokens || 5000, - temperature: 0.3, - topP: 0.1, - ...(this.options.awsUsePromptCache ? { - promptCache: { - promptCacheId: this.options.awspromptCacheId || "" - } - } : {}) - } - } + // Handle cross-region inference + let modelId: string + if (this.options.awsUseCrossRegionInference) { + let regionPrefix = (this.options.awsRegion || "").slice(0, 3) + switch (regionPrefix) { + case "us-": + modelId = `us.${modelConfig.id}` + break + case "eu-": + modelId = `eu.${modelConfig.id}` + break + default: + modelId = modelConfig.id + break + } + } else { + modelId = modelConfig.id + } - try { - const command = new ConverseStreamCommand(payload) - const response = await this.client.send(command) + // Convert messages to Bedrock format + const formattedMessages = convertToBedrockConverseMessages(messages) - if (!response.stream) { - throw new Error('No stream available in the response') - } + // Construct the payload + const payload = { + modelId, + messages: formattedMessages, + system: [{ text: systemPrompt }], + inferenceConfig: { + maxTokens: modelConfig.info.maxTokens || 5000, + temperature: 0.3, + topP: 0.1, + ...(this.options.awsUsePromptCache + ? { + promptCache: { + promptCacheId: this.options.awspromptCacheId || "", + }, + } + : {}), + }, + } - for await (const chunk of response.stream) { - // Parse the chunk as JSON if it's a string (for tests) - let streamEvent: StreamEvent - try { - streamEvent = typeof chunk === 'string' ? - JSON.parse(chunk) : - chunk as unknown as StreamEvent - } catch (e) { - console.error('Failed to parse stream event:', e) - continue - } + try { + const command = new ConverseStreamCommand(payload) + const response = await this.client.send(command) - // Handle metadata events first - if (streamEvent.metadata?.usage) { - yield { - type: "usage", - inputTokens: streamEvent.metadata.usage.inputTokens || 0, - outputTokens: streamEvent.metadata.usage.outputTokens || 0 - } - continue - } + if (!response.stream) { + throw new Error("No stream available in the response") + } - // Handle message start - if (streamEvent.messageStart) { - continue - } + for await (const chunk of response.stream) { + // Parse the chunk as JSON if it's a string (for tests) + let streamEvent: StreamEvent + try { + streamEvent = typeof chunk === "string" ? JSON.parse(chunk) : (chunk as unknown as StreamEvent) + } catch (e) { + console.error("Failed to parse stream event:", e) + continue + } - // Handle content blocks - if (streamEvent.contentBlockStart?.start?.text) { - yield { - type: "text", - text: streamEvent.contentBlockStart.start.text - } - continue - } + // Handle metadata events first + if (streamEvent.metadata?.usage) { + yield { + type: "usage", + inputTokens: streamEvent.metadata.usage.inputTokens || 0, + outputTokens: streamEvent.metadata.usage.outputTokens || 0, + } + continue + } - // Handle content deltas - if (streamEvent.contentBlockDelta?.delta?.text) { - yield { - type: "text", - text: streamEvent.contentBlockDelta.delta.text - } - continue - } + // Handle message start + if (streamEvent.messageStart) { + continue + } - // Handle message stop - if (streamEvent.messageStop) { - continue - } - } + // Handle content blocks + if (streamEvent.contentBlockStart?.start?.text) { + yield { + type: "text", + text: streamEvent.contentBlockStart.start.text, + } + continue + } - } catch (error: unknown) { - console.error('Bedrock Runtime API Error:', error) - // Only access stack if error is an Error object - if (error instanceof Error) { - console.error('Error stack:', error.stack) - yield { - type: "text", - text: `Error: ${error.message}` - } - yield { - type: "usage", - inputTokens: 0, - outputTokens: 0 - } - throw error - } else { - const unknownError = new Error("An unknown error occurred") - yield { - type: "text", - text: unknownError.message - } - yield { - type: "usage", - inputTokens: 0, - outputTokens: 0 - } - throw unknownError - } - } - } + // Handle content deltas + if (streamEvent.contentBlockDelta?.delta?.text) { + yield { + type: "text", + text: streamEvent.contentBlockDelta.delta.text, + } + continue + } - getModel(): { id: BedrockModelId | string; info: ModelInfo } { - const modelId = this.options.apiModelId - if (modelId) { - // For tests, allow any model ID - if (process.env.NODE_ENV === 'test') { - return { - id: modelId, - info: { - maxTokens: 5000, - contextWindow: 128_000, - supportsPromptCache: false - } - } - } - // For production, validate against known models - if (modelId in bedrockModels) { - const id = modelId as BedrockModelId - return { id, info: bedrockModels[id] } - } - } - return { - id: bedrockDefaultModelId, - info: bedrockModels[bedrockDefaultModelId] - } - } + // Handle message stop + if (streamEvent.messageStop) { + continue + } + } + } catch (error: unknown) { + console.error("Bedrock Runtime API Error:", error) + // Only access stack if error is an Error object + if (error instanceof Error) { + console.error("Error stack:", error.stack) + yield { + type: "text", + text: `Error: ${error.message}`, + } + yield { + type: "usage", + inputTokens: 0, + outputTokens: 0, + } + throw error + } else { + const unknownError = new Error("An unknown error occurred") + yield { + type: "text", + text: unknownError.message, + } + yield { + type: "usage", + inputTokens: 0, + outputTokens: 0, + } + throw unknownError + } + } + } - async completePrompt(prompt: string): Promise { - try { - const modelConfig = this.getModel() - - // Handle cross-region inference - let modelId: string - if (this.options.awsUseCrossRegionInference) { - let regionPrefix = (this.options.awsRegion || "").slice(0, 3) - switch (regionPrefix) { - case "us-": - modelId = `us.${modelConfig.id}` - break - case "eu-": - modelId = `eu.${modelConfig.id}` - break - default: - modelId = modelConfig.id - break - } - } else { - modelId = modelConfig.id - } + getModel(): { id: BedrockModelId | string; info: ModelInfo } { + const modelId = this.options.apiModelId + if (modelId) { + // For tests, allow any model ID + if (process.env.NODE_ENV === "test") { + return { + id: modelId, + info: { + maxTokens: 5000, + contextWindow: 128_000, + supportsPromptCache: false, + }, + } + } + // For production, validate against known models + if (modelId in bedrockModels) { + const id = modelId as BedrockModelId + return { id, info: bedrockModels[id] } + } + } + return { + id: bedrockDefaultModelId, + info: bedrockModels[bedrockDefaultModelId], + } + } - const payload = { - modelId, - messages: convertToBedrockConverseMessages([{ - role: "user", - content: prompt - }]), - inferenceConfig: { - maxTokens: modelConfig.info.maxTokens || 5000, - temperature: 0.3, - topP: 0.1 - } - } + async completePrompt(prompt: string): Promise { + try { + const modelConfig = this.getModel() - const command = new ConverseCommand(payload) - const response = await this.client.send(command) + // Handle cross-region inference + let modelId: string + if (this.options.awsUseCrossRegionInference) { + let regionPrefix = (this.options.awsRegion || "").slice(0, 3) + switch (regionPrefix) { + case "us-": + modelId = `us.${modelConfig.id}` + break + case "eu-": + modelId = `eu.${modelConfig.id}` + break + default: + modelId = modelConfig.id + break + } + } else { + modelId = modelConfig.id + } - if (response.output && response.output instanceof Uint8Array) { - try { - const outputStr = new TextDecoder().decode(response.output) - const output = JSON.parse(outputStr) - if (output.content) { - return output.content - } - } catch (parseError) { - console.error('Failed to parse Bedrock response:', parseError) - } - } - return '' - } catch (error) { - if (error instanceof Error) { - throw new Error(`Bedrock completion error: ${error.message}`) - } - throw error - } - } + const payload = { + modelId, + messages: convertToBedrockConverseMessages([ + { + role: "user", + content: prompt, + }, + ]), + inferenceConfig: { + maxTokens: modelConfig.info.maxTokens || 5000, + temperature: 0.3, + topP: 0.1, + }, + } + + const command = new ConverseCommand(payload) + const response = await this.client.send(command) + + if (response.output && response.output instanceof Uint8Array) { + try { + const outputStr = new TextDecoder().decode(response.output) + const output = JSON.parse(outputStr) + if (output.content) { + return output.content + } + } catch (parseError) { + console.error("Failed to parse Bedrock response:", parseError) + } + } + return "" + } catch (error) { + if (error instanceof Error) { + throw new Error(`Bedrock completion error: ${error.message}`) + } + throw error + } + } } diff --git a/src/api/providers/deepseek.ts b/src/api/providers/deepseek.ts index de23d70..e559f98 100644 --- a/src/api/providers/deepseek.ts +++ b/src/api/providers/deepseek.ts @@ -3,24 +3,24 @@ import { ApiHandlerOptions, ModelInfo } from "../../shared/api" import { deepSeekModels, deepSeekDefaultModelId } from "../../shared/api" export class DeepSeekHandler extends OpenAiHandler { - constructor(options: ApiHandlerOptions) { - if (!options.deepSeekApiKey) { - throw new Error("DeepSeek API key is required. Please provide it in the settings.") - } - super({ - ...options, - openAiApiKey: options.deepSeekApiKey, - openAiModelId: options.deepSeekModelId ?? deepSeekDefaultModelId, - openAiBaseUrl: options.deepSeekBaseUrl ?? "https://api.deepseek.com/v1", - includeMaxTokens: true - }) - } + constructor(options: ApiHandlerOptions) { + if (!options.deepSeekApiKey) { + throw new Error("DeepSeek API key is required. Please provide it in the settings.") + } + super({ + ...options, + openAiApiKey: options.deepSeekApiKey, + openAiModelId: options.deepSeekModelId ?? deepSeekDefaultModelId, + openAiBaseUrl: options.deepSeekBaseUrl ?? "https://api.deepseek.com/v1", + includeMaxTokens: true, + }) + } - override getModel(): { id: string; info: ModelInfo } { - const modelId = this.options.deepSeekModelId ?? deepSeekDefaultModelId - return { - id: modelId, - info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId] - } - } + override getModel(): { id: string; info: ModelInfo } { + const modelId = this.options.deepSeekModelId ?? deepSeekDefaultModelId + return { + id: modelId, + info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId], + } + } } diff --git a/src/api/providers/glama.ts b/src/api/providers/glama.ts index 7e95d0c..7f10592 100644 --- a/src/api/providers/glama.ts +++ b/src/api/providers/glama.ts @@ -72,17 +72,17 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler { maxTokens = 8_192 } - const { data: completion, response } = await this.client.chat.completions.create({ - model: this.getModel().id, - max_tokens: maxTokens, - temperature: 0, - messages: openAiMessages, - stream: true, - }).withResponse(); + const { data: completion, response } = await this.client.chat.completions + .create({ + model: this.getModel().id, + max_tokens: maxTokens, + temperature: 0, + messages: openAiMessages, + stream: true, + }) + .withResponse() - const completionRequestId = response.headers.get( - 'x-completion-request-id', - ); + const completionRequestId = response.headers.get("x-completion-request-id") for await (const chunk of completion) { const delta = chunk.choices[0]?.delta @@ -96,13 +96,16 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler { } try { - const response = await axios.get(`https://glama.ai/api/gateway/v1/completion-requests/${completionRequestId}`, { - headers: { - Authorization: `Bearer ${this.options.glamaApiKey}`, + const response = await axios.get( + `https://glama.ai/api/gateway/v1/completion-requests/${completionRequestId}`, + { + headers: { + Authorization: `Bearer ${this.options.glamaApiKey}`, + }, }, - }) + ) - const completionRequest = response.data; + const completionRequest = response.data if (completionRequest.tokenUsage) { yield { @@ -113,7 +116,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler { outputTokens: completionRequest.tokenUsage.completionTokens, totalCost: parseFloat(completionRequest.totalCostUsd), } - } + } } catch (error) { console.error("Error fetching Glama completion details", error) } @@ -126,7 +129,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler { if (modelId && modelInfo) { return { id: modelId, info: modelInfo } } - + return { id: glamaDefaultModelId, info: glamaDefaultModelInfo } } @@ -141,7 +144,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler { if (this.getModel().id.startsWith("anthropic/")) { requestOptions.max_tokens = 8192 } - + const response = await this.client.chat.completions.create(requestOptions) return response.choices[0]?.message.content || "" } catch (error) { diff --git a/src/api/providers/lmstudio.ts b/src/api/providers/lmstudio.ts index e5c6256..d07b9c2 100644 --- a/src/api/providers/lmstudio.ts +++ b/src/api/providers/lmstudio.ts @@ -60,7 +60,7 @@ export class LmStudioHandler implements ApiHandler, SingleCompletionHandler { model: this.getModel().id, messages: [{ role: "user", content: prompt }], temperature: 0, - stream: false + stream: false, }) return response.choices[0]?.message.content || "" } catch (error) { diff --git a/src/api/providers/ollama.ts b/src/api/providers/ollama.ts index 9df73d6..98374c5 100644 --- a/src/api/providers/ollama.ts +++ b/src/api/providers/ollama.ts @@ -53,7 +53,7 @@ export class OllamaHandler implements ApiHandler, SingleCompletionHandler { model: this.getModel().id, messages: [{ role: "user", content: prompt }], temperature: 0, - stream: false + stream: false, }) return response.choices[0]?.message.content || "" } catch (error) { diff --git a/src/api/providers/openai-native.ts b/src/api/providers/openai-native.ts index fa27eb3..0b8908d 100644 --- a/src/api/providers/openai-native.ts +++ b/src/api/providers/openai-native.ts @@ -32,7 +32,10 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler // o1 doesnt support streaming or non-1 temp but does support a developer prompt const response = await this.client.chat.completions.create({ model: modelId, - messages: [{ role: modelId === "o1" ? "developer" : "user", content: systemPrompt }, ...convertToOpenAiMessages(messages)], + messages: [ + { role: modelId === "o1" ? "developer" : "user", content: systemPrompt }, + ...convertToOpenAiMessages(messages), + ], }) yield { type: "text", @@ -98,14 +101,14 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler // o1 doesn't support non-1 temp requestOptions = { model: modelId, - messages: [{ role: "user", content: prompt }] + messages: [{ role: "user", content: prompt }], } break default: requestOptions = { model: modelId, messages: [{ role: "user", content: prompt }], - temperature: 0 + temperature: 0, } } diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index 0878028..f1dbe6e 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -17,7 +17,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler { constructor(options: ApiHandlerOptions) { this.options = options // Azure API shape slightly differs from the core API shape: https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai - const urlHost = new URL(this.options.openAiBaseUrl ?? "").host; + const urlHost = new URL(this.options.openAiBaseUrl ?? "").host if (urlHost === "azure.com" || urlHost.endsWith(".azure.com")) { this.client = new AzureOpenAI({ baseURL: this.options.openAiBaseUrl, @@ -39,7 +39,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler { if (this.options.openAiStreamingEnabled ?? true) { const systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = { role: "system", - content: systemPrompt + content: systemPrompt, } const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { model: modelId, @@ -74,14 +74,14 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler { // o1 for instance doesnt support streaming, non-1 temp, or system prompt const systemMessage: OpenAI.Chat.ChatCompletionUserMessageParam = { role: "user", - content: systemPrompt + content: systemPrompt, } const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { model: modelId, messages: [systemMessage, ...convertToOpenAiMessages(messages)], } const response = await this.client.chat.completions.create(requestOptions) - + yield { type: "text", text: response.choices[0]?.message.content || "", @@ -108,7 +108,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler { messages: [{ role: "user", content: prompt }], temperature: 0, } - + const response = await this.client.chat.completions.create(requestOptions) return response.choices[0]?.message.content || "" } catch (error) { diff --git a/src/api/providers/openrouter.ts b/src/api/providers/openrouter.ts index 9bccf5c..c69d6fe 100644 --- a/src/api/providers/openrouter.ts +++ b/src/api/providers/openrouter.ts @@ -9,12 +9,12 @@ import delay from "delay" // Add custom interface for OpenRouter params type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & { - transforms?: string[]; + transforms?: string[] } // Add custom interface for OpenRouter usage chunk interface OpenRouterApiStreamUsageChunk extends ApiStreamUsageChunk { - fullResponseText: string; + fullResponseText: string } import { SingleCompletionHandler } from ".." @@ -35,7 +35,10 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler { }) } - async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): AsyncGenerator { + async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + ): AsyncGenerator { // Convert Anthropic messages to OpenAI format const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ { role: "system", content: systemPrompt }, @@ -108,7 +111,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler { break } // https://openrouter.ai/docs/transforms - let fullResponseText = ""; + let fullResponseText = "" const stream = await this.client.chat.completions.create({ model: this.getModel().id, max_tokens: maxTokens, @@ -116,8 +119,8 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler { messages: openAiMessages, stream: true, // This way, the transforms field will only be included in the parameters when openRouterUseMiddleOutTransform is true. - ...(this.options.openRouterUseMiddleOutTransform && { transforms: ["middle-out"] }) - } as OpenRouterChatCompletionParams); + ...(this.options.openRouterUseMiddleOutTransform && { transforms: ["middle-out"] }), + } as OpenRouterChatCompletionParams) let genId: string | undefined @@ -135,11 +138,11 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler { const delta = chunk.choices[0]?.delta if (delta?.content) { - fullResponseText += delta.content; + fullResponseText += delta.content yield { type: "text", text: delta.content, - } as ApiStreamChunk; + } as ApiStreamChunk } // if (chunk.usage) { // yield { @@ -170,13 +173,12 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler { inputTokens: generation?.native_tokens_prompt || 0, outputTokens: generation?.native_tokens_completion || 0, totalCost: generation?.total_cost || 0, - fullResponseText - } as OpenRouterApiStreamUsageChunk; + fullResponseText, + } as OpenRouterApiStreamUsageChunk } catch (error) { // ignore if fails console.error("Error fetching OpenRouter generation details:", error) } - } getModel(): { id: string; info: ModelInfo } { const modelId = this.options.openRouterModelId @@ -193,7 +195,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler { model: this.getModel().id, messages: [{ role: "user", content: prompt }], temperature: 0, - stream: false + stream: false, }) if ("error" in response) { diff --git a/src/api/providers/vertex.ts b/src/api/providers/vertex.ts index aed704e..d997135 100644 --- a/src/api/providers/vertex.ts +++ b/src/api/providers/vertex.ts @@ -91,14 +91,14 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler { max_tokens: this.getModel().info.maxTokens || 8192, temperature: 0, messages: [{ role: "user", content: prompt }], - stream: false + stream: false, }) const content = response.content[0] - if (content.type === 'text') { + if (content.type === "text") { return content.text } - return '' + return "" } catch (error) { if (error instanceof Error) { throw new Error(`Vertex completion error: ${error.message}`) diff --git a/src/api/providers/vscode-lm.ts b/src/api/providers/vscode-lm.ts index bde2d6a..6ddc6bd 100644 --- a/src/api/providers/vscode-lm.ts +++ b/src/api/providers/vscode-lm.ts @@ -1,31 +1,31 @@ -import { Anthropic } from "@anthropic-ai/sdk"; -import * as vscode from 'vscode'; -import { ApiHandler, SingleCompletionHandler } from "../"; -import { calculateApiCost } from "../../utils/cost"; -import { ApiStream } from "../transform/stream"; -import { convertToVsCodeLmMessages } from "../transform/vscode-lm-format"; -import { SELECTOR_SEPARATOR, stringifyVsCodeLmModelSelector } from "../../shared/vsCodeSelectorUtils"; -import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"; +import { Anthropic } from "@anthropic-ai/sdk" +import * as vscode from "vscode" +import { ApiHandler, SingleCompletionHandler } from "../" +import { calculateApiCost } from "../../utils/cost" +import { ApiStream } from "../transform/stream" +import { convertToVsCodeLmMessages } from "../transform/vscode-lm-format" +import { SELECTOR_SEPARATOR, stringifyVsCodeLmModelSelector } from "../../shared/vsCodeSelectorUtils" +import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api" /** * Handles interaction with VS Code's Language Model API for chat-based operations. * This handler implements the ApiHandler interface to provide VS Code LM specific functionality. - * + * * @implements {ApiHandler} - * + * * @remarks * The handler manages a VS Code language model chat client and provides methods to: * - Create and manage chat client instances * - Stream messages using VS Code's Language Model API * - Retrieve model information - * + * * @example * ```typescript * const options = { * vsCodeLmModelSelector: { vendor: "copilot", family: "gpt-4" } * }; * const handler = new VsCodeLmHandler(options); - * + * * // Stream a conversation * const systemPrompt = "You are a helpful assistant"; * const messages = [{ role: "user", content: "Hello!" }]; @@ -35,39 +35,36 @@ import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../.. * ``` */ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler { - - private options: ApiHandlerOptions; - private client: vscode.LanguageModelChat | null; - private disposable: vscode.Disposable | null; - private currentRequestCancellation: vscode.CancellationTokenSource | null; + private options: ApiHandlerOptions + private client: vscode.LanguageModelChat | null + private disposable: vscode.Disposable | null + private currentRequestCancellation: vscode.CancellationTokenSource | null constructor(options: ApiHandlerOptions) { - this.options = options; - this.client = null; - this.disposable = null; - this.currentRequestCancellation = null; + this.options = options + this.client = null + this.disposable = null + this.currentRequestCancellation = null try { // Listen for model changes and reset client - this.disposable = vscode.workspace.onDidChangeConfiguration(event => { - if (event.affectsConfiguration('lm')) { + this.disposable = vscode.workspace.onDidChangeConfiguration((event) => { + if (event.affectsConfiguration("lm")) { try { - this.client = null; - this.ensureCleanState(); - } - catch (error) { - console.error('Error during configuration change cleanup:', error); + this.client = null + this.ensureCleanState() + } catch (error) { + console.error("Error during configuration change cleanup:", error) } } - }); - } - catch (error) { + }) + } catch (error) { // Ensure cleanup if constructor fails - this.dispose(); + this.dispose() throw new Error( - `Cline : Failed to initialize handler: ${error instanceof Error ? error.message : 'Unknown error'}` - ); + `Cline : Failed to initialize handler: ${error instanceof Error ? error.message : "Unknown error"}`, + ) } } @@ -77,46 +74,46 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler { * @param selector - Selector criteria to filter language model chat instances * @returns Promise resolving to the first matching language model chat instance * @throws Error when no matching models are found with the given selector - * + * * @example * const selector = { vendor: "copilot", family: "gpt-4o" }; * const chatClient = await createClient(selector); */ async createClient(selector: vscode.LanguageModelChatSelector): Promise { try { - const models = await vscode.lm.selectChatModels(selector); + const models = await vscode.lm.selectChatModels(selector) // Use first available model or create a minimal model object if (models && Array.isArray(models) && models.length > 0) { - return models[0]; + return models[0] } // Create a minimal model if no models are available return { - id: 'default-lm', - name: 'Default Language Model', - vendor: 'vscode', - family: 'lm', - version: '1.0', + id: "default-lm", + name: "Default Language Model", + vendor: "vscode", + family: "lm", + version: "1.0", maxInputTokens: 8192, sendRequest: async (messages, options, token) => { // Provide a minimal implementation return { stream: (async function* () { yield new vscode.LanguageModelTextPart( - "Language model functionality is limited. Please check VS Code configuration." - ); + "Language model functionality is limited. Please check VS Code configuration.", + ) })(), text: (async function* () { - yield "Language model functionality is limited. Please check VS Code configuration."; - })() - }; + yield "Language model functionality is limited. Please check VS Code configuration." + })(), + } }, - countTokens: async () => 0 - }; + countTokens: async () => 0, + } } catch (error) { - const errorMessage = error instanceof Error ? error.message : 'Unknown error'; - throw new Error(`Cline : Failed to select model: ${errorMessage}`); + const errorMessage = error instanceof Error ? error.message : "Unknown error" + throw new Error(`Cline : Failed to select model: ${errorMessage}`) } } @@ -125,242 +122,234 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler { * * @param systemPrompt - The system prompt to initialize the conversation context * @param messages - An array of message parameters following the Anthropic message format - * + * * @yields {ApiStream} An async generator that yields either text chunks or tool calls from the model response - * + * * @throws {Error} When vsCodeLmModelSelector option is not provided * @throws {Error} When the response stream encounters an error - * + * * @remarks * This method handles the initialization of the VS Code LM client if not already created, * converts the messages to VS Code LM format, and streams the response chunks. * Tool calls handling is currently a work in progress. */ dispose(): void { - if (this.disposable) { - - this.disposable.dispose(); + this.disposable.dispose() } if (this.currentRequestCancellation) { - - this.currentRequestCancellation.cancel(); - this.currentRequestCancellation.dispose(); + this.currentRequestCancellation.cancel() + this.currentRequestCancellation.dispose() } } private async countTokens(text: string | vscode.LanguageModelChatMessage): Promise { // Check for required dependencies if (!this.client) { - console.warn('Cline : No client available for token counting'); - return 0; + console.warn("Cline : No client available for token counting") + return 0 } if (!this.currentRequestCancellation) { - console.warn('Cline : No cancellation token available for token counting'); - return 0; + console.warn("Cline : No cancellation token available for token counting") + return 0 } // Validate input if (!text) { - console.debug('Cline : Empty text provided for token counting'); - return 0; + console.debug("Cline : Empty text provided for token counting") + return 0 } try { // Handle different input types - let tokenCount: number; + let tokenCount: number - if (typeof text === 'string') { - tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token); + if (typeof text === "string") { + tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token) } else if (text instanceof vscode.LanguageModelChatMessage) { // For chat messages, ensure we have content if (!text.content || (Array.isArray(text.content) && text.content.length === 0)) { - console.debug('Cline : Empty chat message content'); - return 0; + console.debug("Cline : Empty chat message content") + return 0 } - tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token); + tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token) } else { - console.warn('Cline : Invalid input type for token counting'); - return 0; + console.warn("Cline : Invalid input type for token counting") + return 0 } // Validate the result - if (typeof tokenCount !== 'number') { - console.warn('Cline : Non-numeric token count received:', tokenCount); - return 0; + if (typeof tokenCount !== "number") { + console.warn("Cline : Non-numeric token count received:", tokenCount) + return 0 } if (tokenCount < 0) { - console.warn('Cline : Negative token count received:', tokenCount); - return 0; + console.warn("Cline : Negative token count received:", tokenCount) + return 0 } - return tokenCount; - } - catch (error) { + return tokenCount + } catch (error) { // Handle specific error types if (error instanceof vscode.CancellationError) { - console.debug('Cline : Token counting cancelled by user'); - return 0; + console.debug("Cline : Token counting cancelled by user") + return 0 } - const errorMessage = error instanceof Error ? error.message : 'Unknown error'; - console.warn('Cline : Token counting failed:', errorMessage); + const errorMessage = error instanceof Error ? error.message : "Unknown error" + console.warn("Cline : Token counting failed:", errorMessage) // Log additional error details if available if (error instanceof Error && error.stack) { - console.debug('Token counting error stack:', error.stack); + console.debug("Token counting error stack:", error.stack) } - return 0; // Fallback to prevent stream interruption + return 0 // Fallback to prevent stream interruption } } - private async calculateTotalInputTokens(systemPrompt: string, vsCodeLmMessages: vscode.LanguageModelChatMessage[]): Promise { + private async calculateTotalInputTokens( + systemPrompt: string, + vsCodeLmMessages: vscode.LanguageModelChatMessage[], + ): Promise { + const systemTokens: number = await this.countTokens(systemPrompt) - const systemTokens: number = await this.countTokens(systemPrompt); + const messageTokens: number[] = await Promise.all(vsCodeLmMessages.map((msg) => this.countTokens(msg))) - const messageTokens: number[] = await Promise.all( - vsCodeLmMessages.map(msg => this.countTokens(msg)) - ); - - return systemTokens + messageTokens.reduce( - (sum: number, tokens: number): number => sum + tokens, 0 - ); + return systemTokens + messageTokens.reduce((sum: number, tokens: number): number => sum + tokens, 0) } private ensureCleanState(): void { - if (this.currentRequestCancellation) { - - this.currentRequestCancellation.cancel(); - this.currentRequestCancellation.dispose(); - this.currentRequestCancellation = null; + this.currentRequestCancellation.cancel() + this.currentRequestCancellation.dispose() + this.currentRequestCancellation = null } } private async getClient(): Promise { if (!this.client) { - console.debug('Cline : Getting client with options:', { + console.debug("Cline : Getting client with options:", { vsCodeLmModelSelector: this.options.vsCodeLmModelSelector, hasOptions: !!this.options, - selectorKeys: this.options.vsCodeLmModelSelector ? Object.keys(this.options.vsCodeLmModelSelector) : [] - }); + selectorKeys: this.options.vsCodeLmModelSelector ? Object.keys(this.options.vsCodeLmModelSelector) : [], + }) try { // Use default empty selector if none provided to get all available models - const selector = this.options?.vsCodeLmModelSelector || {}; - console.debug('Cline : Creating client with selector:', selector); - this.client = await this.createClient(selector); + const selector = this.options?.vsCodeLmModelSelector || {} + console.debug("Cline : Creating client with selector:", selector) + this.client = await this.createClient(selector) } catch (error) { - const message = error instanceof Error ? error.message : 'Unknown error'; - console.error('Cline : Client creation failed:', message); - throw new Error(`Cline : Failed to create client: ${message}`); + const message = error instanceof Error ? error.message : "Unknown error" + console.error("Cline : Client creation failed:", message) + throw new Error(`Cline : Failed to create client: ${message}`) } } - return this.client; + return this.client } private cleanTerminalOutput(text: string): string { if (!text) { - return ''; + return "" } - return text - // Нормализуем переносы строк - .replace(/\r\n/g, '\n') - .replace(/\r/g, '\n') + return ( + text + // Нормализуем переносы строк + .replace(/\r\n/g, "\n") + .replace(/\r/g, "\n") - // Удаляем ANSI escape sequences - .replace(/\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])/g, '') // Полный набор ANSI sequences - .replace(/\x9B[0-?]*[ -/]*[@-~]/g, '') // CSI sequences + // Удаляем ANSI escape sequences + .replace(/\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])/g, "") // Полный набор ANSI sequences + .replace(/\x9B[0-?]*[ -/]*[@-~]/g, "") // CSI sequences - // Удаляем последовательности установки заголовка терминала и прочие OSC sequences - .replace(/\x1B\][0-9;]*(?:\x07|\x1B\\)/g, '') + // Удаляем последовательности установки заголовка терминала и прочие OSC sequences + .replace(/\x1B\][0-9;]*(?:\x07|\x1B\\)/g, "") - // Удаляем управляющие символы - .replace(/[\x00-\x09\x0B-\x0C\x0E-\x1F\x7F]/g, '') + // Удаляем управляющие символы + .replace(/[\x00-\x09\x0B-\x0C\x0E-\x1F\x7F]/g, "") - // Удаляем escape-последовательности VS Code - .replace(/\x1B[PD].*?\x1B\\/g, '') // DCS sequences - .replace(/\x1B_.*?\x1B\\/g, '') // APC sequences - .replace(/\x1B\^.*?\x1B\\/g, '') // PM sequences - .replace(/\x1B\[[\d;]*[HfABCDEFGJKST]/g, '') // Cursor movement and clear screen + // Удаляем escape-последовательности VS Code + .replace(/\x1B[PD].*?\x1B\\/g, "") // DCS sequences + .replace(/\x1B_.*?\x1B\\/g, "") // APC sequences + .replace(/\x1B\^.*?\x1B\\/g, "") // PM sequences + .replace(/\x1B\[[\d;]*[HfABCDEFGJKST]/g, "") // Cursor movement and clear screen - // Удаляем пути Windows и служебную информацию - .replace(/^(?:PS )?[A-Z]:\\[^\n]*$/mg, '') - .replace(/^;?Cwd=.*$/mg, '') + // Удаляем пути Windows и служебную информацию + .replace(/^(?:PS )?[A-Z]:\\[^\n]*$/gm, "") + .replace(/^;?Cwd=.*$/gm, "") - // Очищаем экранированные последовательности - .replace(/\\x[0-9a-fA-F]{2}/g, '') - .replace(/\\u[0-9a-fA-F]{4}/g, '') + // Очищаем экранированные последовательности + .replace(/\\x[0-9a-fA-F]{2}/g, "") + .replace(/\\u[0-9a-fA-F]{4}/g, "") - // Финальная очистка - .replace(/\n{3,}/g, '\n\n') // Убираем множественные пустые строки - .trim(); + // Финальная очистка + .replace(/\n{3,}/g, "\n\n") // Убираем множественные пустые строки + .trim() + ) } private cleanMessageContent(content: any): any { if (!content) { - return content; + return content } - if (typeof content === 'string') { - return this.cleanTerminalOutput(content); + if (typeof content === "string") { + return this.cleanTerminalOutput(content) } if (Array.isArray(content)) { - return content.map(item => this.cleanMessageContent(item)); + return content.map((item) => this.cleanMessageContent(item)) } - if (typeof content === 'object') { - const cleaned: any = {}; + if (typeof content === "object") { + const cleaned: any = {} for (const [key, value] of Object.entries(content)) { - cleaned[key] = this.cleanMessageContent(value); + cleaned[key] = this.cleanMessageContent(value) } - return cleaned; + return cleaned } - return content; + return content } async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - // Ensure clean state before starting a new request - this.ensureCleanState(); - const client: vscode.LanguageModelChat = await this.getClient(); + this.ensureCleanState() + const client: vscode.LanguageModelChat = await this.getClient() // Clean system prompt and messages - const cleanedSystemPrompt = this.cleanTerminalOutput(systemPrompt); - const cleanedMessages = messages.map(msg => ({ + const cleanedSystemPrompt = this.cleanTerminalOutput(systemPrompt) + const cleanedMessages = messages.map((msg) => ({ ...msg, - content: this.cleanMessageContent(msg.content) - })); + content: this.cleanMessageContent(msg.content), + })) // Convert Anthropic messages to VS Code LM messages const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = [ vscode.LanguageModelChatMessage.Assistant(cleanedSystemPrompt), ...convertToVsCodeLmMessages(cleanedMessages), - ]; + ] // Initialize cancellation token for the request - this.currentRequestCancellation = new vscode.CancellationTokenSource(); + this.currentRequestCancellation = new vscode.CancellationTokenSource() // Calculate input tokens before starting the stream - const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages); + const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages) // Accumulate the text and count at the end of the stream to reduce token counting overhead. - let accumulatedText: string = ''; + let accumulatedText: string = "" try { - // Create the response stream with minimal required options const requestOptions: vscode.LanguageModelChatRequestOptions = { - justification: `Cline would like to use '${client.name}' from '${client.vendor}', Click 'Allow' to proceed.` - }; + justification: `Cline would like to use '${client.name}' from '${client.vendor}', Click 'Allow' to proceed.`, + } // Note: Tool support is currently provided by the VSCode Language Model API directly // Extensions can register tools using vscode.lm.registerTool() @@ -368,40 +357,40 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler { const response: vscode.LanguageModelChatResponse = await client.sendRequest( vsCodeLmMessages, requestOptions, - this.currentRequestCancellation.token - ); + this.currentRequestCancellation.token, + ) // Consume the stream and handle both text and tool call chunks for await (const chunk of response.stream) { if (chunk instanceof vscode.LanguageModelTextPart) { // Validate text part value - if (typeof chunk.value !== 'string') { - console.warn('Cline : Invalid text part value received:', chunk.value); - continue; + if (typeof chunk.value !== "string") { + console.warn("Cline : Invalid text part value received:", chunk.value) + continue } - accumulatedText += chunk.value; + accumulatedText += chunk.value yield { type: "text", text: chunk.value, - }; + } } else if (chunk instanceof vscode.LanguageModelToolCallPart) { try { // Validate tool call parameters - if (!chunk.name || typeof chunk.name !== 'string') { - console.warn('Cline : Invalid tool name received:', chunk.name); - continue; + if (!chunk.name || typeof chunk.name !== "string") { + console.warn("Cline : Invalid tool name received:", chunk.name) + continue } - if (!chunk.callId || typeof chunk.callId !== 'string') { - console.warn('Cline : Invalid tool callId received:', chunk.callId); - continue; + if (!chunk.callId || typeof chunk.callId !== "string") { + console.warn("Cline : Invalid tool callId received:", chunk.callId) + continue } // Ensure input is a valid object - if (!chunk.input || typeof chunk.input !== 'object') { - console.warn('Cline : Invalid tool input received:', chunk.input); - continue; + if (!chunk.input || typeof chunk.input !== "object") { + console.warn("Cline : Invalid tool input received:", chunk.input) + continue } // Convert tool calls to text format with proper error handling @@ -409,82 +398,75 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler { type: "tool_call", name: chunk.name, arguments: chunk.input, - callId: chunk.callId - }; + callId: chunk.callId, + } - const toolCallText = JSON.stringify(toolCall); - accumulatedText += toolCallText; + const toolCallText = JSON.stringify(toolCall) + accumulatedText += toolCallText // Log tool call for debugging - console.debug('Cline : Processing tool call:', { + console.debug("Cline : Processing tool call:", { name: chunk.name, callId: chunk.callId, - inputSize: JSON.stringify(chunk.input).length - }); + inputSize: JSON.stringify(chunk.input).length, + }) yield { type: "text", text: toolCallText, - }; + } } catch (error) { - console.error('Cline : Failed to process tool call:', error); + console.error("Cline : Failed to process tool call:", error) // Continue processing other chunks even if one fails - continue; + continue } } else { - console.warn('Cline : Unknown chunk type received:', chunk); + console.warn("Cline : Unknown chunk type received:", chunk) } } // Count tokens in the accumulated text after stream completion - const totalOutputTokens: number = await this.countTokens(accumulatedText); + const totalOutputTokens: number = await this.countTokens(accumulatedText) // Report final usage after stream completion yield { type: "usage", inputTokens: totalInputTokens, outputTokens: totalOutputTokens, - totalCost: calculateApiCost( - this.getModel().info, - totalInputTokens, - totalOutputTokens - ) - }; - } - catch (error: unknown) { - - this.ensureCleanState(); + totalCost: calculateApiCost(this.getModel().info, totalInputTokens, totalOutputTokens), + } + } catch (error: unknown) { + this.ensureCleanState() if (error instanceof vscode.CancellationError) { - - throw new Error("Cline : Request cancelled by user"); + throw new Error("Cline : Request cancelled by user") } if (error instanceof Error) { - console.error('Cline : Stream error details:', { + console.error("Cline : Stream error details:", { message: error.message, stack: error.stack, - name: error.name - }); + name: error.name, + }) // Return original error if it's already an Error instance - throw error; - } else if (typeof error === 'object' && error !== null) { + throw error + } else if (typeof error === "object" && error !== null) { // Handle error-like objects - const errorDetails = JSON.stringify(error, null, 2); - console.error('Cline : Stream error object:', errorDetails); - throw new Error(`Cline : Response stream error: ${errorDetails}`); + const errorDetails = JSON.stringify(error, null, 2) + console.error("Cline : Stream error object:", errorDetails) + throw new Error(`Cline : Response stream error: ${errorDetails}`) } else { // Fallback for unknown error types - const errorMessage = String(error); - console.error('Cline : Unknown stream error:', errorMessage); - throw new Error(`Cline : Response stream error: ${errorMessage}`); + const errorMessage = String(error) + console.error("Cline : Unknown stream error:", errorMessage) + throw new Error(`Cline : Response stream error: ${errorMessage}`) } } } // Return model information based on the current client state - getModel(): { id: string; info: ModelInfo; } { + getModel(): { id: string; info: ModelInfo } { if (this.client) { // Validate client properties const requiredProps = { @@ -492,68 +474,69 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler { vendor: this.client.vendor, family: this.client.family, version: this.client.version, - maxInputTokens: this.client.maxInputTokens - }; + maxInputTokens: this.client.maxInputTokens, + } // Log any missing properties for debugging for (const [prop, value] of Object.entries(requiredProps)) { if (!value && value !== 0) { - console.warn(`Cline : Client missing ${prop} property`); + console.warn(`Cline : Client missing ${prop} property`) } } // Construct model ID using available information - const modelParts = [ - this.client.vendor, - this.client.family, - this.client.version - ].filter(Boolean); + const modelParts = [this.client.vendor, this.client.family, this.client.version].filter(Boolean) - const modelId = this.client.id || modelParts.join(SELECTOR_SEPARATOR); + const modelId = this.client.id || modelParts.join(SELECTOR_SEPARATOR) // Build model info with conservative defaults for missing values const modelInfo: ModelInfo = { maxTokens: -1, // Unlimited tokens by default - contextWindow: typeof this.client.maxInputTokens === 'number' - ? Math.max(0, this.client.maxInputTokens) - : openAiModelInfoSaneDefaults.contextWindow, + contextWindow: + typeof this.client.maxInputTokens === "number" + ? Math.max(0, this.client.maxInputTokens) + : openAiModelInfoSaneDefaults.contextWindow, supportsImages: false, // VSCode Language Model API currently doesn't support image inputs supportsPromptCache: true, inputPrice: 0, outputPrice: 0, - description: `VSCode Language Model: ${modelId}` - }; + description: `VSCode Language Model: ${modelId}`, + } - return { id: modelId, info: modelInfo }; + return { id: modelId, info: modelInfo } } // Fallback when no client is available const fallbackId = this.options.vsCodeLmModelSelector ? stringifyVsCodeLmModelSelector(this.options.vsCodeLmModelSelector) - : "vscode-lm"; + : "vscode-lm" - console.debug('Cline : No client available, using fallback model info'); + console.debug("Cline : No client available, using fallback model info") return { id: fallbackId, info: { ...openAiModelInfoSaneDefaults, - description: `VSCode Language Model (Fallback): ${fallbackId}` - } - }; + description: `VSCode Language Model (Fallback): ${fallbackId}`, + }, + } } async completePrompt(prompt: string): Promise { try { - const client = await this.getClient(); - const response = await client.sendRequest([vscode.LanguageModelChatMessage.User(prompt)], {}, new vscode.CancellationTokenSource().token); - let result = ""; + const client = await this.getClient() + const response = await client.sendRequest( + [vscode.LanguageModelChatMessage.User(prompt)], + {}, + new vscode.CancellationTokenSource().token, + ) + let result = "" for await (const chunk of response.stream) { if (chunk instanceof vscode.LanguageModelTextPart) { - result += chunk.value; + result += chunk.value } } - return result; + return result } catch (error) { if (error instanceof Error) { throw new Error(`VSCode LM completion error: ${error.message}`) diff --git a/src/api/transform/__tests__/bedrock-converse-format.test.ts b/src/api/transform/__tests__/bedrock-converse-format.test.ts index c9a0190..c46eb94 100644 --- a/src/api/transform/__tests__/bedrock-converse-format.test.ts +++ b/src/api/transform/__tests__/bedrock-converse-format.test.ts @@ -1,252 +1,250 @@ -import { convertToBedrockConverseMessages, convertToAnthropicMessage } from '../bedrock-converse-format' -import { Anthropic } from '@anthropic-ai/sdk' -import { ContentBlock, ToolResultContentBlock } from '@aws-sdk/client-bedrock-runtime' -import { StreamEvent } from '../../providers/bedrock' +import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../bedrock-converse-format" +import { Anthropic } from "@anthropic-ai/sdk" +import { ContentBlock, ToolResultContentBlock } from "@aws-sdk/client-bedrock-runtime" +import { StreamEvent } from "../../providers/bedrock" -describe('bedrock-converse-format', () => { - describe('convertToBedrockConverseMessages', () => { - test('converts simple text messages correctly', () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { role: 'user', content: 'Hello' }, - { role: 'assistant', content: 'Hi there' } - ] +describe("bedrock-converse-format", () => { + describe("convertToBedrockConverseMessages", () => { + test("converts simple text messages correctly", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { role: "user", content: "Hello" }, + { role: "assistant", content: "Hi there" }, + ] - const result = convertToBedrockConverseMessages(messages) + const result = convertToBedrockConverseMessages(messages) - expect(result).toEqual([ - { - role: 'user', - content: [{ text: 'Hello' }] - }, - { - role: 'assistant', - content: [{ text: 'Hi there' }] - } - ]) - }) + expect(result).toEqual([ + { + role: "user", + content: [{ text: "Hello" }], + }, + { + role: "assistant", + content: [{ text: "Hi there" }], + }, + ]) + }) - test('converts messages with images correctly', () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: 'user', - content: [ - { - type: 'text', - text: 'Look at this image:' - }, - { - type: 'image', - source: { - type: 'base64', - data: 'SGVsbG8=', // "Hello" in base64 - media_type: 'image/jpeg' as const - } - } - ] - } - ] + test("converts messages with images correctly", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text", + text: "Look at this image:", + }, + { + type: "image", + source: { + type: "base64", + data: "SGVsbG8=", // "Hello" in base64 + media_type: "image/jpeg" as const, + }, + }, + ], + }, + ] - const result = convertToBedrockConverseMessages(messages) + const result = convertToBedrockConverseMessages(messages) - if (!result[0] || !result[0].content) { - fail('Expected result to have content') - return - } + if (!result[0] || !result[0].content) { + fail("Expected result to have content") + return + } - expect(result[0].role).toBe('user') - expect(result[0].content).toHaveLength(2) - expect(result[0].content[0]).toEqual({ text: 'Look at this image:' }) - - const imageBlock = result[0].content[1] as ContentBlock - if ('image' in imageBlock && imageBlock.image && imageBlock.image.source) { - expect(imageBlock.image.format).toBe('jpeg') - expect(imageBlock.image.source).toBeDefined() - expect(imageBlock.image.source.bytes).toBeDefined() - } else { - fail('Expected image block not found') - } - }) + expect(result[0].role).toBe("user") + expect(result[0].content).toHaveLength(2) + expect(result[0].content[0]).toEqual({ text: "Look at this image:" }) - test('converts tool use messages correctly', () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: 'assistant', - content: [ - { - type: 'tool_use', - id: 'test-id', - name: 'read_file', - input: { - path: 'test.txt' - } - } - ] - } - ] + const imageBlock = result[0].content[1] as ContentBlock + if ("image" in imageBlock && imageBlock.image && imageBlock.image.source) { + expect(imageBlock.image.format).toBe("jpeg") + expect(imageBlock.image.source).toBeDefined() + expect(imageBlock.image.source.bytes).toBeDefined() + } else { + fail("Expected image block not found") + } + }) - const result = convertToBedrockConverseMessages(messages) + test("converts tool use messages correctly", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "assistant", + content: [ + { + type: "tool_use", + id: "test-id", + name: "read_file", + input: { + path: "test.txt", + }, + }, + ], + }, + ] - if (!result[0] || !result[0].content) { - fail('Expected result to have content') - return - } + const result = convertToBedrockConverseMessages(messages) - expect(result[0].role).toBe('assistant') - const toolBlock = result[0].content[0] as ContentBlock - if ('toolUse' in toolBlock && toolBlock.toolUse) { - expect(toolBlock.toolUse).toEqual({ - toolUseId: 'test-id', - name: 'read_file', - input: '\n\ntest.txt\n\n' - }) - } else { - fail('Expected tool use block not found') - } - }) + if (!result[0] || !result[0].content) { + fail("Expected result to have content") + return + } - test('converts tool result messages correctly', () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: 'assistant', - content: [ - { - type: 'tool_result', - tool_use_id: 'test-id', - content: [{ type: 'text', text: 'File contents here' }] - } - ] - } - ] + expect(result[0].role).toBe("assistant") + const toolBlock = result[0].content[0] as ContentBlock + if ("toolUse" in toolBlock && toolBlock.toolUse) { + expect(toolBlock.toolUse).toEqual({ + toolUseId: "test-id", + name: "read_file", + input: "\n\ntest.txt\n\n", + }) + } else { + fail("Expected tool use block not found") + } + }) - const result = convertToBedrockConverseMessages(messages) + test("converts tool result messages correctly", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "assistant", + content: [ + { + type: "tool_result", + tool_use_id: "test-id", + content: [{ type: "text", text: "File contents here" }], + }, + ], + }, + ] - if (!result[0] || !result[0].content) { - fail('Expected result to have content') - return - } + const result = convertToBedrockConverseMessages(messages) - expect(result[0].role).toBe('assistant') - const resultBlock = result[0].content[0] as ContentBlock - if ('toolResult' in resultBlock && resultBlock.toolResult) { - const expectedContent: ToolResultContentBlock[] = [ - { text: 'File contents here' } - ] - expect(resultBlock.toolResult).toEqual({ - toolUseId: 'test-id', - content: expectedContent, - status: 'success' - }) - } else { - fail('Expected tool result block not found') - } - }) + if (!result[0] || !result[0].content) { + fail("Expected result to have content") + return + } - test('handles text content correctly', () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: 'user', - content: [ - { - type: 'text', - text: 'Hello world' - } - ] - } - ] + expect(result[0].role).toBe("assistant") + const resultBlock = result[0].content[0] as ContentBlock + if ("toolResult" in resultBlock && resultBlock.toolResult) { + const expectedContent: ToolResultContentBlock[] = [{ text: "File contents here" }] + expect(resultBlock.toolResult).toEqual({ + toolUseId: "test-id", + content: expectedContent, + status: "success", + }) + } else { + fail("Expected tool result block not found") + } + }) - const result = convertToBedrockConverseMessages(messages) + test("handles text content correctly", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text", + text: "Hello world", + }, + ], + }, + ] - if (!result[0] || !result[0].content) { - fail('Expected result to have content') - return - } + const result = convertToBedrockConverseMessages(messages) - expect(result[0].role).toBe('user') - expect(result[0].content).toHaveLength(1) - const textBlock = result[0].content[0] as ContentBlock - expect(textBlock).toEqual({ text: 'Hello world' }) - }) - }) + if (!result[0] || !result[0].content) { + fail("Expected result to have content") + return + } - describe('convertToAnthropicMessage', () => { - test('converts metadata events correctly', () => { - const event: StreamEvent = { - metadata: { - usage: { - inputTokens: 10, - outputTokens: 20 - } - } - } + expect(result[0].role).toBe("user") + expect(result[0].content).toHaveLength(1) + const textBlock = result[0].content[0] as ContentBlock + expect(textBlock).toEqual({ text: "Hello world" }) + }) + }) - const result = convertToAnthropicMessage(event, 'test-model') + describe("convertToAnthropicMessage", () => { + test("converts metadata events correctly", () => { + const event: StreamEvent = { + metadata: { + usage: { + inputTokens: 10, + outputTokens: 20, + }, + }, + } - expect(result).toEqual({ - id: '', - type: 'message', - role: 'assistant', - model: 'test-model', - usage: { - input_tokens: 10, - output_tokens: 20 - } - }) - }) + const result = convertToAnthropicMessage(event, "test-model") - test('converts content block start events correctly', () => { - const event: StreamEvent = { - contentBlockStart: { - start: { - text: 'Hello' - } - } - } + expect(result).toEqual({ + id: "", + type: "message", + role: "assistant", + model: "test-model", + usage: { + input_tokens: 10, + output_tokens: 20, + }, + }) + }) - const result = convertToAnthropicMessage(event, 'test-model') + test("converts content block start events correctly", () => { + const event: StreamEvent = { + contentBlockStart: { + start: { + text: "Hello", + }, + }, + } - expect(result).toEqual({ - type: 'message', - role: 'assistant', - content: [{ type: 'text', text: 'Hello' }], - model: 'test-model' - }) - }) + const result = convertToAnthropicMessage(event, "test-model") - test('converts content block delta events correctly', () => { - const event: StreamEvent = { - contentBlockDelta: { - delta: { - text: ' world' - } - } - } + expect(result).toEqual({ + type: "message", + role: "assistant", + content: [{ type: "text", text: "Hello" }], + model: "test-model", + }) + }) - const result = convertToAnthropicMessage(event, 'test-model') + test("converts content block delta events correctly", () => { + const event: StreamEvent = { + contentBlockDelta: { + delta: { + text: " world", + }, + }, + } - expect(result).toEqual({ - type: 'message', - role: 'assistant', - content: [{ type: 'text', text: ' world' }], - model: 'test-model' - }) - }) + const result = convertToAnthropicMessage(event, "test-model") - test('converts message stop events correctly', () => { - const event: StreamEvent = { - messageStop: { - stopReason: 'end_turn' as const - } - } + expect(result).toEqual({ + type: "message", + role: "assistant", + content: [{ type: "text", text: " world" }], + model: "test-model", + }) + }) - const result = convertToAnthropicMessage(event, 'test-model') + test("converts message stop events correctly", () => { + const event: StreamEvent = { + messageStop: { + stopReason: "end_turn" as const, + }, + } - expect(result).toEqual({ - type: 'message', - role: 'assistant', - stop_reason: 'end_turn', - stop_sequence: null, - model: 'test-model' - }) - }) - }) + const result = convertToAnthropicMessage(event, "test-model") + + expect(result).toEqual({ + type: "message", + role: "assistant", + stop_reason: "end_turn", + stop_sequence: null, + model: "test-model", + }) + }) + }) }) diff --git a/src/api/transform/__tests__/openai-format.test.ts b/src/api/transform/__tests__/openai-format.test.ts index 32673dc..f37d369 100644 --- a/src/api/transform/__tests__/openai-format.test.ts +++ b/src/api/transform/__tests__/openai-format.test.ts @@ -1,257 +1,275 @@ -import { convertToOpenAiMessages, convertToAnthropicMessage } from '../openai-format'; -import { Anthropic } from '@anthropic-ai/sdk'; -import OpenAI from 'openai'; +import { convertToOpenAiMessages, convertToAnthropicMessage } from "../openai-format" +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" -type PartialChatCompletion = Omit & { - choices: Array & { - message: OpenAI.Chat.Completions.ChatCompletion.Choice['message']; - finish_reason: string; - index: number; - }>; -}; +type PartialChatCompletion = Omit & { + choices: Array< + Partial & { + message: OpenAI.Chat.Completions.ChatCompletion.Choice["message"] + finish_reason: string + index: number + } + > +} -describe('OpenAI Format Transformations', () => { - describe('convertToOpenAiMessages', () => { - it('should convert simple text messages', () => { - const anthropicMessages: Anthropic.Messages.MessageParam[] = [ - { - role: 'user', - content: 'Hello' - }, - { - role: 'assistant', - content: 'Hi there!' - } - ]; +describe("OpenAI Format Transformations", () => { + describe("convertToOpenAiMessages", () => { + it("should convert simple text messages", () => { + const anthropicMessages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: "Hello", + }, + { + role: "assistant", + content: "Hi there!", + }, + ] - const openAiMessages = convertToOpenAiMessages(anthropicMessages); - expect(openAiMessages).toHaveLength(2); - expect(openAiMessages[0]).toEqual({ - role: 'user', - content: 'Hello' - }); - expect(openAiMessages[1]).toEqual({ - role: 'assistant', - content: 'Hi there!' - }); - }); + const openAiMessages = convertToOpenAiMessages(anthropicMessages) + expect(openAiMessages).toHaveLength(2) + expect(openAiMessages[0]).toEqual({ + role: "user", + content: "Hello", + }) + expect(openAiMessages[1]).toEqual({ + role: "assistant", + content: "Hi there!", + }) + }) - it('should handle messages with image content', () => { - const anthropicMessages: Anthropic.Messages.MessageParam[] = [ - { - role: 'user', - content: [ - { - type: 'text', - text: 'What is in this image?' - }, - { - type: 'image', - source: { - type: 'base64', - media_type: 'image/jpeg', - data: 'base64data' - } - } - ] - } - ]; + it("should handle messages with image content", () => { + const anthropicMessages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text", + text: "What is in this image?", + }, + { + type: "image", + source: { + type: "base64", + media_type: "image/jpeg", + data: "base64data", + }, + }, + ], + }, + ] - const openAiMessages = convertToOpenAiMessages(anthropicMessages); - expect(openAiMessages).toHaveLength(1); - expect(openAiMessages[0].role).toBe('user'); - - const content = openAiMessages[0].content as Array<{ - type: string; - text?: string; - image_url?: { url: string }; - }>; - - expect(Array.isArray(content)).toBe(true); - expect(content).toHaveLength(2); - expect(content[0]).toEqual({ type: 'text', text: 'What is in this image?' }); - expect(content[1]).toEqual({ - type: 'image_url', - image_url: { url: '' } - }); - }); + const openAiMessages = convertToOpenAiMessages(anthropicMessages) + expect(openAiMessages).toHaveLength(1) + expect(openAiMessages[0].role).toBe("user") - it('should handle assistant messages with tool use', () => { - const anthropicMessages: Anthropic.Messages.MessageParam[] = [ - { - role: 'assistant', - content: [ - { - type: 'text', - text: 'Let me check the weather.' - }, - { - type: 'tool_use', - id: 'weather-123', - name: 'get_weather', - input: { city: 'London' } - } - ] - } - ]; + const content = openAiMessages[0].content as Array<{ + type: string + text?: string + image_url?: { url: string } + }> - const openAiMessages = convertToOpenAiMessages(anthropicMessages); - expect(openAiMessages).toHaveLength(1); - - const assistantMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionAssistantMessageParam; - expect(assistantMessage.role).toBe('assistant'); - expect(assistantMessage.content).toBe('Let me check the weather.'); - expect(assistantMessage.tool_calls).toHaveLength(1); - expect(assistantMessage.tool_calls![0]).toEqual({ - id: 'weather-123', - type: 'function', - function: { - name: 'get_weather', - arguments: JSON.stringify({ city: 'London' }) - } - }); - }); + expect(Array.isArray(content)).toBe(true) + expect(content).toHaveLength(2) + expect(content[0]).toEqual({ type: "text", text: "What is in this image?" }) + expect(content[1]).toEqual({ + type: "image_url", + image_url: { url: "" }, + }) + }) - it('should handle user messages with tool results', () => { - const anthropicMessages: Anthropic.Messages.MessageParam[] = [ - { - role: 'user', - content: [ - { - type: 'tool_result', - tool_use_id: 'weather-123', - content: 'Current temperature in London: 20°C' - } - ] - } - ]; + it("should handle assistant messages with tool use", () => { + const anthropicMessages: Anthropic.Messages.MessageParam[] = [ + { + role: "assistant", + content: [ + { + type: "text", + text: "Let me check the weather.", + }, + { + type: "tool_use", + id: "weather-123", + name: "get_weather", + input: { city: "London" }, + }, + ], + }, + ] - const openAiMessages = convertToOpenAiMessages(anthropicMessages); - expect(openAiMessages).toHaveLength(1); - - const toolMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionToolMessageParam; - expect(toolMessage.role).toBe('tool'); - expect(toolMessage.tool_call_id).toBe('weather-123'); - expect(toolMessage.content).toBe('Current temperature in London: 20°C'); - }); - }); + const openAiMessages = convertToOpenAiMessages(anthropicMessages) + expect(openAiMessages).toHaveLength(1) - describe('convertToAnthropicMessage', () => { - it('should convert simple completion', () => { - const openAiCompletion: PartialChatCompletion = { - id: 'completion-123', - model: 'gpt-4', - choices: [{ - message: { - role: 'assistant', - content: 'Hello there!', - refusal: null - }, - finish_reason: 'stop', - index: 0 - }], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15 - }, - created: 123456789, - object: 'chat.completion' - }; + const assistantMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionAssistantMessageParam + expect(assistantMessage.role).toBe("assistant") + expect(assistantMessage.content).toBe("Let me check the weather.") + expect(assistantMessage.tool_calls).toHaveLength(1) + expect(assistantMessage.tool_calls![0]).toEqual({ + id: "weather-123", + type: "function", + function: { + name: "get_weather", + arguments: JSON.stringify({ city: "London" }), + }, + }) + }) - const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion); - expect(anthropicMessage.id).toBe('completion-123'); - expect(anthropicMessage.role).toBe('assistant'); - expect(anthropicMessage.content).toHaveLength(1); - expect(anthropicMessage.content[0]).toEqual({ - type: 'text', - text: 'Hello there!' - }); - expect(anthropicMessage.stop_reason).toBe('end_turn'); - expect(anthropicMessage.usage).toEqual({ - input_tokens: 10, - output_tokens: 5 - }); - }); + it("should handle user messages with tool results", () => { + const anthropicMessages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: "weather-123", + content: "Current temperature in London: 20°C", + }, + ], + }, + ] - it('should handle tool calls in completion', () => { - const openAiCompletion: PartialChatCompletion = { - id: 'completion-123', - model: 'gpt-4', - choices: [{ - message: { - role: 'assistant', - content: 'Let me check the weather.', - tool_calls: [{ - id: 'weather-123', - type: 'function', - function: { - name: 'get_weather', - arguments: '{"city":"London"}' - } - }], - refusal: null - }, - finish_reason: 'tool_calls', - index: 0 - }], - usage: { - prompt_tokens: 15, - completion_tokens: 8, - total_tokens: 23 - }, - created: 123456789, - object: 'chat.completion' - }; + const openAiMessages = convertToOpenAiMessages(anthropicMessages) + expect(openAiMessages).toHaveLength(1) - const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion); - expect(anthropicMessage.content).toHaveLength(2); - expect(anthropicMessage.content[0]).toEqual({ - type: 'text', - text: 'Let me check the weather.' - }); - expect(anthropicMessage.content[1]).toEqual({ - type: 'tool_use', - id: 'weather-123', - name: 'get_weather', - input: { city: 'London' } - }); - expect(anthropicMessage.stop_reason).toBe('tool_use'); - }); + const toolMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionToolMessageParam + expect(toolMessage.role).toBe("tool") + expect(toolMessage.tool_call_id).toBe("weather-123") + expect(toolMessage.content).toBe("Current temperature in London: 20°C") + }) + }) - it('should handle invalid tool call arguments', () => { - const openAiCompletion: PartialChatCompletion = { - id: 'completion-123', - model: 'gpt-4', - choices: [{ - message: { - role: 'assistant', - content: 'Testing invalid arguments', - tool_calls: [{ - id: 'test-123', - type: 'function', - function: { - name: 'test_function', - arguments: 'invalid json' - } - }], - refusal: null - }, - finish_reason: 'tool_calls', - index: 0 - }], - created: 123456789, - object: 'chat.completion' - }; + describe("convertToAnthropicMessage", () => { + it("should convert simple completion", () => { + const openAiCompletion: PartialChatCompletion = { + id: "completion-123", + model: "gpt-4", + choices: [ + { + message: { + role: "assistant", + content: "Hello there!", + refusal: null, + }, + finish_reason: "stop", + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + created: 123456789, + object: "chat.completion", + } - const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion); - expect(anthropicMessage.content).toHaveLength(2); - expect(anthropicMessage.content[1]).toEqual({ - type: 'tool_use', - id: 'test-123', - name: 'test_function', - input: {} // Should default to empty object for invalid JSON - }); - }); - }); -}); \ No newline at end of file + const anthropicMessage = convertToAnthropicMessage( + openAiCompletion as OpenAI.Chat.Completions.ChatCompletion, + ) + expect(anthropicMessage.id).toBe("completion-123") + expect(anthropicMessage.role).toBe("assistant") + expect(anthropicMessage.content).toHaveLength(1) + expect(anthropicMessage.content[0]).toEqual({ + type: "text", + text: "Hello there!", + }) + expect(anthropicMessage.stop_reason).toBe("end_turn") + expect(anthropicMessage.usage).toEqual({ + input_tokens: 10, + output_tokens: 5, + }) + }) + + it("should handle tool calls in completion", () => { + const openAiCompletion: PartialChatCompletion = { + id: "completion-123", + model: "gpt-4", + choices: [ + { + message: { + role: "assistant", + content: "Let me check the weather.", + tool_calls: [ + { + id: "weather-123", + type: "function", + function: { + name: "get_weather", + arguments: '{"city":"London"}', + }, + }, + ], + refusal: null, + }, + finish_reason: "tool_calls", + index: 0, + }, + ], + usage: { + prompt_tokens: 15, + completion_tokens: 8, + total_tokens: 23, + }, + created: 123456789, + object: "chat.completion", + } + + const anthropicMessage = convertToAnthropicMessage( + openAiCompletion as OpenAI.Chat.Completions.ChatCompletion, + ) + expect(anthropicMessage.content).toHaveLength(2) + expect(anthropicMessage.content[0]).toEqual({ + type: "text", + text: "Let me check the weather.", + }) + expect(anthropicMessage.content[1]).toEqual({ + type: "tool_use", + id: "weather-123", + name: "get_weather", + input: { city: "London" }, + }) + expect(anthropicMessage.stop_reason).toBe("tool_use") + }) + + it("should handle invalid tool call arguments", () => { + const openAiCompletion: PartialChatCompletion = { + id: "completion-123", + model: "gpt-4", + choices: [ + { + message: { + role: "assistant", + content: "Testing invalid arguments", + tool_calls: [ + { + id: "test-123", + type: "function", + function: { + name: "test_function", + arguments: "invalid json", + }, + }, + ], + refusal: null, + }, + finish_reason: "tool_calls", + index: 0, + }, + ], + created: 123456789, + object: "chat.completion", + } + + const anthropicMessage = convertToAnthropicMessage( + openAiCompletion as OpenAI.Chat.Completions.ChatCompletion, + ) + expect(anthropicMessage.content).toHaveLength(2) + expect(anthropicMessage.content[1]).toEqual({ + type: "tool_use", + id: "test-123", + name: "test_function", + input: {}, // Should default to empty object for invalid JSON + }) + }) + }) +}) diff --git a/src/api/transform/__tests__/stream.test.ts b/src/api/transform/__tests__/stream.test.ts index 32efd50..7cf2e9c 100644 --- a/src/api/transform/__tests__/stream.test.ts +++ b/src/api/transform/__tests__/stream.test.ts @@ -1,114 +1,114 @@ -import { ApiStreamChunk } from '../stream'; +import { ApiStreamChunk } from "../stream" -describe('API Stream Types', () => { - describe('ApiStreamChunk', () => { - it('should correctly handle text chunks', () => { - const textChunk: ApiStreamChunk = { - type: 'text', - text: 'Hello world' - }; +describe("API Stream Types", () => { + describe("ApiStreamChunk", () => { + it("should correctly handle text chunks", () => { + const textChunk: ApiStreamChunk = { + type: "text", + text: "Hello world", + } - expect(textChunk.type).toBe('text'); - expect(textChunk.text).toBe('Hello world'); - }); + expect(textChunk.type).toBe("text") + expect(textChunk.text).toBe("Hello world") + }) - it('should correctly handle usage chunks with cache information', () => { - const usageChunk: ApiStreamChunk = { - type: 'usage', - inputTokens: 100, - outputTokens: 50, - cacheWriteTokens: 20, - cacheReadTokens: 10 - }; + it("should correctly handle usage chunks with cache information", () => { + const usageChunk: ApiStreamChunk = { + type: "usage", + inputTokens: 100, + outputTokens: 50, + cacheWriteTokens: 20, + cacheReadTokens: 10, + } - expect(usageChunk.type).toBe('usage'); - expect(usageChunk.inputTokens).toBe(100); - expect(usageChunk.outputTokens).toBe(50); - expect(usageChunk.cacheWriteTokens).toBe(20); - expect(usageChunk.cacheReadTokens).toBe(10); - }); + expect(usageChunk.type).toBe("usage") + expect(usageChunk.inputTokens).toBe(100) + expect(usageChunk.outputTokens).toBe(50) + expect(usageChunk.cacheWriteTokens).toBe(20) + expect(usageChunk.cacheReadTokens).toBe(10) + }) - it('should handle usage chunks without cache tokens', () => { - const usageChunk: ApiStreamChunk = { - type: 'usage', - inputTokens: 100, - outputTokens: 50 - }; + it("should handle usage chunks without cache tokens", () => { + const usageChunk: ApiStreamChunk = { + type: "usage", + inputTokens: 100, + outputTokens: 50, + } - expect(usageChunk.type).toBe('usage'); - expect(usageChunk.inputTokens).toBe(100); - expect(usageChunk.outputTokens).toBe(50); - expect(usageChunk.cacheWriteTokens).toBeUndefined(); - expect(usageChunk.cacheReadTokens).toBeUndefined(); - }); + expect(usageChunk.type).toBe("usage") + expect(usageChunk.inputTokens).toBe(100) + expect(usageChunk.outputTokens).toBe(50) + expect(usageChunk.cacheWriteTokens).toBeUndefined() + expect(usageChunk.cacheReadTokens).toBeUndefined() + }) - it('should handle text chunks with empty strings', () => { - const emptyTextChunk: ApiStreamChunk = { - type: 'text', - text: '' - }; + it("should handle text chunks with empty strings", () => { + const emptyTextChunk: ApiStreamChunk = { + type: "text", + text: "", + } - expect(emptyTextChunk.type).toBe('text'); - expect(emptyTextChunk.text).toBe(''); - }); + expect(emptyTextChunk.type).toBe("text") + expect(emptyTextChunk.text).toBe("") + }) - it('should handle usage chunks with zero tokens', () => { - const zeroUsageChunk: ApiStreamChunk = { - type: 'usage', - inputTokens: 0, - outputTokens: 0 - }; + it("should handle usage chunks with zero tokens", () => { + const zeroUsageChunk: ApiStreamChunk = { + type: "usage", + inputTokens: 0, + outputTokens: 0, + } - expect(zeroUsageChunk.type).toBe('usage'); - expect(zeroUsageChunk.inputTokens).toBe(0); - expect(zeroUsageChunk.outputTokens).toBe(0); - }); + expect(zeroUsageChunk.type).toBe("usage") + expect(zeroUsageChunk.inputTokens).toBe(0) + expect(zeroUsageChunk.outputTokens).toBe(0) + }) - it('should handle usage chunks with large token counts', () => { - const largeUsageChunk: ApiStreamChunk = { - type: 'usage', - inputTokens: 1000000, - outputTokens: 500000, - cacheWriteTokens: 200000, - cacheReadTokens: 100000 - }; + it("should handle usage chunks with large token counts", () => { + const largeUsageChunk: ApiStreamChunk = { + type: "usage", + inputTokens: 1000000, + outputTokens: 500000, + cacheWriteTokens: 200000, + cacheReadTokens: 100000, + } - expect(largeUsageChunk.type).toBe('usage'); - expect(largeUsageChunk.inputTokens).toBe(1000000); - expect(largeUsageChunk.outputTokens).toBe(500000); - expect(largeUsageChunk.cacheWriteTokens).toBe(200000); - expect(largeUsageChunk.cacheReadTokens).toBe(100000); - }); + expect(largeUsageChunk.type).toBe("usage") + expect(largeUsageChunk.inputTokens).toBe(1000000) + expect(largeUsageChunk.outputTokens).toBe(500000) + expect(largeUsageChunk.cacheWriteTokens).toBe(200000) + expect(largeUsageChunk.cacheReadTokens).toBe(100000) + }) - it('should handle text chunks with special characters', () => { - const specialCharsChunk: ApiStreamChunk = { - type: 'text', - text: '!@#$%^&*()_+-=[]{}|;:,.<>?`~' - }; + it("should handle text chunks with special characters", () => { + const specialCharsChunk: ApiStreamChunk = { + type: "text", + text: "!@#$%^&*()_+-=[]{}|;:,.<>?`~", + } - expect(specialCharsChunk.type).toBe('text'); - expect(specialCharsChunk.text).toBe('!@#$%^&*()_+-=[]{}|;:,.<>?`~'); - }); + expect(specialCharsChunk.type).toBe("text") + expect(specialCharsChunk.text).toBe("!@#$%^&*()_+-=[]{}|;:,.<>?`~") + }) - it('should handle text chunks with unicode characters', () => { - const unicodeChunk: ApiStreamChunk = { - type: 'text', - text: '你好世界👋🌍' - }; + it("should handle text chunks with unicode characters", () => { + const unicodeChunk: ApiStreamChunk = { + type: "text", + text: "你好世界👋🌍", + } - expect(unicodeChunk.type).toBe('text'); - expect(unicodeChunk.text).toBe('你好世界👋🌍'); - }); + expect(unicodeChunk.type).toBe("text") + expect(unicodeChunk.text).toBe("你好世界👋🌍") + }) - it('should handle text chunks with multiline content', () => { - const multilineChunk: ApiStreamChunk = { - type: 'text', - text: 'Line 1\nLine 2\nLine 3' - }; + it("should handle text chunks with multiline content", () => { + const multilineChunk: ApiStreamChunk = { + type: "text", + text: "Line 1\nLine 2\nLine 3", + } - expect(multilineChunk.type).toBe('text'); - expect(multilineChunk.text).toBe('Line 1\nLine 2\nLine 3'); - expect(multilineChunk.text.split('\n')).toHaveLength(3); - }); - }); -}); \ No newline at end of file + expect(multilineChunk.type).toBe("text") + expect(multilineChunk.text).toBe("Line 1\nLine 2\nLine 3") + expect(multilineChunk.text.split("\n")).toHaveLength(3) + }) + }) +}) diff --git a/src/api/transform/__tests__/vscode-lm-format.test.ts b/src/api/transform/__tests__/vscode-lm-format.test.ts index eb71578..bc70da8 100644 --- a/src/api/transform/__tests__/vscode-lm-format.test.ts +++ b/src/api/transform/__tests__/vscode-lm-format.test.ts @@ -1,66 +1,66 @@ -import { Anthropic } from "@anthropic-ai/sdk"; -import * as vscode from 'vscode'; -import { convertToVsCodeLmMessages, convertToAnthropicRole, convertToAnthropicMessage } from '../vscode-lm-format'; +import { Anthropic } from "@anthropic-ai/sdk" +import * as vscode from "vscode" +import { convertToVsCodeLmMessages, convertToAnthropicRole, convertToAnthropicMessage } from "../vscode-lm-format" // Mock crypto const mockCrypto = { - randomUUID: () => 'test-uuid' -}; -global.crypto = mockCrypto as any; + randomUUID: () => "test-uuid", +} +global.crypto = mockCrypto as any // Define types for our mocked classes interface MockLanguageModelTextPart { - type: 'text'; - value: string; + type: "text" + value: string } interface MockLanguageModelToolCallPart { - type: 'tool_call'; - callId: string; - name: string; - input: any; + type: "tool_call" + callId: string + name: string + input: any } interface MockLanguageModelToolResultPart { - type: 'tool_result'; - toolUseId: string; - parts: MockLanguageModelTextPart[]; + type: "tool_result" + toolUseId: string + parts: MockLanguageModelTextPart[] } -type MockMessageContent = MockLanguageModelTextPart | MockLanguageModelToolCallPart | MockLanguageModelToolResultPart; +type MockMessageContent = MockLanguageModelTextPart | MockLanguageModelToolCallPart | MockLanguageModelToolResultPart interface MockLanguageModelChatMessage { - role: string; - name?: string; - content: MockMessageContent[]; + role: string + name?: string + content: MockMessageContent[] } // Mock vscode namespace -jest.mock('vscode', () => { +jest.mock("vscode", () => { const LanguageModelChatMessageRole = { - Assistant: 'assistant', - User: 'user' - }; + Assistant: "assistant", + User: "user", + } class MockLanguageModelTextPart { - type = 'text'; + type = "text" constructor(public value: string) {} } class MockLanguageModelToolCallPart { - type = 'tool_call'; + type = "tool_call" constructor( public callId: string, public name: string, - public input: any + public input: any, ) {} } class MockLanguageModelToolResultPart { - type = 'tool_result'; + type = "tool_result" constructor( public toolUseId: string, - public parts: MockLanguageModelTextPart[] + public parts: MockLanguageModelTextPart[], ) {} } @@ -68,179 +68,189 @@ jest.mock('vscode', () => { LanguageModelChatMessage: { Assistant: jest.fn((content) => ({ role: LanguageModelChatMessageRole.Assistant, - name: 'assistant', - content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)] + name: "assistant", + content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)], })), User: jest.fn((content) => ({ role: LanguageModelChatMessageRole.User, - name: 'user', - content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)] - })) + name: "user", + content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)], + })), }, LanguageModelChatMessageRole, LanguageModelTextPart: MockLanguageModelTextPart, LanguageModelToolCallPart: MockLanguageModelToolCallPart, - LanguageModelToolResultPart: MockLanguageModelToolResultPart - }; -}); + LanguageModelToolResultPart: MockLanguageModelToolResultPart, + } +}) -describe('vscode-lm-format', () => { - describe('convertToVsCodeLmMessages', () => { - it('should convert simple string messages', () => { +describe("vscode-lm-format", () => { + describe("convertToVsCodeLmMessages", () => { + it("should convert simple string messages", () => { const messages: Anthropic.Messages.MessageParam[] = [ - { role: 'user', content: 'Hello' }, - { role: 'assistant', content: 'Hi there' } - ]; + { role: "user", content: "Hello" }, + { role: "assistant", content: "Hi there" }, + ] - const result = convertToVsCodeLmMessages(messages); - - expect(result).toHaveLength(2); - expect(result[0].role).toBe('user'); - expect((result[0].content[0] as MockLanguageModelTextPart).value).toBe('Hello'); - expect(result[1].role).toBe('assistant'); - expect((result[1].content[0] as MockLanguageModelTextPart).value).toBe('Hi there'); - }); + const result = convertToVsCodeLmMessages(messages) - it('should handle complex user messages with tool results', () => { - const messages: Anthropic.Messages.MessageParam[] = [{ - role: 'user', - content: [ - { type: 'text', text: 'Here is the result:' }, - { - type: 'tool_result', - tool_use_id: 'tool-1', - content: 'Tool output' - } - ] - }]; + expect(result).toHaveLength(2) + expect(result[0].role).toBe("user") + expect((result[0].content[0] as MockLanguageModelTextPart).value).toBe("Hello") + expect(result[1].role).toBe("assistant") + expect((result[1].content[0] as MockLanguageModelTextPart).value).toBe("Hi there") + }) - const result = convertToVsCodeLmMessages(messages); - - expect(result).toHaveLength(1); - expect(result[0].role).toBe('user'); - expect(result[0].content).toHaveLength(2); - const [toolResult, textContent] = result[0].content as [MockLanguageModelToolResultPart, MockLanguageModelTextPart]; - expect(toolResult.type).toBe('tool_result'); - expect(textContent.type).toBe('text'); - }); + it("should handle complex user messages with tool results", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { type: "text", text: "Here is the result:" }, + { + type: "tool_result", + tool_use_id: "tool-1", + content: "Tool output", + }, + ], + }, + ] - it('should handle complex assistant messages with tool calls', () => { - const messages: Anthropic.Messages.MessageParam[] = [{ - role: 'assistant', - content: [ - { type: 'text', text: 'Let me help you with that.' }, - { - type: 'tool_use', - id: 'tool-1', - name: 'calculator', - input: { operation: 'add', numbers: [2, 2] } - } - ] - }]; + const result = convertToVsCodeLmMessages(messages) - const result = convertToVsCodeLmMessages(messages); - - expect(result).toHaveLength(1); - expect(result[0].role).toBe('assistant'); - expect(result[0].content).toHaveLength(2); - const [toolCall, textContent] = result[0].content as [MockLanguageModelToolCallPart, MockLanguageModelTextPart]; - expect(toolCall.type).toBe('tool_call'); - expect(textContent.type).toBe('text'); - }); + expect(result).toHaveLength(1) + expect(result[0].role).toBe("user") + expect(result[0].content).toHaveLength(2) + const [toolResult, textContent] = result[0].content as [ + MockLanguageModelToolResultPart, + MockLanguageModelTextPart, + ] + expect(toolResult.type).toBe("tool_result") + expect(textContent.type).toBe("text") + }) - it('should handle image blocks with appropriate placeholders', () => { - const messages: Anthropic.Messages.MessageParam[] = [{ - role: 'user', - content: [ - { type: 'text', text: 'Look at this:' }, - { - type: 'image', - source: { - type: 'base64', - media_type: 'image/png', - data: 'base64data' - } - } - ] - }]; + it("should handle complex assistant messages with tool calls", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "assistant", + content: [ + { type: "text", text: "Let me help you with that." }, + { + type: "tool_use", + id: "tool-1", + name: "calculator", + input: { operation: "add", numbers: [2, 2] }, + }, + ], + }, + ] - const result = convertToVsCodeLmMessages(messages); - - expect(result).toHaveLength(1); - const imagePlaceholder = result[0].content[1] as MockLanguageModelTextPart; - expect(imagePlaceholder.value).toContain('[Image (base64): image/png not supported by VSCode LM API]'); - }); - }); + const result = convertToVsCodeLmMessages(messages) - describe('convertToAnthropicRole', () => { - it('should convert assistant role correctly', () => { - const result = convertToAnthropicRole('assistant' as any); - expect(result).toBe('assistant'); - }); + expect(result).toHaveLength(1) + expect(result[0].role).toBe("assistant") + expect(result[0].content).toHaveLength(2) + const [toolCall, textContent] = result[0].content as [ + MockLanguageModelToolCallPart, + MockLanguageModelTextPart, + ] + expect(toolCall.type).toBe("tool_call") + expect(textContent.type).toBe("text") + }) - it('should convert user role correctly', () => { - const result = convertToAnthropicRole('user' as any); - expect(result).toBe('user'); - }); + it("should handle image blocks with appropriate placeholders", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { type: "text", text: "Look at this:" }, + { + type: "image", + source: { + type: "base64", + media_type: "image/png", + data: "base64data", + }, + }, + ], + }, + ] - it('should return null for unknown roles', () => { - const result = convertToAnthropicRole('unknown' as any); - expect(result).toBeNull(); - }); - }); + const result = convertToVsCodeLmMessages(messages) - describe('convertToAnthropicMessage', () => { - it('should convert assistant message with text content', async () => { + expect(result).toHaveLength(1) + const imagePlaceholder = result[0].content[1] as MockLanguageModelTextPart + expect(imagePlaceholder.value).toContain("[Image (base64): image/png not supported by VSCode LM API]") + }) + }) + + describe("convertToAnthropicRole", () => { + it("should convert assistant role correctly", () => { + const result = convertToAnthropicRole("assistant" as any) + expect(result).toBe("assistant") + }) + + it("should convert user role correctly", () => { + const result = convertToAnthropicRole("user" as any) + expect(result).toBe("user") + }) + + it("should return null for unknown roles", () => { + const result = convertToAnthropicRole("unknown" as any) + expect(result).toBeNull() + }) + }) + + describe("convertToAnthropicMessage", () => { + it("should convert assistant message with text content", async () => { const vsCodeMessage = { - role: 'assistant', - name: 'assistant', - content: [new vscode.LanguageModelTextPart('Hello')] - }; + role: "assistant", + name: "assistant", + content: [new vscode.LanguageModelTextPart("Hello")], + } - const result = await convertToAnthropicMessage(vsCodeMessage as any); - - expect(result.role).toBe('assistant'); - expect(result.content).toHaveLength(1); + const result = await convertToAnthropicMessage(vsCodeMessage as any) + + expect(result.role).toBe("assistant") + expect(result.content).toHaveLength(1) expect(result.content[0]).toEqual({ - type: 'text', - text: 'Hello' - }); - expect(result.id).toBe('test-uuid'); - }); + type: "text", + text: "Hello", + }) + expect(result.id).toBe("test-uuid") + }) - it('should convert assistant message with tool calls', async () => { + it("should convert assistant message with tool calls", async () => { const vsCodeMessage = { - role: 'assistant', - name: 'assistant', - content: [new vscode.LanguageModelToolCallPart( - 'call-1', - 'calculator', - { operation: 'add', numbers: [2, 2] } - )] - }; + role: "assistant", + name: "assistant", + content: [ + new vscode.LanguageModelToolCallPart("call-1", "calculator", { operation: "add", numbers: [2, 2] }), + ], + } - const result = await convertToAnthropicMessage(vsCodeMessage as any); - - expect(result.content).toHaveLength(1); + const result = await convertToAnthropicMessage(vsCodeMessage as any) + + expect(result.content).toHaveLength(1) expect(result.content[0]).toEqual({ - type: 'tool_use', - id: 'call-1', - name: 'calculator', - input: { operation: 'add', numbers: [2, 2] } - }); - expect(result.id).toBe('test-uuid'); - }); + type: "tool_use", + id: "call-1", + name: "calculator", + input: { operation: "add", numbers: [2, 2] }, + }) + expect(result.id).toBe("test-uuid") + }) - it('should throw error for non-assistant messages', async () => { + it("should throw error for non-assistant messages", async () => { const vsCodeMessage = { - role: 'user', - name: 'user', - content: [new vscode.LanguageModelTextPart('Hello')] - }; + role: "user", + name: "user", + content: [new vscode.LanguageModelTextPart("Hello")], + } - await expect(convertToAnthropicMessage(vsCodeMessage as any)) - .rejects - .toThrow('Cline : Only assistant messages are supported.'); - }); - }); -}); \ No newline at end of file + await expect(convertToAnthropicMessage(vsCodeMessage as any)).rejects.toThrow( + "Cline : Only assistant messages are supported.", + ) + }) + }) +}) diff --git a/src/api/transform/bedrock-converse-format.ts b/src/api/transform/bedrock-converse-format.ts index d3b9abd..07529db 100644 --- a/src/api/transform/bedrock-converse-format.ts +++ b/src/api/transform/bedrock-converse-format.ts @@ -8,210 +8,216 @@ import { StreamEvent } from "../providers/bedrock" /** * Convert Anthropic messages to Bedrock Converse format */ -export function convertToBedrockConverseMessages( - anthropicMessages: Anthropic.Messages.MessageParam[] -): Message[] { - return anthropicMessages.map(anthropicMessage => { - // Map Anthropic roles to Bedrock roles - const role: ConversationRole = anthropicMessage.role === "assistant" ? "assistant" : "user" +export function convertToBedrockConverseMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): Message[] { + return anthropicMessages.map((anthropicMessage) => { + // Map Anthropic roles to Bedrock roles + const role: ConversationRole = anthropicMessage.role === "assistant" ? "assistant" : "user" - if (typeof anthropicMessage.content === "string") { - return { - role, - content: [{ - text: anthropicMessage.content - }] as ContentBlock[] - } - } + if (typeof anthropicMessage.content === "string") { + return { + role, + content: [ + { + text: anthropicMessage.content, + }, + ] as ContentBlock[], + } + } - // Process complex content types - const content = anthropicMessage.content.map(block => { - const messageBlock = block as MessageContent & { - id?: string, - tool_use_id?: string, - content?: Array<{ type: string, text: string }>, - output?: string | Array<{ type: string, text: string }> - } + // Process complex content types + const content = anthropicMessage.content.map((block) => { + const messageBlock = block as MessageContent & { + id?: string + tool_use_id?: string + content?: Array<{ type: string; text: string }> + output?: string | Array<{ type: string; text: string }> + } - if (messageBlock.type === "text") { - return { - text: messageBlock.text || '' - } as ContentBlock - } - - if (messageBlock.type === "image" && messageBlock.source) { - // Convert base64 string to byte array if needed - let byteArray: Uint8Array - if (typeof messageBlock.source.data === 'string') { - const binaryString = atob(messageBlock.source.data) - byteArray = new Uint8Array(binaryString.length) - for (let i = 0; i < binaryString.length; i++) { - byteArray[i] = binaryString.charCodeAt(i) - } - } else { - byteArray = messageBlock.source.data - } + if (messageBlock.type === "text") { + return { + text: messageBlock.text || "", + } as ContentBlock + } - // Extract format from media_type (e.g., "image/jpeg" -> "jpeg") - const format = messageBlock.source.media_type.split('/')[1] - if (!['png', 'jpeg', 'gif', 'webp'].includes(format)) { - throw new Error(`Unsupported image format: ${format}`) - } + if (messageBlock.type === "image" && messageBlock.source) { + // Convert base64 string to byte array if needed + let byteArray: Uint8Array + if (typeof messageBlock.source.data === "string") { + const binaryString = atob(messageBlock.source.data) + byteArray = new Uint8Array(binaryString.length) + for (let i = 0; i < binaryString.length; i++) { + byteArray[i] = binaryString.charCodeAt(i) + } + } else { + byteArray = messageBlock.source.data + } - return { - image: { - format: format as "png" | "jpeg" | "gif" | "webp", - source: { - bytes: byteArray - } - } - } as ContentBlock - } + // Extract format from media_type (e.g., "image/jpeg" -> "jpeg") + const format = messageBlock.source.media_type.split("/")[1] + if (!["png", "jpeg", "gif", "webp"].includes(format)) { + throw new Error(`Unsupported image format: ${format}`) + } - if (messageBlock.type === "tool_use") { - // Convert tool use to XML format - const toolParams = Object.entries(messageBlock.input || {}) - .map(([key, value]) => `<${key}>\n${value}\n`) - .join('\n') + return { + image: { + format: format as "png" | "jpeg" | "gif" | "webp", + source: { + bytes: byteArray, + }, + }, + } as ContentBlock + } - return { - toolUse: { - toolUseId: messageBlock.id || '', - name: messageBlock.name || '', - input: `<${messageBlock.name}>\n${toolParams}\n` - } - } as ContentBlock - } + if (messageBlock.type === "tool_use") { + // Convert tool use to XML format + const toolParams = Object.entries(messageBlock.input || {}) + .map(([key, value]) => `<${key}>\n${value}\n`) + .join("\n") - if (messageBlock.type === "tool_result") { - // First try to use content if available - if (messageBlock.content && Array.isArray(messageBlock.content)) { - return { - toolResult: { - toolUseId: messageBlock.tool_use_id || '', - content: messageBlock.content.map(item => ({ - text: item.text - })), - status: "success" - } - } as ContentBlock - } + return { + toolUse: { + toolUseId: messageBlock.id || "", + name: messageBlock.name || "", + input: `<${messageBlock.name}>\n${toolParams}\n`, + }, + } as ContentBlock + } - // Fall back to output handling if content is not available - if (messageBlock.output && typeof messageBlock.output === "string") { - return { - toolResult: { - toolUseId: messageBlock.tool_use_id || '', - content: [{ - text: messageBlock.output - }], - status: "success" - } - } as ContentBlock - } - // Handle array of content blocks if output is an array - if (Array.isArray(messageBlock.output)) { - return { - toolResult: { - toolUseId: messageBlock.tool_use_id || '', - content: messageBlock.output.map(part => { - if (typeof part === "object" && "text" in part) { - return { text: part.text } - } - // Skip images in tool results as they're handled separately - if (typeof part === "object" && "type" in part && part.type === "image") { - return { text: "(see following message for image)" } - } - return { text: String(part) } - }), - status: "success" - } - } as ContentBlock - } + if (messageBlock.type === "tool_result") { + // First try to use content if available + if (messageBlock.content && Array.isArray(messageBlock.content)) { + return { + toolResult: { + toolUseId: messageBlock.tool_use_id || "", + content: messageBlock.content.map((item) => ({ + text: item.text, + })), + status: "success", + }, + } as ContentBlock + } - // Default case - return { - toolResult: { - toolUseId: messageBlock.tool_use_id || '', - content: [{ - text: String(messageBlock.output || '') - }], - status: "success" - } - } as ContentBlock - } + // Fall back to output handling if content is not available + if (messageBlock.output && typeof messageBlock.output === "string") { + return { + toolResult: { + toolUseId: messageBlock.tool_use_id || "", + content: [ + { + text: messageBlock.output, + }, + ], + status: "success", + }, + } as ContentBlock + } + // Handle array of content blocks if output is an array + if (Array.isArray(messageBlock.output)) { + return { + toolResult: { + toolUseId: messageBlock.tool_use_id || "", + content: messageBlock.output.map((part) => { + if (typeof part === "object" && "text" in part) { + return { text: part.text } + } + // Skip images in tool results as they're handled separately + if (typeof part === "object" && "type" in part && part.type === "image") { + return { text: "(see following message for image)" } + } + return { text: String(part) } + }), + status: "success", + }, + } as ContentBlock + } - if (messageBlock.type === "video") { - const videoContent = messageBlock.s3Location ? { - s3Location: { - uri: messageBlock.s3Location.uri, - bucketOwner: messageBlock.s3Location.bucketOwner - } - } : messageBlock.source + // Default case + return { + toolResult: { + toolUseId: messageBlock.tool_use_id || "", + content: [ + { + text: String(messageBlock.output || ""), + }, + ], + status: "success", + }, + } as ContentBlock + } - return { - video: { - format: "mp4", // Default to mp4, adjust based on actual format if needed - source: videoContent - } - } as ContentBlock - } + if (messageBlock.type === "video") { + const videoContent = messageBlock.s3Location + ? { + s3Location: { + uri: messageBlock.s3Location.uri, + bucketOwner: messageBlock.s3Location.bucketOwner, + }, + } + : messageBlock.source - // Default case for unknown block types - return { - text: '[Unknown Block Type]' - } as ContentBlock - }) + return { + video: { + format: "mp4", // Default to mp4, adjust based on actual format if needed + source: videoContent, + }, + } as ContentBlock + } - return { - role, - content - } - }) + // Default case for unknown block types + return { + text: "[Unknown Block Type]", + } as ContentBlock + }) + + return { + role, + content, + } + }) } /** * Convert Bedrock Converse stream events to Anthropic message format */ export function convertToAnthropicMessage( - streamEvent: StreamEvent, - modelId: string + streamEvent: StreamEvent, + modelId: string, ): Partial { - // Handle metadata events - if (streamEvent.metadata?.usage) { - return { - id: '', // Bedrock doesn't provide message IDs - type: "message", - role: "assistant", - model: modelId, - usage: { - input_tokens: streamEvent.metadata.usage.inputTokens || 0, - output_tokens: streamEvent.metadata.usage.outputTokens || 0 - } - } - } + // Handle metadata events + if (streamEvent.metadata?.usage) { + return { + id: "", // Bedrock doesn't provide message IDs + type: "message", + role: "assistant", + model: modelId, + usage: { + input_tokens: streamEvent.metadata.usage.inputTokens || 0, + output_tokens: streamEvent.metadata.usage.outputTokens || 0, + }, + } + } - // Handle content blocks - const text = streamEvent.contentBlockStart?.start?.text || streamEvent.contentBlockDelta?.delta?.text - if (text !== undefined) { - return { - type: "message", - role: "assistant", - content: [{ type: "text", text: text }], - model: modelId - } - } + // Handle content blocks + const text = streamEvent.contentBlockStart?.start?.text || streamEvent.contentBlockDelta?.delta?.text + if (text !== undefined) { + return { + type: "message", + role: "assistant", + content: [{ type: "text", text: text }], + model: modelId, + } + } - // Handle message stop - if (streamEvent.messageStop) { - return { - type: "message", - role: "assistant", - stop_reason: streamEvent.messageStop.stopReason || null, - stop_sequence: null, - model: modelId - } - } + // Handle message stop + if (streamEvent.messageStop) { + return { + type: "message", + role: "assistant", + stop_reason: streamEvent.messageStop.stopReason || null, + stop_sequence: null, + model: modelId, + } + } - return {} + return {} } diff --git a/src/api/transform/vscode-lm-format.ts b/src/api/transform/vscode-lm-format.ts index 5ccc6e6..acec365 100644 --- a/src/api/transform/vscode-lm-format.ts +++ b/src/api/transform/vscode-lm-format.ts @@ -1,5 +1,5 @@ -import { Anthropic } from "@anthropic-ai/sdk"; -import * as vscode from 'vscode'; +import { Anthropic } from "@anthropic-ai/sdk" +import * as vscode from "vscode" /** * Safely converts a value into a plain object. @@ -7,30 +7,31 @@ import * as vscode from 'vscode'; function asObjectSafe(value: any): object { // Handle null/undefined if (!value) { - return {}; + return {} } try { // Handle strings that might be JSON - if (typeof value === 'string') { - return JSON.parse(value); + if (typeof value === "string") { + return JSON.parse(value) } // Handle pre-existing objects - if (typeof value === 'object') { - return Object.assign({}, value); + if (typeof value === "object") { + return Object.assign({}, value) } - return {}; - } - catch (error) { - console.warn('Cline : Failed to parse object:', error); - return {}; + return {} + } catch (error) { + console.warn("Cline : Failed to parse object:", error) + return {} } } -export function convertToVsCodeLmMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): vscode.LanguageModelChatMessage[] { - const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = []; +export function convertToVsCodeLmMessages( + anthropicMessages: Anthropic.Messages.MessageParam[], +): vscode.LanguageModelChatMessage[] { + const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = [] for (const anthropicMessage of anthropicMessages) { // Handle simple string messages @@ -38,135 +39,129 @@ export function convertToVsCodeLmMessages(anthropicMessages: Anthropic.Messages. vsCodeLmMessages.push( anthropicMessage.role === "assistant" ? vscode.LanguageModelChatMessage.Assistant(anthropicMessage.content) - : vscode.LanguageModelChatMessage.User(anthropicMessage.content) - ); - continue; + : vscode.LanguageModelChatMessage.User(anthropicMessage.content), + ) + continue } // Handle complex message structures switch (anthropicMessage.role) { case "user": { const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{ - nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]; - toolMessages: Anthropic.ToolResultBlockParam[]; + nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[] + toolMessages: Anthropic.ToolResultBlockParam[] }>( (acc, part) => { if (part.type === "tool_result") { - acc.toolMessages.push(part); + acc.toolMessages.push(part) + } else if (part.type === "text" || part.type === "image") { + acc.nonToolMessages.push(part) } - else if (part.type === "text" || part.type === "image") { - acc.nonToolMessages.push(part); - } - return acc; + return acc }, { nonToolMessages: [], toolMessages: [] }, - ); + ) // Process tool messages first then non-tool messages const contentParts = [ // Convert tool messages to ToolResultParts ...toolMessages.map((toolMessage) => { // Process tool result content into TextParts - const toolContentParts: vscode.LanguageModelTextPart[] = ( + const toolContentParts: vscode.LanguageModelTextPart[] = typeof toolMessage.content === "string" ? [new vscode.LanguageModelTextPart(toolMessage.content)] - : ( - toolMessage.content?.map((part) => { + : (toolMessage.content?.map((part) => { if (part.type === "image") { return new vscode.LanguageModelTextPart( - `[Image (${part.source?.type || 'Unknown source-type'}): ${part.source?.media_type || 'unknown media-type'} not supported by VSCode LM API]` - ); + `[Image (${part.source?.type || "Unknown source-type"}): ${part.source?.media_type || "unknown media-type"} not supported by VSCode LM API]`, + ) } - return new vscode.LanguageModelTextPart(part.text); - }) - ?? [new vscode.LanguageModelTextPart("")] - ) - ); + return new vscode.LanguageModelTextPart(part.text) + }) ?? [new vscode.LanguageModelTextPart("")]) - return new vscode.LanguageModelToolResultPart( - toolMessage.tool_use_id, - toolContentParts - ); + return new vscode.LanguageModelToolResultPart(toolMessage.tool_use_id, toolContentParts) }), // Convert non-tool messages to TextParts after tool messages ...nonToolMessages.map((part) => { if (part.type === "image") { return new vscode.LanguageModelTextPart( - `[Image (${part.source?.type || 'Unknown source-type'}): ${part.source?.media_type || 'unknown media-type'} not supported by VSCode LM API]` - ); + `[Image (${part.source?.type || "Unknown source-type"}): ${part.source?.media_type || "unknown media-type"} not supported by VSCode LM API]`, + ) } - return new vscode.LanguageModelTextPart(part.text); - }) - ]; + return new vscode.LanguageModelTextPart(part.text) + }), + ] // Add single user message with all content parts - vsCodeLmMessages.push(vscode.LanguageModelChatMessage.User(contentParts)); - break; + vsCodeLmMessages.push(vscode.LanguageModelChatMessage.User(contentParts)) + break } case "assistant": { const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{ - nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]; - toolMessages: Anthropic.ToolUseBlockParam[]; + nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[] + toolMessages: Anthropic.ToolUseBlockParam[] }>( (acc, part) => { if (part.type === "tool_use") { - acc.toolMessages.push(part); + acc.toolMessages.push(part) + } else if (part.type === "text" || part.type === "image") { + acc.nonToolMessages.push(part) } - else if (part.type === "text" || part.type === "image") { - acc.nonToolMessages.push(part); - } - return acc; + return acc }, { nonToolMessages: [], toolMessages: [] }, - ); + ) - // Process tool messages first then non-tool messages + // Process tool messages first then non-tool messages const contentParts = [ // Convert tool messages to ToolCallParts first - ...toolMessages.map((toolMessage) => - new vscode.LanguageModelToolCallPart( - toolMessage.id, - toolMessage.name, - asObjectSafe(toolMessage.input) - ) + ...toolMessages.map( + (toolMessage) => + new vscode.LanguageModelToolCallPart( + toolMessage.id, + toolMessage.name, + asObjectSafe(toolMessage.input), + ), ), // Convert non-tool messages to TextParts after tool messages ...nonToolMessages.map((part) => { if (part.type === "image") { - return new vscode.LanguageModelTextPart("[Image generation not supported by VSCode LM API]"); + return new vscode.LanguageModelTextPart("[Image generation not supported by VSCode LM API]") } - return new vscode.LanguageModelTextPart(part.text); - }) - ]; + return new vscode.LanguageModelTextPart(part.text) + }), + ] // Add the assistant message to the list of messages - vsCodeLmMessages.push(vscode.LanguageModelChatMessage.Assistant(contentParts)); - break; + vsCodeLmMessages.push(vscode.LanguageModelChatMessage.Assistant(contentParts)) + break } } } - return vsCodeLmMessages; + return vsCodeLmMessages } export function convertToAnthropicRole(vsCodeLmMessageRole: vscode.LanguageModelChatMessageRole): string | null { switch (vsCodeLmMessageRole) { case vscode.LanguageModelChatMessageRole.Assistant: - return "assistant"; + return "assistant" case vscode.LanguageModelChatMessageRole.User: - return "user"; + return "user" default: - return null; + return null } } -export async function convertToAnthropicMessage(vsCodeLmMessage: vscode.LanguageModelChatMessage): Promise { - const anthropicRole: string | null = convertToAnthropicRole(vsCodeLmMessage.role); +export async function convertToAnthropicMessage( + vsCodeLmMessage: vscode.LanguageModelChatMessage, +): Promise { + const anthropicRole: string | null = convertToAnthropicRole(vsCodeLmMessage.role) if (anthropicRole !== "assistant") { - throw new Error("Cline : Only assistant messages are supported."); + throw new Error("Cline : Only assistant messages are supported.") } return { @@ -174,36 +169,32 @@ export async function convertToAnthropicMessage(vsCodeLmMessage: vscode.Language type: "message", model: "vscode-lm", role: anthropicRole, - content: ( - vsCodeLmMessage.content - .map((part): Anthropic.ContentBlock | null => { - if (part instanceof vscode.LanguageModelTextPart) { - return { - type: "text", - text: part.value - }; + content: vsCodeLmMessage.content + .map((part): Anthropic.ContentBlock | null => { + if (part instanceof vscode.LanguageModelTextPart) { + return { + type: "text", + text: part.value, } + } - if (part instanceof vscode.LanguageModelToolCallPart) { - return { - type: "tool_use", - id: part.callId || crypto.randomUUID(), - name: part.name, - input: asObjectSafe(part.input) - }; + if (part instanceof vscode.LanguageModelToolCallPart) { + return { + type: "tool_use", + id: part.callId || crypto.randomUUID(), + name: part.name, + input: asObjectSafe(part.input), } + } - return null; - }) - .filter( - (part): part is Anthropic.ContentBlock => part !== null - ) - ), + return null + }) + .filter((part): part is Anthropic.ContentBlock => part !== null), stop_reason: null, stop_sequence: null, usage: { input_tokens: 0, output_tokens: 0, - } - }; + }, + } } diff --git a/src/core/Cline.ts b/src/core/Cline.ts index 87d020c..eb78cc4 100644 --- a/src/core/Cline.ts +++ b/src/core/Cline.ts @@ -13,7 +13,13 @@ import { ApiHandler, SingleCompletionHandler, buildApiHandler } from "../api" import { ApiStream } from "../api/transform/stream" import { DiffViewProvider } from "../integrations/editor/DiffViewProvider" import { findToolName, formatContentBlockToMarkdown } from "../integrations/misc/export-markdown" -import { extractTextFromFile, addLineNumbers, stripLineNumbers, everyLineHasLineNumbers, truncateOutput } from "../integrations/misc/extract-text" +import { + extractTextFromFile, + addLineNumbers, + stripLineNumbers, + everyLineHasLineNumbers, + truncateOutput, +} from "../integrations/misc/extract-text" import { TerminalManager } from "../integrations/terminal/TerminalManager" import { UrlContentFetcher } from "../services/browser/UrlContentFetcher" import { listFiles } from "../services/glob/list-files" @@ -112,7 +118,7 @@ export class Cline { experimentalDiffStrategy: boolean = false, ) { if (!task && !images && !historyItem) { - throw new Error('Either historyItem or task/images must be provided'); + throw new Error("Either historyItem or task/images must be provided") } this.taskId = crypto.randomUUID() @@ -144,7 +150,8 @@ export class Cline { async updateDiffStrategy(experimentalDiffStrategy?: boolean) { // If not provided, get from current state if (experimentalDiffStrategy === undefined) { - const { experimentalDiffStrategy: stateExperimentalDiffStrategy } = await this.providerRef.deref()?.getState() ?? {} + const { experimentalDiffStrategy: stateExperimentalDiffStrategy } = + (await this.providerRef.deref()?.getState()) ?? {} experimentalDiffStrategy = stateExperimentalDiffStrategy ?? false } this.diffStrategy = getDiffStrategy(this.api.getModel().id, this.fuzzyMatchThreshold, experimentalDiffStrategy) @@ -471,7 +478,7 @@ export class Cline { // need to make sure that the api conversation history can be resumed by the api, even if it goes out of sync with cline messages let existingApiConversationHistory: Anthropic.Messages.MessageParam[] = - await this.getSavedApiConversationHistory() + await this.getSavedApiConversationHistory() // Now present the cline messages to the user and ask if they want to resume @@ -582,8 +589,8 @@ export class Cline { : [{ type: "text", text: lastMessage.content }] if (previousAssistantMessage && previousAssistantMessage.role === "assistant") { const assistantContent = Array.isArray(previousAssistantMessage.content) - ? previousAssistantMessage.content - : [{ type: "text", text: previousAssistantMessage.content }] + ? previousAssistantMessage.content + : [{ type: "text", text: previousAssistantMessage.content }] const toolUseBlocks = assistantContent.filter( (block) => block.type === "tool_use", @@ -756,8 +763,8 @@ export class Cline { // grouping command_output messages despite any gaps anyways) await delay(50) - const { terminalOutputLineLimit } = await this.providerRef.deref()?.getState() ?? {} - const output = truncateOutput(lines.join('\n'), terminalOutputLineLimit) + const { terminalOutputLineLimit } = (await this.providerRef.deref()?.getState()) ?? {} + const output = truncateOutput(lines.join("\n"), terminalOutputLineLimit) const result = output.trim() if (userFeedback) { @@ -788,7 +795,8 @@ export class Cline { async *attemptApiRequest(previousApiReqIndex: number): ApiStream { let mcpHub: McpHub | undefined - const { mcpEnabled, alwaysApproveResubmit, requestDelaySeconds } = await this.providerRef.deref()?.getState() ?? {} + const { mcpEnabled, alwaysApproveResubmit, requestDelaySeconds } = + (await this.providerRef.deref()?.getState()) ?? {} if (mcpEnabled ?? true) { mcpHub = this.providerRef.deref()?.mcpHub @@ -801,24 +809,27 @@ export class Cline { }) } - const { browserViewportSize, preferredLanguage, mode, customPrompts } = await this.providerRef.deref()?.getState() ?? {} - const systemPrompt = await SYSTEM_PROMPT( - cwd, - this.api.getModel().info.supportsComputerUse ?? false, - mcpHub, - this.diffStrategy, - browserViewportSize, - mode, - customPrompts - ) + await addCustomInstructions( - { - customInstructions: this.customInstructions, + const { browserViewportSize, preferredLanguage, mode, customPrompts } = + (await this.providerRef.deref()?.getState()) ?? {} + const systemPrompt = + (await SYSTEM_PROMPT( + cwd, + this.api.getModel().info.supportsComputerUse ?? false, + mcpHub, + this.diffStrategy, + browserViewportSize, + mode, customPrompts, - preferredLanguage - }, - cwd, - mode - ) + )) + + (await addCustomInstructions( + { + customInstructions: this.customInstructions, + customPrompts, + preferredLanguage, + }, + cwd, + mode, + )) // If the previous API request's total token usage is close to the context window, truncate the conversation history to free up space for the new request if (previousApiReqIndex >= 0) { @@ -845,18 +856,18 @@ export class Cline { if (Array.isArray(content)) { if (!this.api.getModel().info.supportsImages) { // Convert image blocks to text descriptions - content = content.map(block => { - if (block.type === 'image') { + content = content.map((block) => { + if (block.type === "image") { // Convert image blocks to text descriptions // Note: We can't access the actual image content/url due to API limitations, // but we can indicate that an image was present in the conversation return { - type: 'text', - text: '[Referenced image in conversation]' - }; + type: "text", + text: "[Referenced image in conversation]", + } } - return block; - }); + return block + }) } } return { role, content } @@ -876,7 +887,12 @@ export class Cline { // Automatically retry with delay // Show countdown timer in error color for (let i = requestDelay; i > 0; i--) { - await this.say("api_req_retry_delayed", `${errorMsg}\n\nRetrying in ${i} seconds...`, undefined, true) + await this.say( + "api_req_retry_delayed", + `${errorMsg}\n\nRetrying in ${i} seconds...`, + undefined, + true, + ) await delay(1000) } await this.say("api_req_retry_delayed", `${errorMsg}\n\nRetrying now...`, undefined, false) @@ -1125,7 +1141,7 @@ export class Cline { } // Validate tool use based on current mode - const { mode } = await this.providerRef.deref()?.getState() ?? {} + const { mode } = (await this.providerRef.deref()?.getState()) ?? {} try { validateToolUse(block.name, mode ?? defaultModeSlug) } catch (error) { @@ -1192,7 +1208,10 @@ export class Cline { await this.diffViewProvider.open(relPath) } // editor is open, stream content in - await this.diffViewProvider.update(everyLineHasLineNumbers(newContent) ? stripLineNumbers(newContent) : newContent, false) + await this.diffViewProvider.update( + everyLineHasLineNumbers(newContent) ? stripLineNumbers(newContent) : newContent, + false, + ) break } else { if (!relPath) { @@ -1209,7 +1228,9 @@ export class Cline { } if (!predictedLineCount) { this.consecutiveMistakeCount++ - pushToolResult(await this.sayAndCreateMissingParamError("write_to_file", "line_count")) + pushToolResult( + await this.sayAndCreateMissingParamError("write_to_file", "line_count"), + ) await this.diffViewProvider.reset() break } @@ -1224,17 +1245,28 @@ export class Cline { await this.ask("tool", partialMessage, true).catch(() => {}) // sending true for partial even though it's not a partial, this shows the edit row before the content is streamed into the editor await this.diffViewProvider.open(relPath) } - await this.diffViewProvider.update(everyLineHasLineNumbers(newContent) ? stripLineNumbers(newContent) : newContent, true) + await this.diffViewProvider.update( + everyLineHasLineNumbers(newContent) ? stripLineNumbers(newContent) : newContent, + true, + ) await delay(300) // wait for diff view to update this.diffViewProvider.scrollToFirstDiff() // Check for code omissions before proceeding - if (detectCodeOmission(this.diffViewProvider.originalContent || "", newContent, predictedLineCount)) { + if ( + detectCodeOmission( + this.diffViewProvider.originalContent || "", + newContent, + predictedLineCount, + ) + ) { if (this.diffStrategy) { await this.diffViewProvider.revertChanges() - pushToolResult(formatResponse.toolError( - `Content appears to be truncated (file has ${newContent.split("\n").length} lines but was predicted to have ${predictedLineCount} lines), and found comments indicating omitted code (e.g., '// rest of code unchanged', '/* previous code */'). Please provide the complete file content without any omissions if possible, or otherwise use the 'apply_diff' tool to apply the diff to the original file.` - )) + pushToolResult( + formatResponse.toolError( + `Content appears to be truncated (file has ${newContent.split("\n").length} lines but was predicted to have ${predictedLineCount} lines), and found comments indicating omitted code (e.g., '// rest of code unchanged', '/* previous code */'). Please provide the complete file content without any omissions if possible, or otherwise use the 'apply_diff' tool to apply the diff to the original file.`, + ), + ) break } else { vscode.window @@ -1285,7 +1317,7 @@ export class Cline { pushToolResult( `The user made the following updates to your content:\n\n${userEdits}\n\n` + `The updated content, which includes both your original modifications and the user's edits, has been successfully saved to ${relPath.toPosix()}. Here is the full, updated content of the file, including line numbers:\n\n` + - `\n${addLineNumbers(finalContent || '')}\n\n\n` + + `\n${addLineNumbers(finalContent || "")}\n\n\n` + `Please note:\n` + `1. You do not need to re-write the file with these changes, as they have already been applied.\n` + `2. Proceed with the task using this updated file content as the new baseline.\n` + @@ -1347,21 +1379,24 @@ export class Cline { const originalContent = await fs.readFile(absolutePath, "utf-8") // Apply the diff to the original content - const diffResult = await this.diffStrategy?.applyDiff( - originalContent, - diffContent, - parseInt(block.params.start_line ?? ''), - parseInt(block.params.end_line ?? '') - ) ?? { + const diffResult = (await this.diffStrategy?.applyDiff( + originalContent, + diffContent, + parseInt(block.params.start_line ?? ""), + parseInt(block.params.end_line ?? ""), + )) ?? { success: false, - error: "No diff strategy available" + error: "No diff strategy available", } if (!diffResult.success) { this.consecutiveMistakeCount++ - const currentCount = (this.consecutiveMistakeCountForApplyDiff.get(relPath) || 0) + 1 + const currentCount = + (this.consecutiveMistakeCountForApplyDiff.get(relPath) || 0) + 1 this.consecutiveMistakeCountForApplyDiff.set(relPath, currentCount) - const errorDetails = diffResult.details ? JSON.stringify(diffResult.details, null, 2) : '' - const formattedError = `Unable to apply diff to file: ${absolutePath}\n\n\n${diffResult.error}${errorDetails ? `\n\nDetails:\n${errorDetails}` : ''}\n` + const errorDetails = diffResult.details + ? JSON.stringify(diffResult.details, null, 2) + : "" + const formattedError = `Unable to apply diff to file: ${absolutePath}\n\n\n${diffResult.error}${errorDetails ? `\n\nDetails:\n${errorDetails}` : ""}\n` if (currentCount >= 2) { await this.say("error", formattedError) } @@ -1373,9 +1408,9 @@ export class Cline { this.consecutiveMistakeCountForApplyDiff.delete(relPath) // Show diff view before asking for approval this.diffViewProvider.editType = "modify" - await this.diffViewProvider.open(relPath); - await this.diffViewProvider.update(diffResult.content, true); - await this.diffViewProvider.scrollToFirstDiff(); + await this.diffViewProvider.open(relPath) + await this.diffViewProvider.update(diffResult.content, true) + await this.diffViewProvider.scrollToFirstDiff() const completeMessage = JSON.stringify({ ...sharedMessageProps, @@ -1403,7 +1438,7 @@ export class Cline { pushToolResult( `The user made the following updates to your content:\n\n${userEdits}\n\n` + `The updated content, which includes both your original modifications and the user's edits, has been successfully saved to ${relPath.toPosix()}. Here is the full, updated content of the file, including line numbers:\n\n` + - `\n${addLineNumbers(finalContent || '')}\n\n\n` + + `\n${addLineNumbers(finalContent || "")}\n\n\n` + `Please note:\n` + `1. You do not need to re-write the file with these changes, as they have already been applied.\n` + `2. Proceed with the task using this updated file content as the new baseline.\n` + @@ -1411,7 +1446,9 @@ export class Cline { `${newProblemsMessage}`, ) } else { - pushToolResult(`Changes successfully applied to ${relPath.toPosix()}:\n\n${newProblemsMessage}`) + pushToolResult( + `Changes successfully applied to ${relPath.toPosix()}:\n\n${newProblemsMessage}`, + ) } await this.diffViewProvider.reset() break @@ -1615,7 +1652,7 @@ export class Cline { await this.ask( "browser_action_launch", removeClosingTag("url", url), - block.partial + block.partial, ).catch(() => {}) } else { await this.say( @@ -1744,7 +1781,7 @@ export class Cline { try { if (block.partial) { await this.ask("command", removeClosingTag("command", command), block.partial).catch( - () => {} + () => {}, ) break } else { @@ -2409,7 +2446,7 @@ export class Cline { Promise.all( userContent.map(async (block) => { const shouldProcessMentions = (text: string) => - text.includes("") || text.includes(""); + text.includes("") || text.includes("") if (block.type === "text") { if (shouldProcessMentions(block.text)) { @@ -2418,7 +2455,7 @@ export class Cline { text: await parseMentions(block.text, cwd, this.urlContentFetcher), } } - return block; + return block } else if (block.type === "tool_result") { if (typeof block.content === "string") { if (shouldProcessMentions(block.content)) { @@ -2427,7 +2464,7 @@ export class Cline { content: await parseMentions(block.content, cwd, this.urlContentFetcher), } } - return block; + return block } else if (Array.isArray(block.content)) { const parsedContent = await Promise.all( block.content.map(async (contentBlock) => { @@ -2445,7 +2482,7 @@ export class Cline { content: parsedContent, } } - return block; + return block } return block }), @@ -2571,26 +2608,29 @@ export class Cline { // Add current time information with timezone const now = new Date() const formatter = new Intl.DateTimeFormat(undefined, { - year: 'numeric', - month: 'numeric', - day: 'numeric', - hour: 'numeric', - minute: 'numeric', - second: 'numeric', - hour12: true + year: "numeric", + month: "numeric", + day: "numeric", + hour: "numeric", + minute: "numeric", + second: "numeric", + hour12: true, }) const timeZone = formatter.resolvedOptions().timeZone const timeZoneOffset = -now.getTimezoneOffset() / 60 // Convert to hours and invert sign to match conventional notation - const timeZoneOffsetStr = `${timeZoneOffset >= 0 ? '+' : ''}${timeZoneOffset}:00` + const timeZoneOffsetStr = `${timeZoneOffset >= 0 ? "+" : ""}${timeZoneOffset}:00` details += `\n\n# Current Time\n${formatter.format(now)} (${timeZone}, UTC${timeZoneOffsetStr})` // Add current mode and any mode-specific warnings - const { mode } = await this.providerRef.deref()?.getState() ?? {} + const { mode } = (await this.providerRef.deref()?.getState()) ?? {} const currentMode = mode ?? defaultModeSlug details += `\n\n# Current Mode\n${currentMode}` // Add warning if not in code mode - if (!isToolAllowedForMode('write_to_file', currentMode) || !isToolAllowedForMode('execute_command', currentMode)) { + if ( + !isToolAllowedForMode("write_to_file", currentMode) || + !isToolAllowedForMode("execute_command", currentMode) + ) { details += `\n\nNOTE: You are currently in '${currentMode}' mode which only allows read-only operations. To write files or execute commands, the user will need to switch to '${defaultModeSlug}' mode. Note that only the user can switch modes.` } @@ -2609,4 +2649,4 @@ export class Cline { return `\n${details.trim()}\n` } -} \ No newline at end of file +} diff --git a/src/core/__tests__/Cline.test.ts b/src/core/__tests__/Cline.test.ts index 11e9f9d..4cf0def 100644 --- a/src/core/__tests__/Cline.test.ts +++ b/src/core/__tests__/Cline.test.ts @@ -1,835 +1,807 @@ -import { Cline } from '../Cline'; -import { ClineProvider } from '../webview/ClineProvider'; -import { ApiConfiguration, ModelInfo } from '../../shared/api'; -import { ApiStreamChunk } from '../../api/transform/stream'; -import { Anthropic } from '@anthropic-ai/sdk'; -import * as vscode from 'vscode'; +import { Cline } from "../Cline" +import { ClineProvider } from "../webview/ClineProvider" +import { ApiConfiguration, ModelInfo } from "../../shared/api" +import { ApiStreamChunk } from "../../api/transform/stream" +import { Anthropic } from "@anthropic-ai/sdk" +import * as vscode from "vscode" // Mock all MCP-related modules -jest.mock('@modelcontextprotocol/sdk/types.js', () => ({ - CallToolResultSchema: {}, - ListResourcesResultSchema: {}, - ListResourceTemplatesResultSchema: {}, - ListToolsResultSchema: {}, - ReadResourceResultSchema: {}, - ErrorCode: { - InvalidRequest: 'InvalidRequest', - MethodNotFound: 'MethodNotFound', - InternalError: 'InternalError' - }, - McpError: class McpError extends Error { - code: string; - constructor(code: string, message: string) { - super(message); - this.code = code; - this.name = 'McpError'; - } - } -}), { virtual: true }); +jest.mock( + "@modelcontextprotocol/sdk/types.js", + () => ({ + CallToolResultSchema: {}, + ListResourcesResultSchema: {}, + ListResourceTemplatesResultSchema: {}, + ListToolsResultSchema: {}, + ReadResourceResultSchema: {}, + ErrorCode: { + InvalidRequest: "InvalidRequest", + MethodNotFound: "MethodNotFound", + InternalError: "InternalError", + }, + McpError: class McpError extends Error { + code: string + constructor(code: string, message: string) { + super(message) + this.code = code + this.name = "McpError" + } + }, + }), + { virtual: true }, +) -jest.mock('@modelcontextprotocol/sdk/client/index.js', () => ({ - Client: jest.fn().mockImplementation(() => ({ - connect: jest.fn().mockResolvedValue(undefined), - close: jest.fn().mockResolvedValue(undefined), - listTools: jest.fn().mockResolvedValue({ tools: [] }), - callTool: jest.fn().mockResolvedValue({ content: [] }) - })) -}), { virtual: true }); +jest.mock( + "@modelcontextprotocol/sdk/client/index.js", + () => ({ + Client: jest.fn().mockImplementation(() => ({ + connect: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + listTools: jest.fn().mockResolvedValue({ tools: [] }), + callTool: jest.fn().mockResolvedValue({ content: [] }), + })), + }), + { virtual: true }, +) -jest.mock('@modelcontextprotocol/sdk/client/stdio.js', () => ({ - StdioClientTransport: jest.fn().mockImplementation(() => ({ - connect: jest.fn().mockResolvedValue(undefined), - close: jest.fn().mockResolvedValue(undefined) - })) -}), { virtual: true }); +jest.mock( + "@modelcontextprotocol/sdk/client/stdio.js", + () => ({ + StdioClientTransport: jest.fn().mockImplementation(() => ({ + connect: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + })), + }), + { virtual: true }, +) // Mock fileExistsAtPath -jest.mock('../../utils/fs', () => ({ - fileExistsAtPath: jest.fn().mockImplementation((filePath) => { - return filePath.includes('ui_messages.json') || - filePath.includes('api_conversation_history.json'); - }) -})); +jest.mock("../../utils/fs", () => ({ + fileExistsAtPath: jest.fn().mockImplementation((filePath) => { + return filePath.includes("ui_messages.json") || filePath.includes("api_conversation_history.json") + }), +})) // Mock fs/promises -const mockMessages = [{ - ts: Date.now(), - type: 'say', - say: 'text', - text: 'historical task' -}]; +const mockMessages = [ + { + ts: Date.now(), + type: "say", + say: "text", + text: "historical task", + }, +] -jest.mock('fs/promises', () => ({ - mkdir: jest.fn().mockResolvedValue(undefined), - writeFile: jest.fn().mockResolvedValue(undefined), - readFile: jest.fn().mockImplementation((filePath) => { - if (filePath.includes('ui_messages.json')) { - return Promise.resolve(JSON.stringify(mockMessages)); - } - if (filePath.includes('api_conversation_history.json')) { - return Promise.resolve('[]'); - } - return Promise.resolve('[]'); - }), - unlink: jest.fn().mockResolvedValue(undefined), - rmdir: jest.fn().mockResolvedValue(undefined) -})); +jest.mock("fs/promises", () => ({ + mkdir: jest.fn().mockResolvedValue(undefined), + writeFile: jest.fn().mockResolvedValue(undefined), + readFile: jest.fn().mockImplementation((filePath) => { + if (filePath.includes("ui_messages.json")) { + return Promise.resolve(JSON.stringify(mockMessages)) + } + if (filePath.includes("api_conversation_history.json")) { + return Promise.resolve("[]") + } + return Promise.resolve("[]") + }), + unlink: jest.fn().mockResolvedValue(undefined), + rmdir: jest.fn().mockResolvedValue(undefined), +})) // Mock dependencies -jest.mock('vscode', () => { - const mockDisposable = { dispose: jest.fn() }; - const mockEventEmitter = { - event: jest.fn(), - fire: jest.fn() - }; +jest.mock("vscode", () => { + const mockDisposable = { dispose: jest.fn() } + const mockEventEmitter = { + event: jest.fn(), + fire: jest.fn(), + } - const mockTextDocument = { - uri: { - fsPath: '/mock/workspace/path/file.ts' - } - }; + const mockTextDocument = { + uri: { + fsPath: "/mock/workspace/path/file.ts", + }, + } - const mockTextEditor = { - document: mockTextDocument - }; + const mockTextEditor = { + document: mockTextDocument, + } - const mockTab = { - input: { - uri: { - fsPath: '/mock/workspace/path/file.ts' - } - } - }; + const mockTab = { + input: { + uri: { + fsPath: "/mock/workspace/path/file.ts", + }, + }, + } - const mockTabGroup = { - tabs: [mockTab] - }; + const mockTabGroup = { + tabs: [mockTab], + } - return { - window: { - createTextEditorDecorationType: jest.fn().mockReturnValue({ - dispose: jest.fn() - }), - visibleTextEditors: [mockTextEditor], - tabGroups: { - all: [mockTabGroup] - } - }, - workspace: { - workspaceFolders: [{ - uri: { - fsPath: '/mock/workspace/path' - }, - name: 'mock-workspace', - index: 0 - }], - createFileSystemWatcher: jest.fn(() => ({ - onDidCreate: jest.fn(() => mockDisposable), - onDidDelete: jest.fn(() => mockDisposable), - onDidChange: jest.fn(() => mockDisposable), - dispose: jest.fn() - })), - fs: { - stat: jest.fn().mockResolvedValue({ type: 1 }) // FileType.File = 1 - }, - onDidSaveTextDocument: jest.fn(() => mockDisposable) - }, - env: { - uriScheme: 'vscode', - language: 'en' - }, - EventEmitter: jest.fn().mockImplementation(() => mockEventEmitter), - Disposable: { - from: jest.fn() - }, - TabInputText: jest.fn() - }; -}); + return { + window: { + createTextEditorDecorationType: jest.fn().mockReturnValue({ + dispose: jest.fn(), + }), + visibleTextEditors: [mockTextEditor], + tabGroups: { + all: [mockTabGroup], + }, + }, + workspace: { + workspaceFolders: [ + { + uri: { + fsPath: "/mock/workspace/path", + }, + name: "mock-workspace", + index: 0, + }, + ], + createFileSystemWatcher: jest.fn(() => ({ + onDidCreate: jest.fn(() => mockDisposable), + onDidDelete: jest.fn(() => mockDisposable), + onDidChange: jest.fn(() => mockDisposable), + dispose: jest.fn(), + })), + fs: { + stat: jest.fn().mockResolvedValue({ type: 1 }), // FileType.File = 1 + }, + onDidSaveTextDocument: jest.fn(() => mockDisposable), + }, + env: { + uriScheme: "vscode", + language: "en", + }, + EventEmitter: jest.fn().mockImplementation(() => mockEventEmitter), + Disposable: { + from: jest.fn(), + }, + TabInputText: jest.fn(), + } +}) // Mock p-wait-for to resolve immediately -jest.mock('p-wait-for', () => ({ - __esModule: true, - default: jest.fn().mockImplementation(async () => Promise.resolve()) -})); +jest.mock("p-wait-for", () => ({ + __esModule: true, + default: jest.fn().mockImplementation(async () => Promise.resolve()), +})) -jest.mock('delay', () => ({ - __esModule: true, - default: jest.fn().mockImplementation(async () => Promise.resolve()) -})); +jest.mock("delay", () => ({ + __esModule: true, + default: jest.fn().mockImplementation(async () => Promise.resolve()), +})) -jest.mock('serialize-error', () => ({ - __esModule: true, - serializeError: jest.fn().mockImplementation((error) => ({ - name: error.name, - message: error.message, - stack: error.stack - })) -})); +jest.mock("serialize-error", () => ({ + __esModule: true, + serializeError: jest.fn().mockImplementation((error) => ({ + name: error.name, + message: error.message, + stack: error.stack, + })), +})) -jest.mock('strip-ansi', () => ({ - __esModule: true, - default: jest.fn().mockImplementation((str) => str.replace(/\u001B\[\d+m/g, '')) -})); +jest.mock("strip-ansi", () => ({ + __esModule: true, + default: jest.fn().mockImplementation((str) => str.replace(/\u001B\[\d+m/g, "")), +})) -jest.mock('globby', () => ({ - __esModule: true, - globby: jest.fn().mockImplementation(async () => []) -})); +jest.mock("globby", () => ({ + __esModule: true, + globby: jest.fn().mockImplementation(async () => []), +})) -jest.mock('os-name', () => ({ - __esModule: true, - default: jest.fn().mockReturnValue('Mock OS Name') -})); +jest.mock("os-name", () => ({ + __esModule: true, + default: jest.fn().mockReturnValue("Mock OS Name"), +})) -jest.mock('default-shell', () => ({ - __esModule: true, - default: '/bin/bash' // Mock default shell path -})); +jest.mock("default-shell", () => ({ + __esModule: true, + default: "/bin/bash", // Mock default shell path +})) -describe('Cline', () => { - let mockProvider: jest.Mocked; - let mockApiConfig: ApiConfiguration; - let mockOutputChannel: any; - let mockExtensionContext: vscode.ExtensionContext; - - beforeEach(() => { - // Setup mock extension context - mockExtensionContext = { - globalState: { - get: jest.fn().mockImplementation((key) => { - if (key === 'taskHistory') { - return [{ - id: '123', - ts: Date.now(), - task: 'historical task', - tokensIn: 100, - tokensOut: 200, - cacheWrites: 0, - cacheReads: 0, - totalCost: 0.001 - }]; - } - return undefined; - }), - update: jest.fn().mockImplementation((key, value) => Promise.resolve()), - keys: jest.fn().mockReturnValue([]) - }, - workspaceState: { - get: jest.fn().mockImplementation((key) => undefined), - update: jest.fn().mockImplementation((key, value) => Promise.resolve()), - keys: jest.fn().mockReturnValue([]) - }, - secrets: { - get: jest.fn().mockImplementation((key) => Promise.resolve(undefined)), - store: jest.fn().mockImplementation((key, value) => Promise.resolve()), - delete: jest.fn().mockImplementation((key) => Promise.resolve()) - }, - extensionUri: { - fsPath: '/mock/extension/path' - }, - globalStorageUri: { - fsPath: '/mock/storage/path' - }, - extension: { - packageJSON: { - version: '1.0.0' - } - } - } as unknown as vscode.ExtensionContext; +describe("Cline", () => { + let mockProvider: jest.Mocked + let mockApiConfig: ApiConfiguration + let mockOutputChannel: any + let mockExtensionContext: vscode.ExtensionContext - // Setup mock output channel - mockOutputChannel = { - appendLine: jest.fn(), - append: jest.fn(), - clear: jest.fn(), - show: jest.fn(), - hide: jest.fn(), - dispose: jest.fn() - }; + beforeEach(() => { + // Setup mock extension context + mockExtensionContext = { + globalState: { + get: jest.fn().mockImplementation((key) => { + if (key === "taskHistory") { + return [ + { + id: "123", + ts: Date.now(), + task: "historical task", + tokensIn: 100, + tokensOut: 200, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0.001, + }, + ] + } + return undefined + }), + update: jest.fn().mockImplementation((key, value) => Promise.resolve()), + keys: jest.fn().mockReturnValue([]), + }, + workspaceState: { + get: jest.fn().mockImplementation((key) => undefined), + update: jest.fn().mockImplementation((key, value) => Promise.resolve()), + keys: jest.fn().mockReturnValue([]), + }, + secrets: { + get: jest.fn().mockImplementation((key) => Promise.resolve(undefined)), + store: jest.fn().mockImplementation((key, value) => Promise.resolve()), + delete: jest.fn().mockImplementation((key) => Promise.resolve()), + }, + extensionUri: { + fsPath: "/mock/extension/path", + }, + globalStorageUri: { + fsPath: "/mock/storage/path", + }, + extension: { + packageJSON: { + version: "1.0.0", + }, + }, + } as unknown as vscode.ExtensionContext - // Setup mock provider with output channel - mockProvider = new ClineProvider(mockExtensionContext, mockOutputChannel) as jest.Mocked; - - // Setup mock API configuration - mockApiConfig = { - apiProvider: 'anthropic', - apiModelId: 'claude-3-5-sonnet-20241022', - apiKey: 'test-api-key' // Add API key to mock config - }; + // Setup mock output channel + mockOutputChannel = { + appendLine: jest.fn(), + append: jest.fn(), + clear: jest.fn(), + show: jest.fn(), + hide: jest.fn(), + dispose: jest.fn(), + } - // Mock provider methods - mockProvider.postMessageToWebview = jest.fn().mockResolvedValue(undefined); - mockProvider.postStateToWebview = jest.fn().mockResolvedValue(undefined); - mockProvider.getTaskWithId = jest.fn().mockImplementation(async (id) => ({ - historyItem: { - id, - ts: Date.now(), - task: 'historical task', - tokensIn: 100, - tokensOut: 200, - cacheWrites: 0, - cacheReads: 0, - totalCost: 0.001 - }, - taskDirPath: '/mock/storage/path/tasks/123', - apiConversationHistoryFilePath: '/mock/storage/path/tasks/123/api_conversation_history.json', - uiMessagesFilePath: '/mock/storage/path/tasks/123/ui_messages.json', - apiConversationHistory: [] - })); - }); + // Setup mock provider with output channel + mockProvider = new ClineProvider(mockExtensionContext, mockOutputChannel) as jest.Mocked - describe('constructor', () => { - it('should respect provided settings', () => { - const cline = new Cline( - mockProvider, - mockApiConfig, - 'custom instructions', - false, - 0.95, // 95% threshold - 'test task' - ); + // Setup mock API configuration + mockApiConfig = { + apiProvider: "anthropic", + apiModelId: "claude-3-5-sonnet-20241022", + apiKey: "test-api-key", // Add API key to mock config + } - expect(cline.customInstructions).toBe('custom instructions'); - expect(cline.diffEnabled).toBe(false); - }); + // Mock provider methods + mockProvider.postMessageToWebview = jest.fn().mockResolvedValue(undefined) + mockProvider.postStateToWebview = jest.fn().mockResolvedValue(undefined) + mockProvider.getTaskWithId = jest.fn().mockImplementation(async (id) => ({ + historyItem: { + id, + ts: Date.now(), + task: "historical task", + tokensIn: 100, + tokensOut: 200, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0.001, + }, + taskDirPath: "/mock/storage/path/tasks/123", + apiConversationHistoryFilePath: "/mock/storage/path/tasks/123/api_conversation_history.json", + uiMessagesFilePath: "/mock/storage/path/tasks/123/ui_messages.json", + apiConversationHistory: [], + })) + }) - it('should use default fuzzy match threshold when not provided', () => { - const cline = new Cline( - mockProvider, - mockApiConfig, - 'custom instructions', - true, - undefined, - 'test task' - ); + describe("constructor", () => { + it("should respect provided settings", () => { + const cline = new Cline( + mockProvider, + mockApiConfig, + "custom instructions", + false, + 0.95, // 95% threshold + "test task", + ) - expect(cline.diffEnabled).toBe(true); - // The diff strategy should be created with default threshold (1.0) - expect(cline.diffStrategy).toBeDefined(); - }); + expect(cline.customInstructions).toBe("custom instructions") + expect(cline.diffEnabled).toBe(false) + }) - it('should use provided fuzzy match threshold', () => { - const getDiffStrategySpy = jest.spyOn(require('../diff/DiffStrategy'), 'getDiffStrategy'); - - const cline = new Cline( - mockProvider, - mockApiConfig, - 'custom instructions', - true, - 0.9, // 90% threshold - 'test task' - ); + it("should use default fuzzy match threshold when not provided", () => { + const cline = new Cline(mockProvider, mockApiConfig, "custom instructions", true, undefined, "test task") - expect(cline.diffEnabled).toBe(true); - expect(cline.diffStrategy).toBeDefined(); - expect(getDiffStrategySpy).toHaveBeenCalledWith('claude-3-5-sonnet-20241022', 0.9, false); - - getDiffStrategySpy.mockRestore(); - }); + expect(cline.diffEnabled).toBe(true) + // The diff strategy should be created with default threshold (1.0) + expect(cline.diffStrategy).toBeDefined() + }) - it('should pass default threshold to diff strategy when not provided', () => { - const getDiffStrategySpy = jest.spyOn(require('../diff/DiffStrategy'), 'getDiffStrategy'); - - const cline = new Cline( - mockProvider, - mockApiConfig, - 'custom instructions', - true, - undefined, - 'test task' - ); + it("should use provided fuzzy match threshold", () => { + const getDiffStrategySpy = jest.spyOn(require("../diff/DiffStrategy"), "getDiffStrategy") - expect(cline.diffEnabled).toBe(true); - expect(cline.diffStrategy).toBeDefined(); - expect(getDiffStrategySpy).toHaveBeenCalledWith('claude-3-5-sonnet-20241022', 1.0, false); - - getDiffStrategySpy.mockRestore(); - }); + const cline = new Cline( + mockProvider, + mockApiConfig, + "custom instructions", + true, + 0.9, // 90% threshold + "test task", + ) - it('should require either task or historyItem', () => { - expect(() => { - new Cline( - mockProvider, - mockApiConfig, - undefined, // customInstructions - false, // diffEnabled - undefined, // fuzzyMatchThreshold - undefined // task - ); - }).toThrow('Either historyItem or task/images must be provided'); - }); - }); + expect(cline.diffEnabled).toBe(true) + expect(cline.diffStrategy).toBeDefined() + expect(getDiffStrategySpy).toHaveBeenCalledWith("claude-3-5-sonnet-20241022", 0.9, false) - describe('getEnvironmentDetails', () => { - let originalDate: DateConstructor; - let mockDate: Date; + getDiffStrategySpy.mockRestore() + }) - beforeEach(() => { - originalDate = global.Date; - const fixedTime = new Date('2024-01-01T12:00:00Z'); - mockDate = new Date(fixedTime); - mockDate.getTimezoneOffset = jest.fn().mockReturnValue(420); // UTC-7 + it("should pass default threshold to diff strategy when not provided", () => { + const getDiffStrategySpy = jest.spyOn(require("../diff/DiffStrategy"), "getDiffStrategy") - class MockDate extends Date { - constructor() { - super(); - return mockDate; - } - static override now() { - return mockDate.getTime(); - } - } - - global.Date = MockDate as DateConstructor; + const cline = new Cline(mockProvider, mockApiConfig, "custom instructions", true, undefined, "test task") - // Create a proper mock of Intl.DateTimeFormat - const mockDateTimeFormat = { - resolvedOptions: () => ({ - timeZone: 'America/Los_Angeles' - }), - format: () => '1/1/2024, 5:00:00 AM' - }; + expect(cline.diffEnabled).toBe(true) + expect(cline.diffStrategy).toBeDefined() + expect(getDiffStrategySpy).toHaveBeenCalledWith("claude-3-5-sonnet-20241022", 1.0, false) - const MockDateTimeFormat = function(this: any) { - return mockDateTimeFormat; - } as any; + getDiffStrategySpy.mockRestore() + }) - MockDateTimeFormat.prototype = mockDateTimeFormat; - MockDateTimeFormat.supportedLocalesOf = jest.fn().mockReturnValue(['en-US']); + it("should require either task or historyItem", () => { + expect(() => { + new Cline( + mockProvider, + mockApiConfig, + undefined, // customInstructions + false, // diffEnabled + undefined, // fuzzyMatchThreshold + undefined, // task + ) + }).toThrow("Either historyItem or task/images must be provided") + }) + }) - global.Intl.DateTimeFormat = MockDateTimeFormat; - }); + describe("getEnvironmentDetails", () => { + let originalDate: DateConstructor + let mockDate: Date - afterEach(() => { - global.Date = originalDate; - }); + beforeEach(() => { + originalDate = global.Date + const fixedTime = new Date("2024-01-01T12:00:00Z") + mockDate = new Date(fixedTime) + mockDate.getTimezoneOffset = jest.fn().mockReturnValue(420) // UTC-7 - it('should include timezone information in environment details', async () => { - const cline = new Cline( - mockProvider, - mockApiConfig, - undefined, - false, - undefined, - 'test task' - ); + class MockDate extends Date { + constructor() { + super() + return mockDate + } + static override now() { + return mockDate.getTime() + } + } - const details = await cline['getEnvironmentDetails'](false); - - // Verify timezone information is present and formatted correctly - expect(details).toContain('America/Los_Angeles'); - expect(details).toMatch(/UTC-7:00/); // Fixed offset for America/Los_Angeles - expect(details).toContain('# Current Time'); - expect(details).toMatch(/1\/1\/2024.*5:00:00 AM.*\(America\/Los_Angeles, UTC-7:00\)/); // Full time string format - }); - - describe('API conversation handling', () => { - it('should clean conversation history before sending to API', async () => { - const cline = new Cline( - mockProvider, - mockApiConfig, - undefined, - false, - undefined, - 'test task' - ); - - // Mock the API's createMessage method to capture the conversation history - const createMessageSpy = jest.fn(); - const mockStream = { - async *[Symbol.asyncIterator]() { - yield { type: 'text', text: '' }; - }, - async next() { - return { done: true, value: undefined }; - }, - async return() { - return { done: true, value: undefined }; - }, - async throw(e: any) { - throw e; - }, - async [Symbol.asyncDispose]() { - // Cleanup - } - } as AsyncGenerator; - - jest.spyOn(cline.api, 'createMessage').mockImplementation((...args) => { - createMessageSpy(...args); - return mockStream; - }); + global.Date = MockDate as DateConstructor - // Add a message with extra properties to the conversation history - const messageWithExtra = { - role: 'user' as const, - content: [{ type: 'text' as const, text: 'test message' }], - ts: Date.now(), - extraProp: 'should be removed' - }; - cline.apiConversationHistory = [messageWithExtra]; + // Create a proper mock of Intl.DateTimeFormat + const mockDateTimeFormat = { + resolvedOptions: () => ({ + timeZone: "America/Los_Angeles", + }), + format: () => "1/1/2024, 5:00:00 AM", + } - // Trigger an API request - await cline.recursivelyMakeClineRequests([ - { type: 'text', text: 'test request' } - ]); + const MockDateTimeFormat = function (this: any) { + return mockDateTimeFormat + } as any - // Get all calls to createMessage - const calls = createMessageSpy.mock.calls; - - // Find the call that includes our test message - const relevantCall = calls.find(call => - call[1]?.some((msg: any) => - msg.content?.[0]?.text === 'test message' - ) - ); + MockDateTimeFormat.prototype = mockDateTimeFormat + MockDateTimeFormat.supportedLocalesOf = jest.fn().mockReturnValue(["en-US"]) - // Verify the conversation history was cleaned in the relevant call - expect(relevantCall?.[1]).toEqual( - expect.arrayContaining([ - { - role: 'user', - content: [{ type: 'text', text: 'test message' }] - } - ]) - ); + global.Intl.DateTimeFormat = MockDateTimeFormat + }) - // Verify extra properties were removed - const passedMessage = relevantCall?.[1].find((msg: any) => - msg.content?.[0]?.text === 'test message' - ); - expect(passedMessage).not.toHaveProperty('ts'); - expect(passedMessage).not.toHaveProperty('extraProp'); - }); + afterEach(() => { + global.Date = originalDate + }) - it('should handle image blocks based on model capabilities', async () => { - // Create two configurations - one with image support, one without - const configWithImages = { - ...mockApiConfig, - apiModelId: 'claude-3-sonnet' - }; - const configWithoutImages = { - ...mockApiConfig, - apiModelId: 'gpt-3.5-turbo' - }; + it("should include timezone information in environment details", async () => { + const cline = new Cline(mockProvider, mockApiConfig, undefined, false, undefined, "test task") - // Create test conversation history with mixed content - const conversationHistory: (Anthropic.MessageParam & { ts?: number })[] = [ - { - role: 'user' as const, - content: [ - { - type: 'text' as const, - text: 'Here is an image' - } satisfies Anthropic.TextBlockParam, - { - type: 'image' as const, - source: { - type: 'base64' as const, - media_type: 'image/jpeg', - data: 'base64data' - } - } satisfies Anthropic.ImageBlockParam - ] - }, - { - role: 'assistant' as const, - content: [{ - type: 'text' as const, - text: 'I see the image' - } satisfies Anthropic.TextBlockParam] - } - ]; + const details = await cline["getEnvironmentDetails"](false) - // Test with model that supports images - const clineWithImages = new Cline( - mockProvider, - configWithImages, - undefined, - false, - undefined, - 'test task' - ); - // Mock the model info to indicate image support - jest.spyOn(clineWithImages.api, 'getModel').mockReturnValue({ - id: 'claude-3-sonnet', - info: { - supportsImages: true, - supportsPromptCache: true, - supportsComputerUse: true, - contextWindow: 200000, - maxTokens: 4096, - inputPrice: 0.25, - outputPrice: 0.75 - } as ModelInfo - }); - clineWithImages.apiConversationHistory = conversationHistory; + // Verify timezone information is present and formatted correctly + expect(details).toContain("America/Los_Angeles") + expect(details).toMatch(/UTC-7:00/) // Fixed offset for America/Los_Angeles + expect(details).toContain("# Current Time") + expect(details).toMatch(/1\/1\/2024.*5:00:00 AM.*\(America\/Los_Angeles, UTC-7:00\)/) // Full time string format + }) - // Test with model that doesn't support images - const clineWithoutImages = new Cline( - mockProvider, - configWithoutImages, - undefined, - false, - undefined, - 'test task' - ); - // Mock the model info to indicate no image support - jest.spyOn(clineWithoutImages.api, 'getModel').mockReturnValue({ - id: 'gpt-3.5-turbo', - info: { - supportsImages: false, - supportsPromptCache: false, - supportsComputerUse: false, - contextWindow: 16000, - maxTokens: 2048, - inputPrice: 0.1, - outputPrice: 0.2 - } as ModelInfo - }); - clineWithoutImages.apiConversationHistory = conversationHistory; + describe("API conversation handling", () => { + it("should clean conversation history before sending to API", async () => { + const cline = new Cline(mockProvider, mockApiConfig, undefined, false, undefined, "test task") - // Create message spy for both instances - const createMessageSpyWithImages = jest.fn(); - const createMessageSpyWithoutImages = jest.fn(); - const mockStream = { - async *[Symbol.asyncIterator]() { - yield { type: 'text', text: '' }; - } - } as AsyncGenerator; + // Mock the API's createMessage method to capture the conversation history + const createMessageSpy = jest.fn() + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { type: "text", text: "" } + }, + async next() { + return { done: true, value: undefined } + }, + async return() { + return { done: true, value: undefined } + }, + async throw(e: any) { + throw e + }, + async [Symbol.asyncDispose]() { + // Cleanup + }, + } as AsyncGenerator - jest.spyOn(clineWithImages.api, 'createMessage').mockImplementation((...args) => { - createMessageSpyWithImages(...args); - return mockStream; - }); - jest.spyOn(clineWithoutImages.api, 'createMessage').mockImplementation((...args) => { - createMessageSpyWithoutImages(...args); - return mockStream; - }); + jest.spyOn(cline.api, "createMessage").mockImplementation((...args) => { + createMessageSpy(...args) + return mockStream + }) - // Trigger API requests for both instances - await clineWithImages.recursivelyMakeClineRequests([{ type: 'text', text: 'test' }]); - await clineWithoutImages.recursivelyMakeClineRequests([{ type: 'text', text: 'test' }]); + // Add a message with extra properties to the conversation history + const messageWithExtra = { + role: "user" as const, + content: [{ type: "text" as const, text: "test message" }], + ts: Date.now(), + extraProp: "should be removed", + } + cline.apiConversationHistory = [messageWithExtra] - // Verify model with image support preserves image blocks - const callsWithImages = createMessageSpyWithImages.mock.calls; - const historyWithImages = callsWithImages[0][1][0]; - expect(historyWithImages.content).toHaveLength(2); - expect(historyWithImages.content[0]).toEqual({ type: 'text', text: 'Here is an image' }); - expect(historyWithImages.content[1]).toHaveProperty('type', 'image'); + // Trigger an API request + await cline.recursivelyMakeClineRequests([{ type: "text", text: "test request" }]) - // Verify model without image support converts image blocks to text - const callsWithoutImages = createMessageSpyWithoutImages.mock.calls; - const historyWithoutImages = callsWithoutImages[0][1][0]; - expect(historyWithoutImages.content).toHaveLength(2); - expect(historyWithoutImages.content[0]).toEqual({ type: 'text', text: 'Here is an image' }); - expect(historyWithoutImages.content[1]).toEqual({ - type: 'text', - text: '[Referenced image in conversation]' - }); - }); - - it('should handle API retry with countdown', async () => { - const cline = new Cline( - mockProvider, - mockApiConfig, - undefined, - false, - undefined, - 'test task' - ); + // Get all calls to createMessage + const calls = createMessageSpy.mock.calls - // Mock delay to track countdown timing - const mockDelay = jest.fn().mockResolvedValue(undefined); - jest.spyOn(require('delay'), 'default').mockImplementation(mockDelay); + // Find the call that includes our test message + const relevantCall = calls.find((call) => + call[1]?.some((msg: any) => msg.content?.[0]?.text === "test message"), + ) - // Mock say to track messages - const saySpy = jest.spyOn(cline, 'say'); + // Verify the conversation history was cleaned in the relevant call + expect(relevantCall?.[1]).toEqual( + expect.arrayContaining([ + { + role: "user", + content: [{ type: "text", text: "test message" }], + }, + ]), + ) - // Create a stream that fails on first chunk - const mockError = new Error('API Error'); - const mockFailedStream = { - async *[Symbol.asyncIterator]() { - throw mockError; - }, - async next() { - throw mockError; - }, - async return() { - return { done: true, value: undefined }; - }, - async throw(e: any) { - throw e; - }, - async [Symbol.asyncDispose]() { - // Cleanup - } - } as AsyncGenerator; + // Verify extra properties were removed + const passedMessage = relevantCall?.[1].find((msg: any) => msg.content?.[0]?.text === "test message") + expect(passedMessage).not.toHaveProperty("ts") + expect(passedMessage).not.toHaveProperty("extraProp") + }) - // Create a successful stream for retry - const mockSuccessStream = { - async *[Symbol.asyncIterator]() { - yield { type: 'text', text: 'Success' }; - }, - async next() { - return { done: true, value: { type: 'text', text: 'Success' } }; - }, - async return() { - return { done: true, value: undefined }; - }, - async throw(e: any) { - throw e; - }, - async [Symbol.asyncDispose]() { - // Cleanup - } - } as AsyncGenerator; + it("should handle image blocks based on model capabilities", async () => { + // Create two configurations - one with image support, one without + const configWithImages = { + ...mockApiConfig, + apiModelId: "claude-3-sonnet", + } + const configWithoutImages = { + ...mockApiConfig, + apiModelId: "gpt-3.5-turbo", + } - // Mock createMessage to fail first then succeed - let firstAttempt = true; - jest.spyOn(cline.api, 'createMessage').mockImplementation(() => { - if (firstAttempt) { - firstAttempt = false; - return mockFailedStream; - } - return mockSuccessStream; - }); + // Create test conversation history with mixed content + const conversationHistory: (Anthropic.MessageParam & { ts?: number })[] = [ + { + role: "user" as const, + content: [ + { + type: "text" as const, + text: "Here is an image", + } satisfies Anthropic.TextBlockParam, + { + type: "image" as const, + source: { + type: "base64" as const, + media_type: "image/jpeg", + data: "base64data", + }, + } satisfies Anthropic.ImageBlockParam, + ], + }, + { + role: "assistant" as const, + content: [ + { + type: "text" as const, + text: "I see the image", + } satisfies Anthropic.TextBlockParam, + ], + }, + ] - // Set alwaysApproveResubmit and requestDelaySeconds - mockProvider.getState = jest.fn().mockResolvedValue({ - alwaysApproveResubmit: true, - requestDelaySeconds: 3 - }); + // Test with model that supports images + const clineWithImages = new Cline( + mockProvider, + configWithImages, + undefined, + false, + undefined, + "test task", + ) + // Mock the model info to indicate image support + jest.spyOn(clineWithImages.api, "getModel").mockReturnValue({ + id: "claude-3-sonnet", + info: { + supportsImages: true, + supportsPromptCache: true, + supportsComputerUse: true, + contextWindow: 200000, + maxTokens: 4096, + inputPrice: 0.25, + outputPrice: 0.75, + } as ModelInfo, + }) + clineWithImages.apiConversationHistory = conversationHistory - // Mock previous API request message - cline.clineMessages = [{ - ts: Date.now(), - type: 'say', - say: 'api_req_started', - text: JSON.stringify({ - tokensIn: 100, - tokensOut: 50, - cacheWrites: 0, - cacheReads: 0, - request: 'test request' - }) - }]; + // Test with model that doesn't support images + const clineWithoutImages = new Cline( + mockProvider, + configWithoutImages, + undefined, + false, + undefined, + "test task", + ) + // Mock the model info to indicate no image support + jest.spyOn(clineWithoutImages.api, "getModel").mockReturnValue({ + id: "gpt-3.5-turbo", + info: { + supportsImages: false, + supportsPromptCache: false, + supportsComputerUse: false, + contextWindow: 16000, + maxTokens: 2048, + inputPrice: 0.1, + outputPrice: 0.2, + } as ModelInfo, + }) + clineWithoutImages.apiConversationHistory = conversationHistory - // Trigger API request - const iterator = cline.attemptApiRequest(0); - await iterator.next(); + // Create message spy for both instances + const createMessageSpyWithImages = jest.fn() + const createMessageSpyWithoutImages = jest.fn() + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { type: "text", text: "" } + }, + } as AsyncGenerator - // Verify countdown messages - expect(saySpy).toHaveBeenCalledWith( - 'api_req_retry_delayed', - expect.stringContaining('Retrying in 3 seconds'), - undefined, - true - ); - expect(saySpy).toHaveBeenCalledWith( - 'api_req_retry_delayed', - expect.stringContaining('Retrying in 2 seconds'), - undefined, - true - ); - expect(saySpy).toHaveBeenCalledWith( - 'api_req_retry_delayed', - expect.stringContaining('Retrying in 1 seconds'), - undefined, - true - ); - expect(saySpy).toHaveBeenCalledWith( - 'api_req_retry_delayed', - expect.stringContaining('Retrying now'), - undefined, - false - ); + jest.spyOn(clineWithImages.api, "createMessage").mockImplementation((...args) => { + createMessageSpyWithImages(...args) + return mockStream + }) + jest.spyOn(clineWithoutImages.api, "createMessage").mockImplementation((...args) => { + createMessageSpyWithoutImages(...args) + return mockStream + }) - // Verify delay was called correctly - expect(mockDelay).toHaveBeenCalledTimes(3); - expect(mockDelay).toHaveBeenCalledWith(1000); + // Trigger API requests for both instances + await clineWithImages.recursivelyMakeClineRequests([{ type: "text", text: "test" }]) + await clineWithoutImages.recursivelyMakeClineRequests([{ type: "text", text: "test" }]) - // Verify error message content - const errorMessage = saySpy.mock.calls.find( - call => call[1]?.includes(mockError.message) - )?.[1]; - expect(errorMessage).toBe(`${mockError.message}\n\nRetrying in 3 seconds...`); - }); + // Verify model with image support preserves image blocks + const callsWithImages = createMessageSpyWithImages.mock.calls + const historyWithImages = callsWithImages[0][1][0] + expect(historyWithImages.content).toHaveLength(2) + expect(historyWithImages.content[0]).toEqual({ type: "text", text: "Here is an image" }) + expect(historyWithImages.content[1]).toHaveProperty("type", "image") - describe('loadContext', () => { - it('should process mentions in task and feedback tags', async () => { - const cline = new Cline( - mockProvider, - mockApiConfig, - undefined, - false, - undefined, - 'test task' - ); - - // Mock parseMentions to track calls - const mockParseMentions = jest.fn().mockImplementation(text => `processed: ${text}`); - jest.spyOn(require('../../core/mentions'), 'parseMentions').mockImplementation(mockParseMentions); - - const userContent = [ - { - type: 'text', - text: 'Regular text with @/some/path' - } as const, - { - type: 'text', - text: 'Text with @/some/path in task tags' - } as const, - { - type: 'tool_result', - tool_use_id: 'test-id', - content: [{ - type: 'text', - text: 'Check @/some/path' - }] - } as Anthropic.ToolResultBlockParam, - { - type: 'tool_result', - tool_use_id: 'test-id-2', - content: [{ - type: 'text', - text: 'Regular tool result with @/path' - }] - } as Anthropic.ToolResultBlockParam - ]; - - // Process the content - const [processedContent] = await cline['loadContext'](userContent); - - // Regular text should not be processed - expect((processedContent[0] as Anthropic.TextBlockParam).text) - .toBe('Regular text with @/some/path'); - - // Text within task tags should be processed - expect((processedContent[1] as Anthropic.TextBlockParam).text) - .toContain('processed:'); - expect(mockParseMentions).toHaveBeenCalledWith( - 'Text with @/some/path in task tags', - expect.any(String), - expect.any(Object) - ); - - // Feedback tag content should be processed - const toolResult1 = processedContent[2] as Anthropic.ToolResultBlockParam; - const content1 = Array.isArray(toolResult1.content) ? toolResult1.content[0] : toolResult1.content; - expect((content1 as Anthropic.TextBlockParam).text).toContain('processed:'); - expect(mockParseMentions).toHaveBeenCalledWith( - 'Check @/some/path', - expect.any(String), - expect.any(Object) - ); - - // Regular tool result should not be processed - const toolResult2 = processedContent[3] as Anthropic.ToolResultBlockParam; - const content2 = Array.isArray(toolResult2.content) ? toolResult2.content[0] : toolResult2.content; - expect((content2 as Anthropic.TextBlockParam).text) - .toBe('Regular tool result with @/path'); - }); - }); - }); - }); -}); + // Verify model without image support converts image blocks to text + const callsWithoutImages = createMessageSpyWithoutImages.mock.calls + const historyWithoutImages = callsWithoutImages[0][1][0] + expect(historyWithoutImages.content).toHaveLength(2) + expect(historyWithoutImages.content[0]).toEqual({ type: "text", text: "Here is an image" }) + expect(historyWithoutImages.content[1]).toEqual({ + type: "text", + text: "[Referenced image in conversation]", + }) + }) + + it("should handle API retry with countdown", async () => { + const cline = new Cline(mockProvider, mockApiConfig, undefined, false, undefined, "test task") + + // Mock delay to track countdown timing + const mockDelay = jest.fn().mockResolvedValue(undefined) + jest.spyOn(require("delay"), "default").mockImplementation(mockDelay) + + // Mock say to track messages + const saySpy = jest.spyOn(cline, "say") + + // Create a stream that fails on first chunk + const mockError = new Error("API Error") + const mockFailedStream = { + async *[Symbol.asyncIterator]() { + throw mockError + }, + async next() { + throw mockError + }, + async return() { + return { done: true, value: undefined } + }, + async throw(e: any) { + throw e + }, + async [Symbol.asyncDispose]() { + // Cleanup + }, + } as AsyncGenerator + + // Create a successful stream for retry + const mockSuccessStream = { + async *[Symbol.asyncIterator]() { + yield { type: "text", text: "Success" } + }, + async next() { + return { done: true, value: { type: "text", text: "Success" } } + }, + async return() { + return { done: true, value: undefined } + }, + async throw(e: any) { + throw e + }, + async [Symbol.asyncDispose]() { + // Cleanup + }, + } as AsyncGenerator + + // Mock createMessage to fail first then succeed + let firstAttempt = true + jest.spyOn(cline.api, "createMessage").mockImplementation(() => { + if (firstAttempt) { + firstAttempt = false + return mockFailedStream + } + return mockSuccessStream + }) + + // Set alwaysApproveResubmit and requestDelaySeconds + mockProvider.getState = jest.fn().mockResolvedValue({ + alwaysApproveResubmit: true, + requestDelaySeconds: 3, + }) + + // Mock previous API request message + cline.clineMessages = [ + { + ts: Date.now(), + type: "say", + say: "api_req_started", + text: JSON.stringify({ + tokensIn: 100, + tokensOut: 50, + cacheWrites: 0, + cacheReads: 0, + request: "test request", + }), + }, + ] + + // Trigger API request + const iterator = cline.attemptApiRequest(0) + await iterator.next() + + // Verify countdown messages + expect(saySpy).toHaveBeenCalledWith( + "api_req_retry_delayed", + expect.stringContaining("Retrying in 3 seconds"), + undefined, + true, + ) + expect(saySpy).toHaveBeenCalledWith( + "api_req_retry_delayed", + expect.stringContaining("Retrying in 2 seconds"), + undefined, + true, + ) + expect(saySpy).toHaveBeenCalledWith( + "api_req_retry_delayed", + expect.stringContaining("Retrying in 1 seconds"), + undefined, + true, + ) + expect(saySpy).toHaveBeenCalledWith( + "api_req_retry_delayed", + expect.stringContaining("Retrying now"), + undefined, + false, + ) + + // Verify delay was called correctly + expect(mockDelay).toHaveBeenCalledTimes(3) + expect(mockDelay).toHaveBeenCalledWith(1000) + + // Verify error message content + const errorMessage = saySpy.mock.calls.find((call) => call[1]?.includes(mockError.message))?.[1] + expect(errorMessage).toBe(`${mockError.message}\n\nRetrying in 3 seconds...`) + }) + + describe("loadContext", () => { + it("should process mentions in task and feedback tags", async () => { + const cline = new Cline(mockProvider, mockApiConfig, undefined, false, undefined, "test task") + + // Mock parseMentions to track calls + const mockParseMentions = jest.fn().mockImplementation((text) => `processed: ${text}`) + jest.spyOn(require("../../core/mentions"), "parseMentions").mockImplementation(mockParseMentions) + + const userContent = [ + { + type: "text", + text: "Regular text with @/some/path", + } as const, + { + type: "text", + text: "Text with @/some/path in task tags", + } as const, + { + type: "tool_result", + tool_use_id: "test-id", + content: [ + { + type: "text", + text: "Check @/some/path", + }, + ], + } as Anthropic.ToolResultBlockParam, + { + type: "tool_result", + tool_use_id: "test-id-2", + content: [ + { + type: "text", + text: "Regular tool result with @/path", + }, + ], + } as Anthropic.ToolResultBlockParam, + ] + + // Process the content + const [processedContent] = await cline["loadContext"](userContent) + + // Regular text should not be processed + expect((processedContent[0] as Anthropic.TextBlockParam).text).toBe("Regular text with @/some/path") + + // Text within task tags should be processed + expect((processedContent[1] as Anthropic.TextBlockParam).text).toContain("processed:") + expect(mockParseMentions).toHaveBeenCalledWith( + "Text with @/some/path in task tags", + expect.any(String), + expect.any(Object), + ) + + // Feedback tag content should be processed + const toolResult1 = processedContent[2] as Anthropic.ToolResultBlockParam + const content1 = Array.isArray(toolResult1.content) ? toolResult1.content[0] : toolResult1.content + expect((content1 as Anthropic.TextBlockParam).text).toContain("processed:") + expect(mockParseMentions).toHaveBeenCalledWith( + "Check @/some/path", + expect.any(String), + expect.any(Object), + ) + + // Regular tool result should not be processed + const toolResult2 = processedContent[3] as Anthropic.ToolResultBlockParam + const content2 = Array.isArray(toolResult2.content) ? toolResult2.content[0] : toolResult2.content + expect((content2 as Anthropic.TextBlockParam).text).toBe("Regular tool result with @/path") + }) + }) + }) + }) +}) diff --git a/src/core/__tests__/mode-validator.test.ts b/src/core/__tests__/mode-validator.test.ts index 4bcabb7..6842e52 100644 --- a/src/core/__tests__/mode-validator.test.ts +++ b/src/core/__tests__/mode-validator.test.ts @@ -1,52 +1,52 @@ -import { Mode, isToolAllowedForMode, TestToolName, getModeConfig, modes } from '../../shared/modes'; -import { validateToolUse } from '../mode-validator'; +import { Mode, isToolAllowedForMode, TestToolName, getModeConfig, modes } from "../../shared/modes" +import { validateToolUse } from "../mode-validator" -const asTestTool = (tool: string): TestToolName => tool as TestToolName; -const [codeMode, architectMode, askMode] = modes.map(mode => mode.slug); +const asTestTool = (tool: string): TestToolName => tool as TestToolName +const [codeMode, architectMode, askMode] = modes.map((mode) => mode.slug) -describe('mode-validator', () => { - describe('isToolAllowedForMode', () => { - describe('code mode', () => { - it('allows all code mode tools', () => { - const mode = getModeConfig(codeMode); - mode.tools.forEach(([tool]) => { - expect(isToolAllowedForMode(tool, codeMode)).toBe(true) - }) - }) +describe("mode-validator", () => { + describe("isToolAllowedForMode", () => { + describe("code mode", () => { + it("allows all code mode tools", () => { + const mode = getModeConfig(codeMode) + mode.tools.forEach(([tool]) => { + expect(isToolAllowedForMode(tool, codeMode)).toBe(true) + }) + }) - it('disallows unknown tools', () => { - expect(isToolAllowedForMode(asTestTool('unknown_tool'), codeMode)).toBe(false) - }) - }) + it("disallows unknown tools", () => { + expect(isToolAllowedForMode(asTestTool("unknown_tool"), codeMode)).toBe(false) + }) + }) - describe('architect mode', () => { - it('allows configured tools', () => { - const mode = getModeConfig(architectMode); - mode.tools.forEach(([tool]) => { - expect(isToolAllowedForMode(tool, architectMode)).toBe(true) - }) - }) - }) + describe("architect mode", () => { + it("allows configured tools", () => { + const mode = getModeConfig(architectMode) + mode.tools.forEach(([tool]) => { + expect(isToolAllowedForMode(tool, architectMode)).toBe(true) + }) + }) + }) - describe('ask mode', () => { - it('allows configured tools', () => { - const mode = getModeConfig(askMode); - mode.tools.forEach(([tool]) => { - expect(isToolAllowedForMode(tool, askMode)).toBe(true) - }) - }) - }) - }) + describe("ask mode", () => { + it("allows configured tools", () => { + const mode = getModeConfig(askMode) + mode.tools.forEach(([tool]) => { + expect(isToolAllowedForMode(tool, askMode)).toBe(true) + }) + }) + }) + }) - describe('validateToolUse', () => { - it('throws error for disallowed tools in architect mode', () => { - expect(() => validateToolUse('unknown_tool', 'architect')).toThrow( - 'Tool "unknown_tool" is not allowed in architect mode.' - ) - }) + describe("validateToolUse", () => { + it("throws error for disallowed tools in architect mode", () => { + expect(() => validateToolUse("unknown_tool", "architect")).toThrow( + 'Tool "unknown_tool" is not allowed in architect mode.', + ) + }) - it('does not throw for allowed tools in architect mode', () => { - expect(() => validateToolUse('read_file', 'architect')).not.toThrow() - }) - }) -}) \ No newline at end of file + it("does not throw for allowed tools in architect mode", () => { + expect(() => validateToolUse("read_file", "architect")).not.toThrow() + }) + }) +}) diff --git a/src/core/config/ConfigManager.ts b/src/core/config/ConfigManager.ts index 8b7651e..ef3d4d3 100644 --- a/src/core/config/ConfigManager.ts +++ b/src/core/config/ConfigManager.ts @@ -1,221 +1,221 @@ -import { ExtensionContext } from 'vscode' -import { ApiConfiguration } from '../../shared/api' -import { Mode } from '../prompts/types' -import { ApiConfigMeta } from '../../shared/ExtensionMessage' +import { ExtensionContext } from "vscode" +import { ApiConfiguration } from "../../shared/api" +import { Mode } from "../prompts/types" +import { ApiConfigMeta } from "../../shared/ExtensionMessage" export interface ApiConfigData { - currentApiConfigName: string - apiConfigs: { - [key: string]: ApiConfiguration - } - modeApiConfigs?: Partial> + currentApiConfigName: string + apiConfigs: { + [key: string]: ApiConfiguration + } + modeApiConfigs?: Partial> } export class ConfigManager { - private readonly defaultConfig: ApiConfigData = { - currentApiConfigName: 'default', - apiConfigs: { - default: { - id: this.generateId() - } - } - } + private readonly defaultConfig: ApiConfigData = { + currentApiConfigName: "default", + apiConfigs: { + default: { + id: this.generateId(), + }, + }, + } - private readonly SCOPE_PREFIX = "roo_cline_config_" - private readonly context: ExtensionContext + private readonly SCOPE_PREFIX = "roo_cline_config_" + private readonly context: ExtensionContext - constructor(context: ExtensionContext) { - this.context = context - this.initConfig().catch(console.error) - } + constructor(context: ExtensionContext) { + this.context = context + this.initConfig().catch(console.error) + } - private generateId(): string { - return Math.random().toString(36).substring(2, 15) - } + private generateId(): string { + return Math.random().toString(36).substring(2, 15) + } - /** - * Initialize config if it doesn't exist - */ - async initConfig(): Promise { - try { - const config = await this.readConfig() - if (!config) { - await this.writeConfig(this.defaultConfig) - return - } + /** + * Initialize config if it doesn't exist + */ + async initConfig(): Promise { + try { + const config = await this.readConfig() + if (!config) { + await this.writeConfig(this.defaultConfig) + return + } - // Migrate: ensure all configs have IDs - let needsMigration = false - for (const [name, apiConfig] of Object.entries(config.apiConfigs)) { - if (!apiConfig.id) { - apiConfig.id = this.generateId() - needsMigration = true - } - } + // Migrate: ensure all configs have IDs + let needsMigration = false + for (const [name, apiConfig] of Object.entries(config.apiConfigs)) { + if (!apiConfig.id) { + apiConfig.id = this.generateId() + needsMigration = true + } + } - if (needsMigration) { - await this.writeConfig(config) - } - } catch (error) { - throw new Error(`Failed to initialize config: ${error}`) - } - } + if (needsMigration) { + await this.writeConfig(config) + } + } catch (error) { + throw new Error(`Failed to initialize config: ${error}`) + } + } - /** - * List all available configs with metadata - */ - async ListConfig(): Promise { - try { - const config = await this.readConfig() - return Object.entries(config.apiConfigs).map(([name, apiConfig]) => ({ - name, - id: apiConfig.id || '', - apiProvider: apiConfig.apiProvider, - })) - } catch (error) { - throw new Error(`Failed to list configs: ${error}`) - } - } + /** + * List all available configs with metadata + */ + async ListConfig(): Promise { + try { + const config = await this.readConfig() + return Object.entries(config.apiConfigs).map(([name, apiConfig]) => ({ + name, + id: apiConfig.id || "", + apiProvider: apiConfig.apiProvider, + })) + } catch (error) { + throw new Error(`Failed to list configs: ${error}`) + } + } - /** - * Save a config with the given name - */ - async SaveConfig(name: string, config: ApiConfiguration): Promise { - try { - const currentConfig = await this.readConfig() - const existingConfig = currentConfig.apiConfigs[name] - currentConfig.apiConfigs[name] = { - ...config, - id: existingConfig?.id || this.generateId() - } - await this.writeConfig(currentConfig) - } catch (error) { - throw new Error(`Failed to save config: ${error}`) - } - } + /** + * Save a config with the given name + */ + async SaveConfig(name: string, config: ApiConfiguration): Promise { + try { + const currentConfig = await this.readConfig() + const existingConfig = currentConfig.apiConfigs[name] + currentConfig.apiConfigs[name] = { + ...config, + id: existingConfig?.id || this.generateId(), + } + await this.writeConfig(currentConfig) + } catch (error) { + throw new Error(`Failed to save config: ${error}`) + } + } - /** - * Load a config by name - */ - async LoadConfig(name: string): Promise { - try { - const config = await this.readConfig() - const apiConfig = config.apiConfigs[name] - - if (!apiConfig) { - throw new Error(`Config '${name}' not found`) - } - - config.currentApiConfigName = name; - await this.writeConfig(config) - - return apiConfig - } catch (error) { - throw new Error(`Failed to load config: ${error}`) - } - } + /** + * Load a config by name + */ + async LoadConfig(name: string): Promise { + try { + const config = await this.readConfig() + const apiConfig = config.apiConfigs[name] - /** - * Delete a config by name - */ - async DeleteConfig(name: string): Promise { - try { - const currentConfig = await this.readConfig() - if (!currentConfig.apiConfigs[name]) { - throw new Error(`Config '${name}' not found`) - } + if (!apiConfig) { + throw new Error(`Config '${name}' not found`) + } - // Don't allow deleting the default config - if (Object.keys(currentConfig.apiConfigs).length === 1) { - throw new Error(`Cannot delete the last remaining configuration.`) - } + config.currentApiConfigName = name + await this.writeConfig(config) - delete currentConfig.apiConfigs[name] - await this.writeConfig(currentConfig) - } catch (error) { - throw new Error(`Failed to delete config: ${error}`) - } - } + return apiConfig + } catch (error) { + throw new Error(`Failed to load config: ${error}`) + } + } - /** - * Set the current active API configuration - */ - async SetCurrentConfig(name: string): Promise { - try { - const currentConfig = await this.readConfig() - if (!currentConfig.apiConfigs[name]) { - throw new Error(`Config '${name}' not found`) - } + /** + * Delete a config by name + */ + async DeleteConfig(name: string): Promise { + try { + const currentConfig = await this.readConfig() + if (!currentConfig.apiConfigs[name]) { + throw new Error(`Config '${name}' not found`) + } - currentConfig.currentApiConfigName = name - await this.writeConfig(currentConfig) - } catch (error) { - throw new Error(`Failed to set current config: ${error}`) - } - } + // Don't allow deleting the default config + if (Object.keys(currentConfig.apiConfigs).length === 1) { + throw new Error(`Cannot delete the last remaining configuration.`) + } - /** - * Check if a config exists by name - */ - async HasConfig(name: string): Promise { - try { - const config = await this.readConfig() - return name in config.apiConfigs - } catch (error) { - throw new Error(`Failed to check config existence: ${error}`) - } - } + delete currentConfig.apiConfigs[name] + await this.writeConfig(currentConfig) + } catch (error) { + throw new Error(`Failed to delete config: ${error}`) + } + } - /** - * Set the API config for a specific mode - */ - async SetModeConfig(mode: Mode, configId: string): Promise { - try { - const currentConfig = await this.readConfig() - if (!currentConfig.modeApiConfigs) { - currentConfig.modeApiConfigs = {} - } - currentConfig.modeApiConfigs[mode] = configId - await this.writeConfig(currentConfig) - } catch (error) { - throw new Error(`Failed to set mode config: ${error}`) - } - } + /** + * Set the current active API configuration + */ + async SetCurrentConfig(name: string): Promise { + try { + const currentConfig = await this.readConfig() + if (!currentConfig.apiConfigs[name]) { + throw new Error(`Config '${name}' not found`) + } - /** - * Get the API config ID for a specific mode - */ - async GetModeConfigId(mode: Mode): Promise { - try { - const config = await this.readConfig() - return config.modeApiConfigs?.[mode] - } catch (error) { - throw new Error(`Failed to get mode config: ${error}`) - } - } + currentConfig.currentApiConfigName = name + await this.writeConfig(currentConfig) + } catch (error) { + throw new Error(`Failed to set current config: ${error}`) + } + } - private async readConfig(): Promise { - try { - const configKey = `${this.SCOPE_PREFIX}api_config` - const content = await this.context.secrets.get(configKey) - - if (!content) { - return this.defaultConfig - } + /** + * Check if a config exists by name + */ + async HasConfig(name: string): Promise { + try { + const config = await this.readConfig() + return name in config.apiConfigs + } catch (error) { + throw new Error(`Failed to check config existence: ${error}`) + } + } - return JSON.parse(content) - } catch (error) { - throw new Error(`Failed to read config from secrets: ${error}`) - } - } + /** + * Set the API config for a specific mode + */ + async SetModeConfig(mode: Mode, configId: string): Promise { + try { + const currentConfig = await this.readConfig() + if (!currentConfig.modeApiConfigs) { + currentConfig.modeApiConfigs = {} + } + currentConfig.modeApiConfigs[mode] = configId + await this.writeConfig(currentConfig) + } catch (error) { + throw new Error(`Failed to set mode config: ${error}`) + } + } - private async writeConfig(config: ApiConfigData): Promise { - try { - const configKey = `${this.SCOPE_PREFIX}api_config` - const content = JSON.stringify(config, null, 2) - await this.context.secrets.store(configKey, content) - } catch (error) { - throw new Error(`Failed to write config to secrets: ${error}`) - } - } -} \ No newline at end of file + /** + * Get the API config ID for a specific mode + */ + async GetModeConfigId(mode: Mode): Promise { + try { + const config = await this.readConfig() + return config.modeApiConfigs?.[mode] + } catch (error) { + throw new Error(`Failed to get mode config: ${error}`) + } + } + + private async readConfig(): Promise { + try { + const configKey = `${this.SCOPE_PREFIX}api_config` + const content = await this.context.secrets.get(configKey) + + if (!content) { + return this.defaultConfig + } + + return JSON.parse(content) + } catch (error) { + throw new Error(`Failed to read config from secrets: ${error}`) + } + } + + private async writeConfig(config: ApiConfigData): Promise { + try { + const configKey = `${this.SCOPE_PREFIX}api_config` + const content = JSON.stringify(config, null, 2) + await this.context.secrets.store(configKey, content) + } catch (error) { + throw new Error(`Failed to write config to secrets: ${error}`) + } + } +} diff --git a/src/core/config/__tests__/ConfigManager.test.ts b/src/core/config/__tests__/ConfigManager.test.ts index 09a4964..59f36e6 100644 --- a/src/core/config/__tests__/ConfigManager.test.ts +++ b/src/core/config/__tests__/ConfigManager.test.ts @@ -1,452 +1,470 @@ -import { ExtensionContext } from 'vscode' -import { ConfigManager, ApiConfigData } from '../ConfigManager' -import { ApiConfiguration } from '../../../shared/api' +import { ExtensionContext } from "vscode" +import { ConfigManager, ApiConfigData } from "../ConfigManager" +import { ApiConfiguration } from "../../../shared/api" // Mock VSCode ExtensionContext const mockSecrets = { - get: jest.fn(), - store: jest.fn(), - delete: jest.fn() + get: jest.fn(), + store: jest.fn(), + delete: jest.fn(), } const mockContext = { - secrets: mockSecrets + secrets: mockSecrets, } as unknown as ExtensionContext -describe('ConfigManager', () => { - let configManager: ConfigManager - - beforeEach(() => { - jest.clearAllMocks() - configManager = new ConfigManager(mockContext) - }) - - describe('initConfig', () => { - it('should not write to storage when secrets.get returns null', async () => { - // Mock readConfig to return null - mockSecrets.get.mockResolvedValueOnce(null) - - await configManager.initConfig() - - // Should not write to storage because readConfig returns defaultConfig - expect(mockSecrets.store).not.toHaveBeenCalled() - }) - - it('should not initialize config if it exists', async () => { - mockSecrets.get.mockResolvedValue(JSON.stringify({ - currentApiConfigName: 'default', - apiConfigs: { - default: { - config: {}, - id: 'default' - } - } - })) - - await configManager.initConfig() - - expect(mockSecrets.store).not.toHaveBeenCalled() - }) - - it('should generate IDs for configs that lack them', async () => { - // Mock a config with missing IDs - mockSecrets.get.mockResolvedValue(JSON.stringify({ - currentApiConfigName: 'default', - apiConfigs: { - default: { - config: {} - }, - test: { - apiProvider: 'anthropic' - } - } - })) - - await configManager.initConfig() - - // Should have written the config with new IDs - expect(mockSecrets.store).toHaveBeenCalled() - const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1]) - expect(storedConfig.apiConfigs.default.id).toBeTruthy() - expect(storedConfig.apiConfigs.test.id).toBeTruthy() - }) - - it('should throw error if secrets storage fails', async () => { - mockSecrets.get.mockRejectedValue(new Error('Storage failed')) - - await expect(configManager.initConfig()).rejects.toThrow( - 'Failed to initialize config: Error: Failed to read config from secrets: Error: Storage failed' - ) - }) - }) - - describe('ListConfig', () => { - it('should list all available configs', async () => { - const existingConfig: ApiConfigData = { - currentApiConfigName: 'default', - apiConfigs: { - default: { - id: 'default' - }, - test: { - apiProvider: 'anthropic', - id: 'test-id' - } - }, - modeApiConfigs: { - code: 'default', - architect: 'default', - ask: 'default' - } - } - - mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) - - const configs = await configManager.ListConfig() - expect(configs).toEqual([ - { name: 'default', id: 'default', apiProvider: undefined }, - { name: 'test', id: 'test-id', apiProvider: 'anthropic' } - ]) - }) - - it('should handle empty config file', async () => { - const emptyConfig: ApiConfigData = { - currentApiConfigName: 'default', - apiConfigs: {}, - modeApiConfigs: { - code: 'default', - architect: 'default', - ask: 'default' - } - } - - mockSecrets.get.mockResolvedValue(JSON.stringify(emptyConfig)) - - const configs = await configManager.ListConfig() - expect(configs).toEqual([]) - }) - - it('should throw error if reading from secrets fails', async () => { - mockSecrets.get.mockRejectedValue(new Error('Read failed')) - - await expect(configManager.ListConfig()).rejects.toThrow( - 'Failed to list configs: Error: Failed to read config from secrets: Error: Read failed' - ) - }) - }) - - describe('SaveConfig', () => { - it('should save new config', async () => { - mockSecrets.get.mockResolvedValue(JSON.stringify({ - currentApiConfigName: 'default', - apiConfigs: { - default: {} - }, - modeApiConfigs: { - code: 'default', - architect: 'default', - ask: 'default' - } - })) - - const newConfig: ApiConfiguration = { - apiProvider: 'anthropic', - apiKey: 'test-key' - } - - await configManager.SaveConfig('test', newConfig) - - // Get the actual stored config to check the generated ID - const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1]) - const testConfigId = storedConfig.apiConfigs.test.id - - const expectedConfig = { - currentApiConfigName: 'default', - apiConfigs: { - default: {}, - test: { - ...newConfig, - id: testConfigId - } - }, - modeApiConfigs: { - code: 'default', - architect: 'default', - ask: 'default' - } - } - - expect(mockSecrets.store).toHaveBeenCalledWith( - 'roo_cline_config_api_config', - JSON.stringify(expectedConfig, null, 2) - ) - }) - - it('should update existing config', async () => { - const existingConfig: ApiConfigData = { - currentApiConfigName: 'default', - apiConfigs: { - test: { - apiProvider: 'anthropic', - apiKey: 'old-key', - id: 'test-id' - } - } - } - - mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) - - const updatedConfig: ApiConfiguration = { - apiProvider: 'anthropic', - apiKey: 'new-key' - } - - await configManager.SaveConfig('test', updatedConfig) - - const expectedConfig = { - currentApiConfigName: 'default', - apiConfigs: { - test: { - apiProvider: 'anthropic', - apiKey: 'new-key', - id: 'test-id' - } - } - } - - expect(mockSecrets.store).toHaveBeenCalledWith( - 'roo_cline_config_api_config', - JSON.stringify(expectedConfig, null, 2) - ) - }) - - it('should throw error if secrets storage fails', async () => { - mockSecrets.get.mockResolvedValue(JSON.stringify({ - currentApiConfigName: 'default', - apiConfigs: { default: {} } - })) - mockSecrets.store.mockRejectedValueOnce(new Error('Storage failed')) - - await expect(configManager.SaveConfig('test', {})).rejects.toThrow( - 'Failed to save config: Error: Failed to write config to secrets: Error: Storage failed' - ) - }) - }) - - describe('DeleteConfig', () => { - it('should delete existing config', async () => { - const existingConfig: ApiConfigData = { - currentApiConfigName: 'default', - apiConfigs: { - default: { - id: 'default' - }, - test: { - apiProvider: 'anthropic', - id: 'test-id' - } - } - } - - mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) - - await configManager.DeleteConfig('test') - - // Get the stored config to check the ID - const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1]) - expect(storedConfig.currentApiConfigName).toBe('default') - expect(Object.keys(storedConfig.apiConfigs)).toEqual(['default']) - expect(storedConfig.apiConfigs.default.id).toBeTruthy() - }) - - it('should throw error when trying to delete non-existent config', async () => { - mockSecrets.get.mockResolvedValue(JSON.stringify({ - currentApiConfigName: 'default', - apiConfigs: { default: {} } - })) - - await expect(configManager.DeleteConfig('nonexistent')).rejects.toThrow( - "Config 'nonexistent' not found" - ) - }) - - it('should throw error when trying to delete last remaining config', async () => { - mockSecrets.get.mockResolvedValue(JSON.stringify({ - currentApiConfigName: 'default', - apiConfigs: { - default: { - id: 'default' - } - } - })) - - await expect(configManager.DeleteConfig('default')).rejects.toThrow( - 'Cannot delete the last remaining configuration.' - ) - }) - }) - - describe('LoadConfig', () => { - it('should load config and update current config name', async () => { - const existingConfig: ApiConfigData = { - currentApiConfigName: 'default', - apiConfigs: { - test: { - apiProvider: 'anthropic', - apiKey: 'test-key', - id: 'test-id' - } - } - } - - mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) - - const config = await configManager.LoadConfig('test') - - expect(config).toEqual({ - apiProvider: 'anthropic', - apiKey: 'test-key', - id: 'test-id' - }) - - // Get the stored config to check the structure - const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1]) - expect(storedConfig.currentApiConfigName).toBe('test') - expect(storedConfig.apiConfigs.test).toEqual({ - apiProvider: 'anthropic', - apiKey: 'test-key', - id: 'test-id' - }) - }) - - it('should throw error when config does not exist', async () => { - mockSecrets.get.mockResolvedValue(JSON.stringify({ - currentApiConfigName: 'default', - apiConfigs: { - default: { - config: {}, - id: 'default' - } - } - })) - - await expect(configManager.LoadConfig('nonexistent')).rejects.toThrow( - "Config 'nonexistent' not found" - ) - }) - - it('should throw error if secrets storage fails', async () => { - mockSecrets.get.mockResolvedValue(JSON.stringify({ - currentApiConfigName: 'default', - apiConfigs: { - test: { - config: { - apiProvider: 'anthropic' - }, - id: 'test-id' - } - } - })) - mockSecrets.store.mockRejectedValueOnce(new Error('Storage failed')) - - await expect(configManager.LoadConfig('test')).rejects.toThrow( - 'Failed to load config: Error: Failed to write config to secrets: Error: Storage failed' - ) - }) - }) - - describe('SetCurrentConfig', () => { - it('should set current config', async () => { - const existingConfig: ApiConfigData = { - currentApiConfigName: 'default', - apiConfigs: { - default: { - id: 'default' - }, - test: { - apiProvider: 'anthropic', - id: 'test-id' - } - } - } - - mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) - - await configManager.SetCurrentConfig('test') - - // Get the stored config to check the structure - const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1]) - expect(storedConfig.currentApiConfigName).toBe('test') - expect(storedConfig.apiConfigs.default.id).toBe('default') - expect(storedConfig.apiConfigs.test).toEqual({ - apiProvider: 'anthropic', - id: 'test-id' - }) - }) - - it('should throw error when config does not exist', async () => { - mockSecrets.get.mockResolvedValue(JSON.stringify({ - currentApiConfigName: 'default', - apiConfigs: { default: {} } - })) - - await expect(configManager.SetCurrentConfig('nonexistent')).rejects.toThrow( - "Config 'nonexistent' not found" - ) - }) - - it('should throw error if secrets storage fails', async () => { - mockSecrets.get.mockResolvedValue(JSON.stringify({ - currentApiConfigName: 'default', - apiConfigs: { - test: { apiProvider: 'anthropic' } - } - })) - mockSecrets.store.mockRejectedValueOnce(new Error('Storage failed')) - - await expect(configManager.SetCurrentConfig('test')).rejects.toThrow( - 'Failed to set current config: Error: Failed to write config to secrets: Error: Storage failed' - ) - }) - }) - - describe('HasConfig', () => { - it('should return true for existing config', async () => { - const existingConfig: ApiConfigData = { - currentApiConfigName: 'default', - apiConfigs: { - default: { - id: 'default' - }, - test: { - apiProvider: 'anthropic', - id: 'test-id' - } - } - } - - mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) - - const hasConfig = await configManager.HasConfig('test') - expect(hasConfig).toBe(true) - }) - - it('should return false for non-existent config', async () => { - mockSecrets.get.mockResolvedValue(JSON.stringify({ - currentApiConfigName: 'default', - apiConfigs: { default: {} } - })) - - const hasConfig = await configManager.HasConfig('nonexistent') - expect(hasConfig).toBe(false) - }) - - it('should throw error if secrets storage fails', async () => { - mockSecrets.get.mockRejectedValue(new Error('Storage failed')) - - await expect(configManager.HasConfig('test')).rejects.toThrow( - 'Failed to check config existence: Error: Failed to read config from secrets: Error: Storage failed' - ) - }) - }) -}) \ No newline at end of file +describe("ConfigManager", () => { + let configManager: ConfigManager + + beforeEach(() => { + jest.clearAllMocks() + configManager = new ConfigManager(mockContext) + }) + + describe("initConfig", () => { + it("should not write to storage when secrets.get returns null", async () => { + // Mock readConfig to return null + mockSecrets.get.mockResolvedValueOnce(null) + + await configManager.initConfig() + + // Should not write to storage because readConfig returns defaultConfig + expect(mockSecrets.store).not.toHaveBeenCalled() + }) + + it("should not initialize config if it exists", async () => { + mockSecrets.get.mockResolvedValue( + JSON.stringify({ + currentApiConfigName: "default", + apiConfigs: { + default: { + config: {}, + id: "default", + }, + }, + }), + ) + + await configManager.initConfig() + + expect(mockSecrets.store).not.toHaveBeenCalled() + }) + + it("should generate IDs for configs that lack them", async () => { + // Mock a config with missing IDs + mockSecrets.get.mockResolvedValue( + JSON.stringify({ + currentApiConfigName: "default", + apiConfigs: { + default: { + config: {}, + }, + test: { + apiProvider: "anthropic", + }, + }, + }), + ) + + await configManager.initConfig() + + // Should have written the config with new IDs + expect(mockSecrets.store).toHaveBeenCalled() + const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1]) + expect(storedConfig.apiConfigs.default.id).toBeTruthy() + expect(storedConfig.apiConfigs.test.id).toBeTruthy() + }) + + it("should throw error if secrets storage fails", async () => { + mockSecrets.get.mockRejectedValue(new Error("Storage failed")) + + await expect(configManager.initConfig()).rejects.toThrow( + "Failed to initialize config: Error: Failed to read config from secrets: Error: Storage failed", + ) + }) + }) + + describe("ListConfig", () => { + it("should list all available configs", async () => { + const existingConfig: ApiConfigData = { + currentApiConfigName: "default", + apiConfigs: { + default: { + id: "default", + }, + test: { + apiProvider: "anthropic", + id: "test-id", + }, + }, + modeApiConfigs: { + code: "default", + architect: "default", + ask: "default", + }, + } + + mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) + + const configs = await configManager.ListConfig() + expect(configs).toEqual([ + { name: "default", id: "default", apiProvider: undefined }, + { name: "test", id: "test-id", apiProvider: "anthropic" }, + ]) + }) + + it("should handle empty config file", async () => { + const emptyConfig: ApiConfigData = { + currentApiConfigName: "default", + apiConfigs: {}, + modeApiConfigs: { + code: "default", + architect: "default", + ask: "default", + }, + } + + mockSecrets.get.mockResolvedValue(JSON.stringify(emptyConfig)) + + const configs = await configManager.ListConfig() + expect(configs).toEqual([]) + }) + + it("should throw error if reading from secrets fails", async () => { + mockSecrets.get.mockRejectedValue(new Error("Read failed")) + + await expect(configManager.ListConfig()).rejects.toThrow( + "Failed to list configs: Error: Failed to read config from secrets: Error: Read failed", + ) + }) + }) + + describe("SaveConfig", () => { + it("should save new config", async () => { + mockSecrets.get.mockResolvedValue( + JSON.stringify({ + currentApiConfigName: "default", + apiConfigs: { + default: {}, + }, + modeApiConfigs: { + code: "default", + architect: "default", + ask: "default", + }, + }), + ) + + const newConfig: ApiConfiguration = { + apiProvider: "anthropic", + apiKey: "test-key", + } + + await configManager.SaveConfig("test", newConfig) + + // Get the actual stored config to check the generated ID + const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1]) + const testConfigId = storedConfig.apiConfigs.test.id + + const expectedConfig = { + currentApiConfigName: "default", + apiConfigs: { + default: {}, + test: { + ...newConfig, + id: testConfigId, + }, + }, + modeApiConfigs: { + code: "default", + architect: "default", + ask: "default", + }, + } + + expect(mockSecrets.store).toHaveBeenCalledWith( + "roo_cline_config_api_config", + JSON.stringify(expectedConfig, null, 2), + ) + }) + + it("should update existing config", async () => { + const existingConfig: ApiConfigData = { + currentApiConfigName: "default", + apiConfigs: { + test: { + apiProvider: "anthropic", + apiKey: "old-key", + id: "test-id", + }, + }, + } + + mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) + + const updatedConfig: ApiConfiguration = { + apiProvider: "anthropic", + apiKey: "new-key", + } + + await configManager.SaveConfig("test", updatedConfig) + + const expectedConfig = { + currentApiConfigName: "default", + apiConfigs: { + test: { + apiProvider: "anthropic", + apiKey: "new-key", + id: "test-id", + }, + }, + } + + expect(mockSecrets.store).toHaveBeenCalledWith( + "roo_cline_config_api_config", + JSON.stringify(expectedConfig, null, 2), + ) + }) + + it("should throw error if secrets storage fails", async () => { + mockSecrets.get.mockResolvedValue( + JSON.stringify({ + currentApiConfigName: "default", + apiConfigs: { default: {} }, + }), + ) + mockSecrets.store.mockRejectedValueOnce(new Error("Storage failed")) + + await expect(configManager.SaveConfig("test", {})).rejects.toThrow( + "Failed to save config: Error: Failed to write config to secrets: Error: Storage failed", + ) + }) + }) + + describe("DeleteConfig", () => { + it("should delete existing config", async () => { + const existingConfig: ApiConfigData = { + currentApiConfigName: "default", + apiConfigs: { + default: { + id: "default", + }, + test: { + apiProvider: "anthropic", + id: "test-id", + }, + }, + } + + mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) + + await configManager.DeleteConfig("test") + + // Get the stored config to check the ID + const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1]) + expect(storedConfig.currentApiConfigName).toBe("default") + expect(Object.keys(storedConfig.apiConfigs)).toEqual(["default"]) + expect(storedConfig.apiConfigs.default.id).toBeTruthy() + }) + + it("should throw error when trying to delete non-existent config", async () => { + mockSecrets.get.mockResolvedValue( + JSON.stringify({ + currentApiConfigName: "default", + apiConfigs: { default: {} }, + }), + ) + + await expect(configManager.DeleteConfig("nonexistent")).rejects.toThrow("Config 'nonexistent' not found") + }) + + it("should throw error when trying to delete last remaining config", async () => { + mockSecrets.get.mockResolvedValue( + JSON.stringify({ + currentApiConfigName: "default", + apiConfigs: { + default: { + id: "default", + }, + }, + }), + ) + + await expect(configManager.DeleteConfig("default")).rejects.toThrow( + "Cannot delete the last remaining configuration.", + ) + }) + }) + + describe("LoadConfig", () => { + it("should load config and update current config name", async () => { + const existingConfig: ApiConfigData = { + currentApiConfigName: "default", + apiConfigs: { + test: { + apiProvider: "anthropic", + apiKey: "test-key", + id: "test-id", + }, + }, + } + + mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) + + const config = await configManager.LoadConfig("test") + + expect(config).toEqual({ + apiProvider: "anthropic", + apiKey: "test-key", + id: "test-id", + }) + + // Get the stored config to check the structure + const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1]) + expect(storedConfig.currentApiConfigName).toBe("test") + expect(storedConfig.apiConfigs.test).toEqual({ + apiProvider: "anthropic", + apiKey: "test-key", + id: "test-id", + }) + }) + + it("should throw error when config does not exist", async () => { + mockSecrets.get.mockResolvedValue( + JSON.stringify({ + currentApiConfigName: "default", + apiConfigs: { + default: { + config: {}, + id: "default", + }, + }, + }), + ) + + await expect(configManager.LoadConfig("nonexistent")).rejects.toThrow("Config 'nonexistent' not found") + }) + + it("should throw error if secrets storage fails", async () => { + mockSecrets.get.mockResolvedValue( + JSON.stringify({ + currentApiConfigName: "default", + apiConfigs: { + test: { + config: { + apiProvider: "anthropic", + }, + id: "test-id", + }, + }, + }), + ) + mockSecrets.store.mockRejectedValueOnce(new Error("Storage failed")) + + await expect(configManager.LoadConfig("test")).rejects.toThrow( + "Failed to load config: Error: Failed to write config to secrets: Error: Storage failed", + ) + }) + }) + + describe("SetCurrentConfig", () => { + it("should set current config", async () => { + const existingConfig: ApiConfigData = { + currentApiConfigName: "default", + apiConfigs: { + default: { + id: "default", + }, + test: { + apiProvider: "anthropic", + id: "test-id", + }, + }, + } + + mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) + + await configManager.SetCurrentConfig("test") + + // Get the stored config to check the structure + const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1]) + expect(storedConfig.currentApiConfigName).toBe("test") + expect(storedConfig.apiConfigs.default.id).toBe("default") + expect(storedConfig.apiConfigs.test).toEqual({ + apiProvider: "anthropic", + id: "test-id", + }) + }) + + it("should throw error when config does not exist", async () => { + mockSecrets.get.mockResolvedValue( + JSON.stringify({ + currentApiConfigName: "default", + apiConfigs: { default: {} }, + }), + ) + + await expect(configManager.SetCurrentConfig("nonexistent")).rejects.toThrow( + "Config 'nonexistent' not found", + ) + }) + + it("should throw error if secrets storage fails", async () => { + mockSecrets.get.mockResolvedValue( + JSON.stringify({ + currentApiConfigName: "default", + apiConfigs: { + test: { apiProvider: "anthropic" }, + }, + }), + ) + mockSecrets.store.mockRejectedValueOnce(new Error("Storage failed")) + + await expect(configManager.SetCurrentConfig("test")).rejects.toThrow( + "Failed to set current config: Error: Failed to write config to secrets: Error: Storage failed", + ) + }) + }) + + describe("HasConfig", () => { + it("should return true for existing config", async () => { + const existingConfig: ApiConfigData = { + currentApiConfigName: "default", + apiConfigs: { + default: { + id: "default", + }, + test: { + apiProvider: "anthropic", + id: "test-id", + }, + }, + } + + mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) + + const hasConfig = await configManager.HasConfig("test") + expect(hasConfig).toBe(true) + }) + + it("should return false for non-existent config", async () => { + mockSecrets.get.mockResolvedValue( + JSON.stringify({ + currentApiConfigName: "default", + apiConfigs: { default: {} }, + }), + ) + + const hasConfig = await configManager.HasConfig("nonexistent") + expect(hasConfig).toBe(false) + }) + + it("should throw error if secrets storage fails", async () => { + mockSecrets.get.mockRejectedValue(new Error("Storage failed")) + + await expect(configManager.HasConfig("test")).rejects.toThrow( + "Failed to check config existence: Error: Failed to read config from secrets: Error: Storage failed", + ) + }) + }) +}) diff --git a/src/core/diff/DiffStrategy.ts b/src/core/diff/DiffStrategy.ts index ac3a0c4..de52498 100644 --- a/src/core/diff/DiffStrategy.ts +++ b/src/core/diff/DiffStrategy.ts @@ -1,17 +1,21 @@ -import type { DiffStrategy } from './types' -import { UnifiedDiffStrategy } from './strategies/unified' -import { SearchReplaceDiffStrategy } from './strategies/search-replace' -import { NewUnifiedDiffStrategy } from './strategies/new-unified' +import type { DiffStrategy } from "./types" +import { UnifiedDiffStrategy } from "./strategies/unified" +import { SearchReplaceDiffStrategy } from "./strategies/search-replace" +import { NewUnifiedDiffStrategy } from "./strategies/new-unified" /** * Get the appropriate diff strategy for the given model * @param model The name of the model being used (e.g., 'gpt-4', 'claude-3-opus') * @returns The appropriate diff strategy for the model */ -export function getDiffStrategy(model: string, fuzzyMatchThreshold?: number, experimentalDiffStrategy: boolean = false): DiffStrategy { - if (experimentalDiffStrategy) { - return new NewUnifiedDiffStrategy(fuzzyMatchThreshold) - } - return new SearchReplaceDiffStrategy(fuzzyMatchThreshold) +export function getDiffStrategy( + model: string, + fuzzyMatchThreshold?: number, + experimentalDiffStrategy: boolean = false, +): DiffStrategy { + if (experimentalDiffStrategy) { + return new NewUnifiedDiffStrategy(fuzzyMatchThreshold) + } + return new SearchReplaceDiffStrategy(fuzzyMatchThreshold) } export type { DiffStrategy } diff --git a/src/core/diff/strategies/__tests__/new-unified.test.ts b/src/core/diff/strategies/__tests__/new-unified.test.ts index 3504e18..da30173 100644 --- a/src/core/diff/strategies/__tests__/new-unified.test.ts +++ b/src/core/diff/strategies/__tests__/new-unified.test.ts @@ -1,74 +1,73 @@ -import { NewUnifiedDiffStrategy } from '../new-unified'; +import { NewUnifiedDiffStrategy } from "../new-unified" -describe('main', () => { +describe("main", () => { + let strategy: NewUnifiedDiffStrategy - let strategy: NewUnifiedDiffStrategy + beforeEach(() => { + strategy = new NewUnifiedDiffStrategy(0.97) + }) - beforeEach(() => { - strategy = new NewUnifiedDiffStrategy(0.97) - }) + describe("constructor", () => { + it("should use default confidence threshold when not provided", () => { + const defaultStrategy = new NewUnifiedDiffStrategy() + expect(defaultStrategy["confidenceThreshold"]).toBe(1) + }) - describe('constructor', () => { - it('should use default confidence threshold when not provided', () => { - const defaultStrategy = new NewUnifiedDiffStrategy() - expect(defaultStrategy['confidenceThreshold']).toBe(1) - }) + it("should use provided confidence threshold", () => { + const customStrategy = new NewUnifiedDiffStrategy(0.85) + expect(customStrategy["confidenceThreshold"]).toBe(0.85) + }) - it('should use provided confidence threshold', () => { - const customStrategy = new NewUnifiedDiffStrategy(0.85) - expect(customStrategy['confidenceThreshold']).toBe(0.85) - }) + it("should enforce minimum confidence threshold", () => { + const lowStrategy = new NewUnifiedDiffStrategy(0.7) // Below minimum of 0.8 + expect(lowStrategy["confidenceThreshold"]).toBe(0.8) + }) + }) - it('should enforce minimum confidence threshold', () => { - const lowStrategy = new NewUnifiedDiffStrategy(0.7) // Below minimum of 0.8 - expect(lowStrategy['confidenceThreshold']).toBe(0.8) - }) - }) + describe("getToolDescription", () => { + it("should return tool description with correct cwd", () => { + const cwd = "/test/path" + const description = strategy.getToolDescription({ cwd }) - describe('getToolDescription', () => { - it('should return tool description with correct cwd', () => { - const cwd = '/test/path' - const description = strategy.getToolDescription({ cwd }) - - expect(description).toContain('apply_diff') - expect(description).toContain(cwd) - expect(description).toContain('Parameters:') - expect(description).toContain('Format Requirements:') - }) - }) + expect(description).toContain("apply_diff") + expect(description).toContain(cwd) + expect(description).toContain("Parameters:") + expect(description).toContain("Format Requirements:") + }) + }) - it('should apply simple diff correctly', async () => { - const original = `line1 + it("should apply simple diff correctly", async () => { + const original = `line1 line2 -line3`; +line3` - const diff = `--- a/file.txt + const diff = `--- a/file.txt +++ b/file.txt @@ ... @@ line1 +new line line2 -line3 -+modified line3`; ++modified line3` - const result = await strategy.applyDiff(original, diff); - expect(result.success).toBe(true); - if(result.success) { - expect(result.content).toBe(`line1 + const result = await strategy.applyDiff(original, diff) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`line1 new line line2 -modified line3`); - } - }); +modified line3`) + } + }) - it('should handle multiple hunks', async () => { - const original = `line1 + it("should handle multiple hunks", async () => { + const original = `line1 line2 line3 line4 -line5`; +line5` - const diff = `--- a/file.txt + const diff = `--- a/file.txt +++ b/file.txt @@ ... @@ line1 @@ -80,23 +79,23 @@ line5`; line4 -line5 +modified line5 -+new line at end`; ++new line at end` - const result = await strategy.applyDiff(original, diff); - expect(result.success).toBe(true); - if (result.success) { - expect(result.content).toBe(`line1 + const result = await strategy.applyDiff(original, diff) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`line1 new line line2 modified line3 line4 modified line5 -new line at end`); - } - }); +new line at end`) + } + }) - it('should handle complex large', async () => { - const original = `line1 + it("should handle complex large", async () => { + const original = `line1 line2 line3 line4 @@ -105,9 +104,9 @@ line6 line7 line8 line9 -line10`; +line10` - const diff = `--- a/file.txt + const diff = `--- a/file.txt +++ b/file.txt @@ ... @@ line1 @@ -130,12 +129,12 @@ line10`; line9 -line10 +final line -+very last line`; ++very last line` - const result = await strategy.applyDiff(original, diff); - expect(result.success).toBe(true); - if (result.success) { - expect(result.content).toBe(`line1 + const result = await strategy.applyDiff(original, diff) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`line1 header line another header line2 @@ -150,12 +149,12 @@ changed line8 bonus line line9 final line -very last line`); - } - }); +very last line`) + } + }) - it('should handle indentation changes', async () => { - const original = `first line + it("should handle indentation changes", async () => { + const original = `first line indented line double indented line back to single indent @@ -164,9 +163,9 @@ no indent double indent again triple indent back to single -last line`; +last line` - const diff = `--- original + const diff = `--- original +++ modified @@ ... @@ first line @@ -181,9 +180,9 @@ last line`; - triple indent + hi there mate back to single - last line`; + last line` - const expected = `first line + const expected = `first line indented line tab indented line new indented line @@ -194,23 +193,22 @@ no indent double indent again hi there mate back to single -last line`; +last line` - const result = await strategy.applyDiff(original, diff); - expect(result.success).toBe(true); - if (result.success) { - expect(result.content).toBe(expected); - } - }); + const result = await strategy.applyDiff(original, diff) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(expected) + } + }) - it('should handle high level edits', async () => { - - const original = `def factorial(n): + it("should handle high level edits", async () => { + const original = `def factorial(n): if n == 0: return 1 else: return n * factorial(n-1)` - const diff = `@@ ... @@ + const diff = `@@ ... @@ -def factorial(n): - if n == 0: - return 1 @@ -222,21 +220,21 @@ last line`; + else: + return number * factorial(number-1)` -const expected = `def factorial(number): + const expected = `def factorial(number): if number == 0: return 1 else: return number * factorial(number-1)` - const result = await strategy.applyDiff(original, diff); - expect(result.success).toBe(true); - if (result.success) { - expect(result.content).toBe(expected); - } - }); + const result = await strategy.applyDiff(original, diff) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(expected) + } + }) - it('it should handle very complex edits', async () => { - const original = `//Initialize the array that will hold the primes + it("it should handle very complex edits", async () => { + const original = `//Initialize the array that will hold the primes var primeArray = []; /*Write a function that checks for primeness and pushes those values to t*he array*/ @@ -269,7 +267,7 @@ for (var i = 2; primeArray.length < numPrimes; i++) { console.log(primeArray); ` - const diff = `--- test_diff.js + const diff = `--- test_diff.js +++ test_diff.js @@ ... @@ -//Initialize the array that will hold the primes @@ -297,7 +295,7 @@ console.log(primeArray); } console.log(primeArray);` - const expected = `var primeArray = []; + const expected = `var primeArray = []; function PrimeCheck(candidate){ isPrime = true; for(var i = 2; i < candidate && isPrime; i++){ @@ -320,58 +318,57 @@ for (var i = 2; primeArray.length < numPrimes; i++) { } console.log(primeArray); ` - - const result = await strategy.applyDiff(original, diff); - expect(result.success).toBe(true); - if (result.success) { - expect(result.content).toBe(expected); - } - }); + const result = await strategy.applyDiff(original, diff) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(expected) + } + }) - describe('error handling and edge cases', () => { - it('should reject completely invalid diff format', async () => { - const original = 'line1\nline2\nline3'; - const invalidDiff = 'this is not a diff at all'; - - const result = await strategy.applyDiff(original, invalidDiff); - expect(result.success).toBe(false); - }); + describe("error handling and edge cases", () => { + it("should reject completely invalid diff format", async () => { + const original = "line1\nline2\nline3" + const invalidDiff = "this is not a diff at all" - it('should reject diff with invalid hunk format', async () => { - const original = 'line1\nline2\nline3'; - const invalidHunkDiff = `--- a/file.txt + const result = await strategy.applyDiff(original, invalidDiff) + expect(result.success).toBe(false) + }) + + it("should reject diff with invalid hunk format", async () => { + const original = "line1\nline2\nline3" + const invalidHunkDiff = `--- a/file.txt +++ b/file.txt invalid hunk header line1 -line2 -+new line`; - - const result = await strategy.applyDiff(original, invalidHunkDiff); - expect(result.success).toBe(false); - }); ++new line` - it('should fail when diff tries to modify non-existent content', async () => { - const original = 'line1\nline2\nline3'; - const nonMatchingDiff = `--- a/file.txt + const result = await strategy.applyDiff(original, invalidHunkDiff) + expect(result.success).toBe(false) + }) + + it("should fail when diff tries to modify non-existent content", async () => { + const original = "line1\nline2\nline3" + const nonMatchingDiff = `--- a/file.txt +++ b/file.txt @@ ... @@ line1 -nonexistent line +new line - line3`; - - const result = await strategy.applyDiff(original, nonMatchingDiff); - expect(result.success).toBe(false); - }); + line3` - it('should handle overlapping hunks', async () => { - const original = `line1 + const result = await strategy.applyDiff(original, nonMatchingDiff) + expect(result.success).toBe(false) + }) + + it("should handle overlapping hunks", async () => { + const original = `line1 line2 line3 line4 -line5`; - const overlappingDiff = `--- a/file.txt +line5` + const overlappingDiff = `--- a/file.txt +++ b/file.txt @@ ... @@ line1 @@ -384,19 +381,19 @@ line5`; -line3 -line4 +modified3and4 - line5`; - - const result = await strategy.applyDiff(original, overlappingDiff); - expect(result.success).toBe(false); - }); + line5` - it('should handle empty lines modifications', async () => { - const original = `line1 + const result = await strategy.applyDiff(original, overlappingDiff) + expect(result.success).toBe(false) + }) + + it("should handle empty lines modifications", async () => { + const original = `line1 line3 -line5`; - const emptyLinesDiff = `--- a/file.txt +line5` + const emptyLinesDiff = `--- a/file.txt +++ b/file.txt @@ ... @@ line1 @@ -404,73 +401,73 @@ line5`; -line3 +line3modified - line5`; - - const result = await strategy.applyDiff(original, emptyLinesDiff); - expect(result.success).toBe(true); - if (result.success) { - expect(result.content).toBe(`line1 + line5` + + const result = await strategy.applyDiff(original, emptyLinesDiff) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`line1 line3modified -line5`); - } - }); +line5`) + } + }) - it('should handle mixed line endings in diff', async () => { - const original = 'line1\r\nline2\nline3\r\n'; - const mixedEndingsDiff = `--- a/file.txt + it("should handle mixed line endings in diff", async () => { + const original = "line1\r\nline2\nline3\r\n" + const mixedEndingsDiff = `--- a/file.txt +++ b/file.txt @@ ... @@ line1\r -line2 +modified2\r - line3`; - - const result = await strategy.applyDiff(original, mixedEndingsDiff); - expect(result.success).toBe(true); - if (result.success) { - expect(result.content).toBe('line1\r\nmodified2\r\nline3\r\n'); - } - }); + line3` - it('should handle partial line modifications', async () => { - const original = 'const value = oldValue + 123;'; - const partialDiff = `--- a/file.txt + const result = await strategy.applyDiff(original, mixedEndingsDiff) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe("line1\r\nmodified2\r\nline3\r\n") + } + }) + + it("should handle partial line modifications", async () => { + const original = "const value = oldValue + 123;" + const partialDiff = `--- a/file.txt +++ b/file.txt @@ ... @@ -const value = oldValue + 123; -+const value = newValue + 123;`; - - const result = await strategy.applyDiff(original, partialDiff); - expect(result.success).toBe(true); - if (result.success) { - expect(result.content).toBe('const value = newValue + 123;'); - } - }); ++const value = newValue + 123;` - it('should handle slightly malformed but recoverable diff', async () => { - const original = 'line1\nline2\nline3'; - // Missing space after --- and +++ - const slightlyBadDiff = `---a/file.txt + const result = await strategy.applyDiff(original, partialDiff) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe("const value = newValue + 123;") + } + }) + + it("should handle slightly malformed but recoverable diff", async () => { + const original = "line1\nline2\nline3" + // Missing space after --- and +++ + const slightlyBadDiff = `---a/file.txt +++b/file.txt @@ ... @@ line1 -line2 +new line - line3`; - - const result = await strategy.applyDiff(original, slightlyBadDiff); - expect(result.success).toBe(true); - if (result.success) { - expect(result.content).toBe('line1\nnew line\nline3'); - } - }); - }); + line3` - describe('similar code sections', () => { - it('should correctly modify the right section when similar code exists', async () => { - const original = `function add(a, b) { + const result = await strategy.applyDiff(original, slightlyBadDiff) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe("line1\nnew line\nline3") + } + }) + }) + + describe("similar code sections", () => { + it("should correctly modify the right section when similar code exists", async () => { + const original = `function add(a, b) { return a + b; } @@ -480,20 +477,20 @@ function subtract(a, b) { function multiply(a, b) { return a + b; // Bug here -}`; +}` - const diff = `--- a/math.js + const diff = `--- a/math.js +++ b/math.js @@ ... @@ function multiply(a, b) { - return a + b; // Bug here + return a * b; - }`; + }` - const result = await strategy.applyDiff(original, diff); - expect(result.success).toBe(true); - if (result.success) { - expect(result.content).toBe(`function add(a, b) { + const result = await strategy.applyDiff(original, diff) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`function add(a, b) { return a + b; } @@ -503,12 +500,12 @@ function subtract(a, b) { function multiply(a, b) { return a * b; -}`); - } - }); +}`) + } + }) - it('should handle multiple similar sections with correct context', async () => { - const original = `if (condition) { + it("should handle multiple similar sections with correct context", async () => { + const original = `if (condition) { doSomething(); doSomething(); doSomething(); @@ -518,9 +515,9 @@ if (otherCondition) { doSomething(); doSomething(); doSomething(); -}`; +}` - const diff = `--- a/file.js + const diff = `--- a/file.js +++ b/file.js @@ ... @@ if (otherCondition) { @@ -528,12 +525,12 @@ if (otherCondition) { - doSomething(); + doSomethingElse(); doSomething(); - }`; + }` - const result = await strategy.applyDiff(original, diff); - expect(result.success).toBe(true); - if (result.success) { - expect(result.content).toBe(`if (condition) { + const result = await strategy.applyDiff(original, diff) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`if (condition) { doSomething(); doSomething(); doSomething(); @@ -543,14 +540,14 @@ if (otherCondition) { doSomething(); doSomethingElse(); doSomething(); -}`); - } - }); - }); +}`) + } + }) + }) - describe('hunk splitting', () => { - it('should handle large diffs with multiple non-contiguous changes', async () => { - const original = `import { readFile } from 'fs'; + describe("hunk splitting", () => { + it("should handle large diffs with multiple non-contiguous changes", async () => { + const original = `import { readFile } from 'fs'; import { join } from 'path'; import { Logger } from './logger'; @@ -595,9 +592,9 @@ export { validateInput, writeOutput, parseConfig -};`; +};` - const diff = `--- a/file.ts + const diff = `--- a/file.ts +++ b/file.ts @@ ... @@ -import { readFile } from 'fs'; @@ -672,9 +669,9 @@ export { - parseConfig + parseConfig, + type Config - };`; + };` - const expected = `import { readFile, writeFile } from 'fs'; + const expected = `import { readFile, writeFile } from 'fs'; import { join } from 'path'; import { Logger } from './utils/logger'; import { Config } from './types'; @@ -727,13 +724,13 @@ export { writeOutput, parseConfig, type Config -};`; +};` - const result = await strategy.applyDiff(original, diff); - expect(result.success).toBe(true); - if (result.success) { - expect(result.content).toBe(expected); - } - }); - }); -}); \ No newline at end of file + const result = await strategy.applyDiff(original, diff) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(expected) + } + }) + }) +}) diff --git a/src/core/diff/strategies/__tests__/search-replace.test.ts b/src/core/diff/strategies/__tests__/search-replace.test.ts index 9436027..cd71eda 100644 --- a/src/core/diff/strategies/__tests__/search-replace.test.ts +++ b/src/core/diff/strategies/__tests__/search-replace.test.ts @@ -1,16 +1,16 @@ -import { SearchReplaceDiffStrategy } from '../search-replace' +import { SearchReplaceDiffStrategy } from "../search-replace" -describe('SearchReplaceDiffStrategy', () => { - describe('exact matching', () => { - let strategy: SearchReplaceDiffStrategy +describe("SearchReplaceDiffStrategy", () => { + describe("exact matching", () => { + let strategy: SearchReplaceDiffStrategy - beforeEach(() => { - strategy = new SearchReplaceDiffStrategy(1.0, 5) // Default 1.0 threshold for exact matching, 5 line buffer for tests - }) + beforeEach(() => { + strategy = new SearchReplaceDiffStrategy(1.0, 5) // Default 1.0 threshold for exact matching, 5 line buffer for tests + }) - it('should replace matching content', async () => { - const originalContent = 'function hello() {\n console.log("hello")\n}\n' - const diffContent = `test.ts + it("should replace matching content", async () => { + const originalContent = 'function hello() {\n console.log("hello")\n}\n' + const diffContent = `test.ts <<<<<<< SEARCH function hello() { console.log("hello") @@ -21,16 +21,16 @@ function hello() { } >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe('function hello() {\n console.log("hello world")\n}\n') - } - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe('function hello() {\n console.log("hello world")\n}\n') + } + }) - it('should match content with different surrounding whitespace', async () => { - const originalContent = '\nfunction example() {\n return 42;\n}\n\n' - const diffContent = `test.ts + it("should match content with different surrounding whitespace", async () => { + const originalContent = "\nfunction example() {\n return 42;\n}\n\n" + const diffContent = `test.ts <<<<<<< SEARCH function example() { return 42; @@ -41,16 +41,16 @@ function example() { } >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe('\nfunction example() {\n return 43;\n}\n\n') - } - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe("\nfunction example() {\n return 43;\n}\n\n") + } + }) - it('should match content with different indentation in search block', async () => { - const originalContent = ' function test() {\n return true;\n }\n' - const diffContent = `test.ts + it("should match content with different indentation in search block", async () => { + const originalContent = " function test() {\n return true;\n }\n" + const diffContent = `test.ts <<<<<<< SEARCH function test() { return true; @@ -61,16 +61,16 @@ function test() { } >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(' function test() {\n return false;\n }\n') - } - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(" function test() {\n return false;\n }\n") + } + }) - it('should handle tab-based indentation', async () => { - const originalContent = "function test() {\n\treturn true;\n}\n" - const diffContent = `test.ts + it("should handle tab-based indentation", async () => { + const originalContent = "function test() {\n\treturn true;\n}\n" + const diffContent = `test.ts <<<<<<< SEARCH function test() { \treturn true; @@ -81,16 +81,16 @@ function test() { } >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe("function test() {\n\treturn false;\n}\n") - } - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe("function test() {\n\treturn false;\n}\n") + } + }) - it('should preserve mixed tabs and spaces', async () => { - const originalContent = "\tclass Example {\n\t constructor() {\n\t\tthis.value = 0;\n\t }\n\t}" - const diffContent = `test.ts + it("should preserve mixed tabs and spaces", async () => { + const originalContent = "\tclass Example {\n\t constructor() {\n\t\tthis.value = 0;\n\t }\n\t}" + const diffContent = `test.ts <<<<<<< SEARCH \tclass Example { \t constructor() { @@ -105,16 +105,18 @@ function test() { \t} >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe("\tclass Example {\n\t constructor() {\n\t\tthis.value = 1;\n\t }\n\t}") - } - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe( + "\tclass Example {\n\t constructor() {\n\t\tthis.value = 1;\n\t }\n\t}", + ) + } + }) - it('should handle additional indentation with tabs', async () => { - const originalContent = "\tfunction test() {\n\t\treturn true;\n\t}" - const diffContent = `test.ts + it("should handle additional indentation with tabs", async () => { + const originalContent = "\tfunction test() {\n\t\treturn true;\n\t}" + const diffContent = `test.ts <<<<<<< SEARCH function test() { \treturn true; @@ -126,16 +128,16 @@ function test() { } >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe("\tfunction test() {\n\t\t// Add comment\n\t\treturn false;\n\t}") - } - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe("\tfunction test() {\n\t\t// Add comment\n\t\treturn false;\n\t}") + } + }) - it('should preserve exact indentation characters when adding lines', async () => { - const originalContent = "\tfunction test() {\n\t\treturn true;\n\t}" - const diffContent = `test.ts + it("should preserve exact indentation characters when adding lines", async () => { + const originalContent = "\tfunction test() {\n\t\treturn true;\n\t}" + const diffContent = `test.ts <<<<<<< SEARCH \tfunction test() { \t\treturn true; @@ -148,16 +150,18 @@ function test() { \t} >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe("\tfunction test() {\n\t\t// First comment\n\t\t// Second comment\n\t\treturn true;\n\t}") - } - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe( + "\tfunction test() {\n\t\t// First comment\n\t\t// Second comment\n\t\treturn true;\n\t}", + ) + } + }) - it('should handle Windows-style CRLF line endings', async () => { - const originalContent = "function test() {\r\n return true;\r\n}\r\n" - const diffContent = `test.ts + it("should handle Windows-style CRLF line endings", async () => { + const originalContent = "function test() {\r\n return true;\r\n}\r\n" + const diffContent = `test.ts <<<<<<< SEARCH function test() { return true; @@ -168,16 +172,16 @@ function test() { } >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe("function test() {\r\n return false;\r\n}\r\n") - } - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe("function test() {\r\n return false;\r\n}\r\n") + } + }) - it('should return false if search content does not match', async () => { - const originalContent = 'function hello() {\n console.log("hello")\n}\n' - const diffContent = `test.ts + it("should return false if search content does not match", async () => { + const originalContent = 'function hello() {\n console.log("hello")\n}\n' + const diffContent = `test.ts <<<<<<< SEARCH function hello() { console.log("wrong") @@ -188,21 +192,22 @@ function hello() { } >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(false) - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(false) + }) - it('should return false if diff format is invalid', async () => { - const originalContent = 'function hello() {\n console.log("hello")\n}\n' - const diffContent = `test.ts\nInvalid diff format` + it("should return false if diff format is invalid", async () => { + const originalContent = 'function hello() {\n console.log("hello")\n}\n' + const diffContent = `test.ts\nInvalid diff format` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(false) - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(false) + }) - it('should handle multiple lines with proper indentation', async () => { - const originalContent = 'class Example {\n constructor() {\n this.value = 0\n }\n\n getValue() {\n return this.value\n }\n}\n' - const diffContent = `test.ts + it("should handle multiple lines with proper indentation", async () => { + const originalContent = + "class Example {\n constructor() {\n this.value = 0\n }\n\n getValue() {\n return this.value\n }\n}\n" + const diffContent = `test.ts <<<<<<< SEARCH getValue() { return this.value @@ -215,16 +220,18 @@ function hello() { } >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe('class Example {\n constructor() {\n this.value = 0\n }\n\n getValue() {\n // Add logging\n console.log("Getting value")\n return this.value\n }\n}\n') - } - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe( + 'class Example {\n constructor() {\n this.value = 0\n }\n\n getValue() {\n // Add logging\n console.log("Getting value")\n return this.value\n }\n}\n', + ) + } + }) - it('should preserve whitespace exactly in the output', async () => { - const originalContent = " indented\n more indented\n back\n" - const diffContent = `test.ts + it("should preserve whitespace exactly in the output", async () => { + const originalContent = " indented\n more indented\n back\n" + const diffContent = `test.ts <<<<<<< SEARCH indented more indented @@ -235,16 +242,16 @@ function hello() { end >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(" modified\n still indented\n end\n") - } - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(" modified\n still indented\n end\n") + } + }) - it('should preserve indentation when adding new lines after existing content', async () => { - const originalContent = ' onScroll={() => updateHighlights()}' - const diffContent = `test.ts + it("should preserve indentation when adding new lines after existing content", async () => { + const originalContent = " onScroll={() => updateHighlights()}" + const diffContent = `test.ts <<<<<<< SEARCH onScroll={() => updateHighlights()} ======= @@ -255,15 +262,17 @@ function hello() { }} >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(' onScroll={() => updateHighlights()}\n onDragOver={(e) => {\n e.preventDefault()\n e.stopPropagation()\n }}') - } - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe( + " onScroll={() => updateHighlights()}\n onDragOver={(e) => {\n e.preventDefault()\n e.stopPropagation()\n }}", + ) + } + }) - it('should handle varying indentation levels correctly', async () => { - const originalContent = ` + it("should handle varying indentation levels correctly", async () => { + const originalContent = ` class Example { constructor() { this.value = 0; @@ -271,9 +280,9 @@ class Example { this.init(); } } -}`.trim(); - - const diffContent = `test.ts +}`.trim() + + const diffContent = `test.ts <<<<<<< SEARCH class Example { constructor() { @@ -294,12 +303,13 @@ class Example { } } } ->>>>>>> REPLACE`.trim(); - - const result = await strategy.applyDiff(originalContent, diffContent); - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(` +>>>>>>> REPLACE`.trim() + + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe( + ` class Example { constructor() { this.value = 1; @@ -309,20 +319,21 @@ class Example { this.validate(); } } -}`.trim()); - } - }) +}`.trim(), + ) + } + }) - it('should handle mixed indentation styles in the same file', async () => { - const originalContent = `class Example { + it("should handle mixed indentation styles in the same file", async () => { + const originalContent = `class Example { constructor() { this.value = 0; if (true) { this.init(); } } -}`.trim(); - const diffContent = `test.ts +}`.trim() + const diffContent = `test.ts <<<<<<< SEARCH constructor() { this.value = 0; @@ -338,12 +349,12 @@ class Example { this.validate(); } } ->>>>>>> REPLACE`; - - const result = await strategy.applyDiff(originalContent, diffContent); - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`class Example { +>>>>>>> REPLACE` + + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`class Example { constructor() { this.value = 1; if (true) { @@ -351,18 +362,18 @@ class Example { this.validate(); } } -}`); - } - }) - - it('should handle Python-style significant whitespace', async () => { - const originalContent = `def example(): +}`) + } + }) + + it("should handle Python-style significant whitespace", async () => { + const originalContent = `def example(): if condition: do_something() for item in items: process(item) - return True`.trim(); - const diffContent = `test.ts + return True`.trim() + const diffContent = `test.ts <<<<<<< SEARCH if condition: do_something() @@ -374,30 +385,30 @@ class Example { while items: item = items.pop() process(item) ->>>>>>> REPLACE`; - - const result = await strategy.applyDiff(originalContent, diffContent); - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`def example(): +>>>>>>> REPLACE` + + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`def example(): if condition: do_something() while items: item = items.pop() process(item) - return True`); - } - }); - - it('should preserve empty lines with indentation', async () => { - const originalContent = `function test() { + return True`) + } + }) + + it("should preserve empty lines with indentation", async () => { + const originalContent = `function test() { const x = 1; if (x) { return true; } -}`.trim(); - const diffContent = `test.ts +}`.trim() + const diffContent = `test.ts <<<<<<< SEARCH const x = 1; @@ -407,31 +418,31 @@ class Example { // Check x if (x) { ->>>>>>> REPLACE`; - - const result = await strategy.applyDiff(originalContent, diffContent); - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`function test() { +>>>>>>> REPLACE` + + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`function test() { const x = 1; // Check x if (x) { return true; } -}`); - } - }); - - it('should handle indentation when replacing entire blocks', async () => { - const originalContent = `class Test { +}`) + } + }) + + it("should handle indentation when replacing entire blocks", async () => { + const originalContent = `class Test { method() { if (true) { console.log("test"); } } -}`.trim(); - const diffContent = `test.ts +}`.trim() + const diffContent = `test.ts <<<<<<< SEARCH method() { if (true) { @@ -448,12 +459,12 @@ class Example { console.error(e); } } ->>>>>>> REPLACE`; - - const result = await strategy.applyDiff(originalContent, diffContent); - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`class Test { +>>>>>>> REPLACE` + + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`class Test { method() { try { if (true) { @@ -463,72 +474,72 @@ class Example { console.error(e); } } -}`); - } - }); +}`) + } + }) - it('should handle negative indentation relative to search content', async () => { - const originalContent = `class Example { + it("should handle negative indentation relative to search content", async () => { + const originalContent = `class Example { constructor() { if (true) { this.init(); this.setup(); } } -}`.trim(); - const diffContent = `test.ts +}`.trim() + const diffContent = `test.ts <<<<<<< SEARCH this.init(); this.setup(); ======= this.init(); this.setup(); ->>>>>>> REPLACE`; - - const result = await strategy.applyDiff(originalContent, diffContent); - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`class Example { +>>>>>>> REPLACE` + + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`class Example { constructor() { if (true) { this.init(); this.setup(); } } -}`); - } - }); - - it('should handle extreme negative indentation (no indent)', async () => { - const originalContent = `class Example { +}`) + } + }) + + it("should handle extreme negative indentation (no indent)", async () => { + const originalContent = `class Example { constructor() { if (true) { this.init(); } } -}`.trim(); - const diffContent = `test.ts +}`.trim() + const diffContent = `test.ts <<<<<<< SEARCH this.init(); ======= this.init(); ->>>>>>> REPLACE`; - - const result = await strategy.applyDiff(originalContent, diffContent); - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`class Example { +>>>>>>> REPLACE` + + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`class Example { constructor() { if (true) { this.init(); } } -}`); - } - }); - - it('should handle mixed indentation changes in replace block', async () => { - const originalContent = `class Example { +}`) + } + }) + + it("should handle mixed indentation changes in replace block", async () => { + const originalContent = `class Example { constructor() { if (true) { this.init(); @@ -536,8 +547,8 @@ this.init(); this.validate(); } } -}`.trim(); - const diffContent = `test.ts +}`.trim() + const diffContent = `test.ts <<<<<<< SEARCH this.init(); this.setup(); @@ -546,12 +557,12 @@ this.init(); this.init(); this.setup(); this.validate(); ->>>>>>> REPLACE`; - - const result = await strategy.applyDiff(originalContent, diffContent); - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`class Example { +>>>>>>> REPLACE` + + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`class Example { constructor() { if (true) { this.init(); @@ -559,12 +570,12 @@ this.init(); this.validate(); } } -}`); - } - }); +}`) + } + }) - it('should find matches from middle out', async () => { - const originalContent = ` + it("should find matches from middle out", async () => { + const originalContent = ` function one() { return "target"; } @@ -584,21 +595,21 @@ function four() { function five() { return "target"; }`.trim() - - const diffContent = `test.ts + + const diffContent = `test.ts <<<<<<< SEARCH return "target"; ======= return "updated"; >>>>>>> REPLACE` - - // Search around the middle (function three) - // Even though all functions contain the target text, - // it should match the one closest to line 9 first - const result = await strategy.applyDiff(originalContent, diffContent, 9, 9) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`function one() { + + // Search around the middle (function three) + // Even though all functions contain the target text, + // it should match the one closest to line 9 first + const result = await strategy.applyDiff(originalContent, diffContent, 9, 9) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`function one() { return "target"; } @@ -617,21 +628,21 @@ function four() { function five() { return "target"; }`) - } - }) - }) + } + }) + }) - describe('line number stripping', () => { - describe('line number stripping', () => { - let strategy: SearchReplaceDiffStrategy - - beforeEach(() => { - strategy = new SearchReplaceDiffStrategy() - }) - - it('should strip line numbers from both search and replace sections', async () => { - const originalContent = 'function test() {\n return true;\n}\n' - const diffContent = `test.ts + describe("line number stripping", () => { + describe("line number stripping", () => { + let strategy: SearchReplaceDiffStrategy + + beforeEach(() => { + strategy = new SearchReplaceDiffStrategy() + }) + + it("should strip line numbers from both search and replace sections", async () => { + const originalContent = "function test() {\n return true;\n}\n" + const diffContent = `test.ts <<<<<<< SEARCH 1 | function test() { 2 | return true; @@ -641,17 +652,17 @@ function five() { 2 | return false; 3 | } >>>>>>> REPLACE` - - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe('function test() {\n return false;\n}\n') - } - }) - it('should strip line numbers with leading spaces', async () => { - const originalContent = 'function test() {\n return true;\n}\n' - const diffContent = `test.ts + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe("function test() {\n return false;\n}\n") + } + }) + + it("should strip line numbers with leading spaces", async () => { + const originalContent = "function test() {\n return true;\n}\n" + const diffContent = `test.ts <<<<<<< SEARCH 1 | function test() { 2 | return true; @@ -661,17 +672,17 @@ function five() { 2 | return false; 3 | } >>>>>>> REPLACE` - - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe('function test() {\n return false;\n}\n') - } - }) - - it('should not strip when not all lines have numbers in either section', async () => { - const originalContent = 'function test() {\n return true;\n}\n' - const diffContent = `test.ts + + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe("function test() {\n return false;\n}\n") + } + }) + + it("should not strip when not all lines have numbers in either section", async () => { + const originalContent = "function test() {\n return true;\n}\n" + const diffContent = `test.ts <<<<<<< SEARCH 1 | function test() { 2 | return true; @@ -681,14 +692,14 @@ function five() { return false; 3 | } >>>>>>> REPLACE` - - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(false) - }) - - it('should preserve content that naturally starts with pipe', async () => { - const originalContent = '|header|another|\n|---|---|\n|data|more|\n' - const diffContent = `test.ts + + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(false) + }) + + it("should preserve content that naturally starts with pipe", async () => { + const originalContent = "|header|another|\n|---|---|\n|data|more|\n" + const diffContent = `test.ts <<<<<<< SEARCH 1 | |header|another| 2 | |---|---| @@ -698,17 +709,17 @@ function five() { 2 | |---|---| 3 | |data|updated| >>>>>>> REPLACE` - - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe('|header|another|\n|---|---|\n|data|updated|\n') - } - }) - - it('should preserve indentation when stripping line numbers', async () => { - const originalContent = ' function test() {\n return true;\n }\n' - const diffContent = `test.ts + + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe("|header|another|\n|---|---|\n|data|updated|\n") + } + }) + + it("should preserve indentation when stripping line numbers", async () => { + const originalContent = " function test() {\n return true;\n }\n" + const diffContent = `test.ts <<<<<<< SEARCH 1 | function test() { 2 | return true; @@ -718,17 +729,17 @@ function five() { 2 | return false; 3 | } >>>>>>> REPLACE` - - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(' function test() {\n return false;\n }\n') - } - }) - - it('should handle different line numbers between sections', async () => { - const originalContent = 'function test() {\n return true;\n}\n' - const diffContent = `test.ts + + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(" function test() {\n return false;\n }\n") + } + }) + + it("should handle different line numbers between sections", async () => { + const originalContent = "function test() {\n return true;\n}\n" + const diffContent = `test.ts <<<<<<< SEARCH 10 | function test() { 11 | return true; @@ -738,17 +749,17 @@ function five() { 21 | return false; 22 | } >>>>>>> REPLACE` - - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe('function test() {\n return false;\n}\n') - } - }) - it('should not strip content that starts with pipe but no line number', async () => { - const originalContent = '| Pipe\n|---|\n| Data\n' - const diffContent = `test.ts + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe("function test() {\n return false;\n}\n") + } + }) + + it("should not strip content that starts with pipe but no line number", async () => { + const originalContent = "| Pipe\n|---|\n| Data\n" + const diffContent = `test.ts <<<<<<< SEARCH | Pipe |---| @@ -758,17 +769,17 @@ function five() { |---| | Updated >>>>>>> REPLACE` - - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe('| Pipe\n|---|\n| Updated\n') - } - }) - - it('should handle mix of line-numbered and pipe-only content', async () => { - const originalContent = '| Pipe\n|---|\n| Data\n' - const diffContent = `test.ts + + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe("| Pipe\n|---|\n| Updated\n") + } + }) + + it("should handle mix of line-numbered and pipe-only content", async () => { + const originalContent = "| Pipe\n|---|\n| Data\n" + const diffContent = `test.ts <<<<<<< SEARCH | Pipe |---| @@ -778,48 +789,48 @@ function five() { 2 | |---| 3 | | NewData >>>>>>> REPLACE` - - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe('1 | | Pipe\n2 | |---|\n3 | | NewData\n') - } - }) - }) - }); - describe('insertion/deletion', () => { - let strategy: SearchReplaceDiffStrategy - - beforeEach(() => { - strategy = new SearchReplaceDiffStrategy() - }) - - describe('deletion', () => { - it('should delete code when replace block is empty', async () => { - const originalContent = `function test() { + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe("1 | | Pipe\n2 | |---|\n3 | | NewData\n") + } + }) + }) + }) + + describe("insertion/deletion", () => { + let strategy: SearchReplaceDiffStrategy + + beforeEach(() => { + strategy = new SearchReplaceDiffStrategy() + }) + + describe("deletion", () => { + it("should delete code when replace block is empty", async () => { + const originalContent = `function test() { console.log("hello"); // Comment to remove console.log("world"); }` - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH // Comment to remove ======= >>>>>>> REPLACE` - - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`function test() { + + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`function test() { console.log("hello"); console.log("world"); }`) - } - }) - - it('should delete multiple lines when replace block is empty', async () => { - const originalContent = `class Example { + } + }) + + it("should delete multiple lines when replace block is empty", async () => { + const originalContent = `class Example { constructor() { // Initialize this.value = 0; @@ -828,7 +839,7 @@ function five() { // End init } }` - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH // Initialize this.value = 0; @@ -837,19 +848,19 @@ function five() { // End init ======= >>>>>>> REPLACE` - - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`class Example { + + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`class Example { constructor() { } }`) - } - }) - - it('should preserve indentation when deleting nested code', async () => { - const originalContent = `function outer() { + } + }) + + it("should preserve indentation when deleting nested code", async () => { + const originalContent = `function outer() { if (true) { // Remove this console.log("test"); @@ -857,146 +868,147 @@ function five() { } return true; }` - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH // Remove this console.log("test"); // And this ======= >>>>>>> REPLACE` - - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`function outer() { + + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`function outer() { if (true) { } return true; }`) - } - }) - }) - - describe('insertion', () => { - it('should insert code at specified line when search block is empty', async () => { - const originalContent = `function test() { + } + }) + }) + + describe("insertion", () => { + it("should insert code at specified line when search block is empty", async () => { + const originalContent = `function test() { const x = 1; return x; }` - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH ======= console.log("Adding log"); >>>>>>> REPLACE` - - const result = await strategy.applyDiff(originalContent, diffContent, 2, 2) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`function test() { + + const result = await strategy.applyDiff(originalContent, diffContent, 2, 2) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`function test() { console.log("Adding log"); const x = 1; return x; }`) - } - }) - - it('should preserve indentation when inserting at nested location', async () => { - const originalContent = `function test() { + } + }) + + it("should preserve indentation when inserting at nested location", async () => { + const originalContent = `function test() { if (true) { const x = 1; } }` - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH ======= console.log("Before"); console.log("After"); >>>>>>> REPLACE` - - const result = await strategy.applyDiff(originalContent, diffContent, 3, 3) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`function test() { + + const result = await strategy.applyDiff(originalContent, diffContent, 3, 3) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`function test() { if (true) { console.log("Before"); console.log("After"); const x = 1; } }`) - } - }) - - it('should handle insertion at start of file', async () => { - const originalContent = `function test() { + } + }) + + it("should handle insertion at start of file", async () => { + const originalContent = `function test() { return true; }` - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH ======= // Copyright 2024 // License: MIT >>>>>>> REPLACE` - - const result = await strategy.applyDiff(originalContent, diffContent, 1, 1) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`// Copyright 2024 + + const result = await strategy.applyDiff(originalContent, diffContent, 1, 1) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`// Copyright 2024 // License: MIT function test() { return true; }`) - } - }) - - it('should handle insertion at end of file', async () => { - const originalContent = `function test() { + } + }) + + it("should handle insertion at end of file", async () => { + const originalContent = `function test() { return true; }` - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH ======= // End of file >>>>>>> REPLACE` - - const result = await strategy.applyDiff(originalContent, diffContent, 4, 4) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`function test() { + + const result = await strategy.applyDiff(originalContent, diffContent, 4, 4) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`function test() { return true; } // End of file`) - } - }) - - it('should error if no start_line is provided for insertion', async () => { - const originalContent = `function test() { + } + }) + + it("should error if no start_line is provided for insertion", async () => { + const originalContent = `function test() { return true; }` - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH ======= console.log("test"); >>>>>>> REPLACE` - - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(false) - }) - }) - }) - describe('fuzzy matching', () => { - let strategy: SearchReplaceDiffStrategy - beforeEach(() => { - strategy = new SearchReplaceDiffStrategy(0.9, 5) // 90% similarity threshold, 5 line buffer for tests - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(false) + }) + }) + }) - it('should match content with small differences (>90% similar)', async () => { - const originalContent = 'function getData() {\n const results = fetchData();\n return results.filter(Boolean);\n}\n' - const diffContent = `test.ts + describe("fuzzy matching", () => { + let strategy: SearchReplaceDiffStrategy + beforeEach(() => { + strategy = new SearchReplaceDiffStrategy(0.9, 5) // 90% similarity threshold, 5 line buffer for tests + }) + + it("should match content with small differences (>90% similar)", async () => { + const originalContent = + "function getData() {\n const results = fetchData();\n return results.filter(Boolean);\n}\n" + const diffContent = `test.ts <<<<<<< SEARCH function getData() { const result = fetchData(); @@ -1009,18 +1021,20 @@ function getData() { } >>>>>>> REPLACE` - strategy = new SearchReplaceDiffStrategy(0.9, 5) // Use 5 line buffer for tests + strategy = new SearchReplaceDiffStrategy(0.9, 5) // Use 5 line buffer for tests - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe('function getData() {\n const data = fetchData();\n return data.filter(Boolean);\n}\n') - } - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe( + "function getData() {\n const data = fetchData();\n return data.filter(Boolean);\n}\n", + ) + } + }) - it('should not match when content is too different (<90% similar)', async () => { - const originalContent = 'function processUsers(data) {\n return data.map(user => user.name);\n}\n' - const diffContent = `test.ts + it("should not match when content is too different (<90% similar)", async () => { + const originalContent = "function processUsers(data) {\n return data.map(user => user.name);\n}\n" + const diffContent = `test.ts <<<<<<< SEARCH function handleItems(items) { return items.map(item => item.username); @@ -1031,13 +1045,13 @@ function processData(data) { } >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(false) - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(false) + }) - it('should match content with extra whitespace', async () => { - const originalContent = 'function sum(a, b) {\n return a + b;\n}' - const diffContent = `test.ts + it("should match content with extra whitespace", async () => { + const originalContent = "function sum(a, b) {\n return a + b;\n}" + const diffContent = `test.ts <<<<<<< SEARCH function sum(a, b) { return a + b; @@ -1048,16 +1062,16 @@ function sum(a, b) { } >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe('function sum(a, b) {\n return a + b + 1;\n}') - } - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe("function sum(a, b) {\n return a + b + 1;\n}") + } + }) - it('should not exact match empty lines', async () => { - const originalContent = 'function sum(a, b) {\n\n return a + b;\n}' - const diffContent = `test.ts + it("should not exact match empty lines", async () => { + const originalContent = "function sum(a, b) {\n\n return a + b;\n}" + const diffContent = `test.ts <<<<<<< SEARCH function sum(a, b) { ======= @@ -1065,23 +1079,23 @@ import { a } from "a"; function sum(a, b) { >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe('import { a } from "a";\nfunction sum(a, b) {\n\n return a + b;\n}') - } - }) - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe('import { a } from "a";\nfunction sum(a, b) {\n\n return a + b;\n}') + } + }) + }) - describe('line-constrained search', () => { - let strategy: SearchReplaceDiffStrategy + describe("line-constrained search", () => { + let strategy: SearchReplaceDiffStrategy - beforeEach(() => { - strategy = new SearchReplaceDiffStrategy(0.9, 5) - }) + beforeEach(() => { + strategy = new SearchReplaceDiffStrategy(0.9, 5) + }) - it('should find and replace within specified line range', async () => { - const originalContent = ` + it("should find and replace within specified line range", async () => { + const originalContent = ` function one() { return 1; } @@ -1094,7 +1108,7 @@ function three() { return 3; } `.trim() - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH function two() { return 2; @@ -1105,10 +1119,10 @@ function two() { } >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent, 5, 7) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`function one() { + const result = await strategy.applyDiff(originalContent, diffContent, 5, 7) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`function one() { return 1; } @@ -1119,11 +1133,11 @@ function two() { function three() { return 3; }`) - } - }) + } + }) - it('should find and replace within buffer zone (5 lines before/after)', async () => { - const originalContent = ` + it("should find and replace within buffer zone (5 lines before/after)", async () => { + const originalContent = ` function one() { return 1; } @@ -1136,7 +1150,7 @@ function three() { return 3; } `.trim() - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH function three() { return 3; @@ -1147,12 +1161,12 @@ function three() { } >>>>>>> REPLACE` - // Even though we specify lines 5-7, it should still find the match at lines 9-11 - // because it's within the 5-line buffer zone - const result = await strategy.applyDiff(originalContent, diffContent, 5, 7) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`function one() { + // Even though we specify lines 5-7, it should still find the match at lines 9-11 + // because it's within the 5-line buffer zone + const result = await strategy.applyDiff(originalContent, diffContent, 5, 7) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`function one() { return 1; } @@ -1163,11 +1177,11 @@ function two() { function three() { return "three"; }`) - } - }) + } + }) - it('should not find matches outside search range and buffer zone', async () => { - const originalContent = ` + it("should not find matches outside search range and buffer zone", async () => { + const originalContent = ` function one() { return 1; } @@ -1188,7 +1202,7 @@ function five() { return 5; } `.trim() - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH function five() { return 5; @@ -1199,14 +1213,14 @@ function five() { } >>>>>>> REPLACE` - // Searching around function two() (lines 5-7) - // function five() is more than 5 lines away, so it shouldn't match - const result = await strategy.applyDiff(originalContent, diffContent, 5, 7) - expect(result.success).toBe(false) - }) + // Searching around function two() (lines 5-7) + // function five() is more than 5 lines away, so it shouldn't match + const result = await strategy.applyDiff(originalContent, diffContent, 5, 7) + expect(result.success).toBe(false) + }) - it('should handle search range at start of file', async () => { - const originalContent = ` + it("should handle search range at start of file", async () => { + const originalContent = ` function one() { return 1; } @@ -1215,7 +1229,7 @@ function two() { return 2; } `.trim() - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH function one() { return 1; @@ -1226,21 +1240,21 @@ function one() { } >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent, 1, 3) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`function one() { + const result = await strategy.applyDiff(originalContent, diffContent, 1, 3) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`function one() { return "one"; } function two() { return 2; }`) - } - }) + } + }) - it('should handle search range at end of file', async () => { - const originalContent = ` + it("should handle search range at end of file", async () => { + const originalContent = ` function one() { return 1; } @@ -1249,7 +1263,7 @@ function two() { return 2; } `.trim() - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH function two() { return 2; @@ -1260,21 +1274,21 @@ function two() { } >>>>>>> REPLACE` - const result = await strategy.applyDiff(originalContent, diffContent, 5, 7) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`function one() { + const result = await strategy.applyDiff(originalContent, diffContent, 5, 7) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`function one() { return 1; } function two() { return "two"; }`) - } - }) + } + }) - it('should match specific instance of duplicate code using line numbers', async () => { - const originalContent = ` + it("should match specific instance of duplicate code using line numbers", async () => { + const originalContent = ` function processData(data) { return data.map(x => x * 2); } @@ -1292,7 +1306,7 @@ function moreStuff() { console.log("world"); } `.trim() - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH function processData(data) { return data.map(x => x * 2); @@ -1305,11 +1319,11 @@ function processData(data) { } >>>>>>> REPLACE` - // Target the second instance of processData - const result = await strategy.applyDiff(originalContent, diffContent, 10, 12) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`function processData(data) { + // Target the second instance of processData + const result = await strategy.applyDiff(originalContent, diffContent, 10, 12) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`function processData(data) { return data.map(x => x * 2); } @@ -1327,11 +1341,11 @@ function processData(data) { function moreStuff() { console.log("world"); }`) - } - }) + } + }) - it('should search from start line to end of file when only start_line is provided', async () => { - const originalContent = ` + it("should search from start line to end of file when only start_line is provided", async () => { + const originalContent = ` function one() { return 1; } @@ -1344,7 +1358,7 @@ function three() { return 3; } `.trim() - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH function three() { return 3; @@ -1355,11 +1369,11 @@ function three() { } >>>>>>> REPLACE` - // Only provide start_line, should search from there to end of file - const result = await strategy.applyDiff(originalContent, diffContent, 8) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`function one() { + // Only provide start_line, should search from there to end of file + const result = await strategy.applyDiff(originalContent, diffContent, 8) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`function one() { return 1; } @@ -1370,11 +1384,11 @@ function two() { function three() { return "three"; }`) - } - }) + } + }) - it('should search from start of file to end line when only end_line is provided', async () => { - const originalContent = ` + it("should search from start of file to end line when only end_line is provided", async () => { + const originalContent = ` function one() { return 1; } @@ -1387,7 +1401,7 @@ function three() { return 3; } `.trim() - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH function one() { return 1; @@ -1398,11 +1412,11 @@ function one() { } >>>>>>> REPLACE` - // Only provide end_line, should search from start of file to there - const result = await strategy.applyDiff(originalContent, diffContent, undefined, 4) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`function one() { + // Only provide end_line, should search from start of file to there + const result = await strategy.applyDiff(originalContent, diffContent, undefined, 4) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`function one() { return "one"; } @@ -1413,11 +1427,11 @@ function two() { function three() { return 3; }`) - } - }) + } + }) - it('should prioritize exact line match over expanded search', async () => { - const originalContent = ` + it("should prioritize exact line match over expanded search", async () => { + const originalContent = ` function one() { return 1; } @@ -1433,7 +1447,7 @@ function process() { function two() { return 2; }` - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH function process() { return "old"; @@ -1444,12 +1458,12 @@ function process() { } >>>>>>> REPLACE` - // Should match the second instance exactly at lines 10-12 - // even though the first instance at 6-8 is within the expanded search range - const result = await strategy.applyDiff(originalContent, diffContent, 10, 12) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(` + // Should match the second instance exactly at lines 10-12 + // even though the first instance at 6-8 is within the expanded search range + const result = await strategy.applyDiff(originalContent, diffContent, 10, 12) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(` function one() { return 1; } @@ -1465,11 +1479,11 @@ function process() { function two() { return 2; }`) - } - }) + } + }) - it('should fall back to expanded search only if exact match fails', async () => { - const originalContent = ` + it("should fall back to expanded search only if exact match fails", async () => { + const originalContent = ` function one() { return 1; } @@ -1481,7 +1495,7 @@ function process() { function two() { return 2; }`.trim() - const diffContent = `test.ts + const diffContent = `test.ts <<<<<<< SEARCH function process() { return "target"; @@ -1492,12 +1506,12 @@ function process() { } >>>>>>> REPLACE` - // Specify wrong line numbers (3-5), but content exists at 6-8 - // Should still find and replace it since it's within the expanded range - const result = await strategy.applyDiff(originalContent, diffContent, 3, 5) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(`function one() { + // Specify wrong line numbers (3-5), but content exists at 6-8 + // Should still find and replace it since it's within the expanded range + const result = await strategy.applyDiff(originalContent, diffContent, 3, 5) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(`function one() { return 1; } @@ -1508,36 +1522,36 @@ function process() { function two() { return 2; }`) - } - }) - }) + } + }) + }) - describe('getToolDescription', () => { - let strategy: SearchReplaceDiffStrategy + describe("getToolDescription", () => { + let strategy: SearchReplaceDiffStrategy - beforeEach(() => { - strategy = new SearchReplaceDiffStrategy() - }) + beforeEach(() => { + strategy = new SearchReplaceDiffStrategy() + }) - it('should include the current working directory', async () => { - const cwd = '/test/dir' - const description = await strategy.getToolDescription({ cwd }) - expect(description).toContain(`relative to the current working directory ${cwd}`) - }) + it("should include the current working directory", async () => { + const cwd = "/test/dir" + const description = await strategy.getToolDescription({ cwd }) + expect(description).toContain(`relative to the current working directory ${cwd}`) + }) - it('should include required format elements', async () => { - const description = await strategy.getToolDescription({ cwd: '/test' }) - expect(description).toContain('<<<<<<< SEARCH') - expect(description).toContain('=======') - expect(description).toContain('>>>>>>> REPLACE') - expect(description).toContain('') - expect(description).toContain('') - }) + it("should include required format elements", async () => { + const description = await strategy.getToolDescription({ cwd: "/test" }) + expect(description).toContain("<<<<<<< SEARCH") + expect(description).toContain("=======") + expect(description).toContain(">>>>>>> REPLACE") + expect(description).toContain("") + expect(description).toContain("") + }) - it('should document start_line and end_line parameters', async () => { - const description = await strategy.getToolDescription({ cwd: '/test' }) - expect(description).toContain('start_line: (required) The line number where the search block starts.') - expect(description).toContain('end_line: (required) The line number where the search block ends.') - }) - }) + it("should document start_line and end_line parameters", async () => { + const description = await strategy.getToolDescription({ cwd: "/test" }) + expect(description).toContain("start_line: (required) The line number where the search block starts.") + expect(description).toContain("end_line: (required) The line number where the search block ends.") + }) + }) }) diff --git a/src/core/diff/strategies/__tests__/unified.test.ts b/src/core/diff/strategies/__tests__/unified.test.ts index 949db04..1d9847b 100644 --- a/src/core/diff/strategies/__tests__/unified.test.ts +++ b/src/core/diff/strategies/__tests__/unified.test.ts @@ -1,27 +1,27 @@ -import { UnifiedDiffStrategy } from '../unified' +import { UnifiedDiffStrategy } from "../unified" -describe('UnifiedDiffStrategy', () => { - let strategy: UnifiedDiffStrategy +describe("UnifiedDiffStrategy", () => { + let strategy: UnifiedDiffStrategy - beforeEach(() => { - strategy = new UnifiedDiffStrategy() - }) + beforeEach(() => { + strategy = new UnifiedDiffStrategy() + }) - describe('getToolDescription', () => { - it('should return tool description with correct cwd', () => { - const cwd = '/test/path' - const description = strategy.getToolDescription({ cwd }) - - expect(description).toContain('apply_diff') - expect(description).toContain(cwd) - expect(description).toContain('Parameters:') - expect(description).toContain('Format Requirements:') - }) - }) + describe("getToolDescription", () => { + it("should return tool description with correct cwd", () => { + const cwd = "/test/path" + const description = strategy.getToolDescription({ cwd }) - describe('applyDiff', () => { - it('should successfully apply a function modification diff', async () => { - const originalContent = `import { Logger } from '../logger'; + expect(description).toContain("apply_diff") + expect(description).toContain(cwd) + expect(description).toContain("Parameters:") + expect(description).toContain("Format Requirements:") + }) + }) + + describe("applyDiff", () => { + it("should successfully apply a function modification diff", async () => { + const originalContent = `import { Logger } from '../logger'; function calculateTotal(items: number[]): number { return items.reduce((sum, item) => { @@ -31,7 +31,7 @@ function calculateTotal(items: number[]): number { export { calculateTotal };` - const diffContent = `--- src/utils/helper.ts + const diffContent = `--- src/utils/helper.ts +++ src/utils/helper.ts @@ -1,9 +1,10 @@ import { Logger } from '../logger'; @@ -47,7 +47,7 @@ export { calculateTotal };` export { calculateTotal };` - const expected = `import { Logger } from '../logger'; + const expected = `import { Logger } from '../logger'; function calculateTotal(items: number[]): number { const total = items.reduce((sum, item) => { @@ -58,21 +58,21 @@ function calculateTotal(items: number[]): number { export { calculateTotal };` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(expected) - } - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(expected) + } + }) - it('should successfully apply a diff adding a new method', async () => { - const originalContent = `class Calculator { + it("should successfully apply a diff adding a new method", async () => { + const originalContent = `class Calculator { add(a: number, b: number): number { return a + b; } }` - const diffContent = `--- src/Calculator.ts + const diffContent = `--- src/Calculator.ts +++ src/Calculator.ts @@ -1,5 +1,9 @@ class Calculator { @@ -85,7 +85,7 @@ export { calculateTotal };` + } }` - const expected = `class Calculator { + const expected = `class Calculator { add(a: number, b: number): number { return a + b; } @@ -95,15 +95,15 @@ export { calculateTotal };` } }` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(expected) - } - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(expected) + } + }) - it('should successfully apply a diff modifying imports', async () => { - const originalContent = `import { useState } from 'react'; + it("should successfully apply a diff modifying imports", async () => { + const originalContent = `import { useState } from 'react'; import { Button } from './components'; function App() { @@ -111,7 +111,7 @@ function App() { return ; }` - const diffContent = `--- src/App.tsx + const diffContent = `--- src/App.tsx +++ src/App.tsx @@ -1,7 +1,8 @@ -import { useState } from 'react'; @@ -124,7 +124,7 @@ function App() { return ; }` - const expected = `import { useState, useEffect } from 'react'; + const expected = `import { useState, useEffect } from 'react'; import { Button } from './components'; function App() { @@ -132,16 +132,16 @@ function App() { useEffect(() => { document.title = \`Count: \${count}\` }, [count]); return ; }` - - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(expected) - } - }) - it('should successfully apply a diff with multiple hunks', async () => { - const originalContent = `import { readFile, writeFile } from 'fs'; + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(expected) + } + }) + + it("should successfully apply a diff with multiple hunks", async () => { + const originalContent = `import { readFile, writeFile } from 'fs'; function processFile(path: string) { readFile(path, 'utf8', (err, data) => { @@ -155,7 +155,7 @@ function processFile(path: string) { export { processFile };` - const diffContent = `--- src/file-processor.ts + const diffContent = `--- src/file-processor.ts +++ src/file-processor.ts @@ -1,12 +1,14 @@ -import { readFile, writeFile } from 'fs'; @@ -182,7 +182,7 @@ export { processFile };` export { processFile };` - const expected = `import { promises as fs } from 'fs'; + const expected = `import { promises as fs } from 'fs'; import { join } from 'path'; async function processFile(path: string) { @@ -198,32 +198,31 @@ async function processFile(path: string) { export { processFile };` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(expected) - } - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(expected) + } + }) - it('should handle empty original content', async () => { - const originalContent = '' - const diffContent = `--- empty.ts + it("should handle empty original content", async () => { + const originalContent = "" + const diffContent = `--- empty.ts +++ empty.ts @@ -0,0 +1,3 @@ +export function greet(name: string): string { + return \`Hello, \${name}!\`; +}` - const expected = `export function greet(name: string): string { + const expected = `export function greet(name: string): string { return \`Hello, \${name}!\`; }\n` - const result = await strategy.applyDiff(originalContent, diffContent) - expect(result.success).toBe(true) - if (result.success) { - expect(result.content).toBe(expected) - } - }) - }) + const result = await strategy.applyDiff(originalContent, diffContent) + expect(result.success).toBe(true) + if (result.success) { + expect(result.content).toBe(expected) + } + }) + }) }) - diff --git a/src/core/diff/strategies/new-unified/__tests__/edit-strategies.test.ts b/src/core/diff/strategies/new-unified/__tests__/edit-strategies.test.ts index 2ed1cc9..2bc3554 100644 --- a/src/core/diff/strategies/new-unified/__tests__/edit-strategies.test.ts +++ b/src/core/diff/strategies/new-unified/__tests__/edit-strategies.test.ts @@ -265,8 +265,8 @@ describe("applyGitFallback", () => { { type: "context", content: "line1", indent: "" }, { type: "remove", content: "line2", indent: "" }, { type: "add", content: "new line2", indent: "" }, - { type: "context", content: "line3", indent: "" } - ] + { type: "context", content: "line3", indent: "" }, + ], } as Hunk const content = ["line1", "line2", "line3"] @@ -281,8 +281,8 @@ describe("applyGitFallback", () => { const hunk = { changes: [ { type: "context", content: "nonexistent", indent: "" }, - { type: "add", content: "new line", indent: "" } - ] + { type: "add", content: "new line", indent: "" }, + ], } as Hunk const content = ["line1", "line2", "line3"] diff --git a/src/core/diff/strategies/new-unified/__tests__/search-strategies.test.ts b/src/core/diff/strategies/new-unified/__tests__/search-strategies.test.ts index 6c4aba5..5bee537 100644 --- a/src/core/diff/strategies/new-unified/__tests__/search-strategies.test.ts +++ b/src/core/diff/strategies/new-unified/__tests__/search-strategies.test.ts @@ -3,7 +3,7 @@ import { findAnchorMatch, findExactMatch, findSimilarityMatch, findLevenshteinMa type SearchStrategy = ( searchStr: string, content: string[], - startIndex?: number + startIndex?: number, ) => { index: number confidence: number @@ -11,141 +11,141 @@ type SearchStrategy = ( } const testCases = [ - { - name: "should return no match if the search string is not found", - searchStr: "not found", - content: ["line1", "line2", "line3"], - expected: { index: -1, confidence: 0 }, - strategies: ["exact", "similarity", "levenshtein"], - }, - { - name: "should return a match if the search string is found", - searchStr: "line2", - content: ["line1", "line2", "line3"], - expected: { index: 1, confidence: 1 }, - strategies: ["exact", "similarity", "levenshtein"], - }, - { - name: "should return a match with correct index when startIndex is provided", - searchStr: "line3", - content: ["line1", "line2", "line3", "line4", "line3"], - startIndex: 3, - expected: { index: 4, confidence: 1 }, - strategies: ["exact", "similarity", "levenshtein"], - }, - { - name: "should return a match even if there are more lines in content", - searchStr: "line2", - content: ["line1", "line2", "line3", "line4", "line5"], - expected: { index: 1, confidence: 1 }, - strategies: ["exact", "similarity", "levenshtein"], - }, - { - name: "should return a match even if the search string is at the beginning of the content", - searchStr: "line1", - content: ["line1", "line2", "line3"], - expected: { index: 0, confidence: 1 }, - strategies: ["exact", "similarity", "levenshtein"], - }, - { - name: "should return a match even if the search string is at the end of the content", - searchStr: "line3", - content: ["line1", "line2", "line3"], - expected: { index: 2, confidence: 1 }, - strategies: ["exact", "similarity", "levenshtein"], - }, - { - name: "should return a match for a multi-line search string", - searchStr: "line2\nline3", - content: ["line1", "line2", "line3", "line4"], - expected: { index: 1, confidence: 1 }, - strategies: ["exact", "similarity", "levenshtein"], - }, - { - name: "should return no match if a multi-line search string is not found", - searchStr: "line2\nline4", - content: ["line1", "line2", "line3", "line4"], - expected: { index: -1, confidence: 0 }, - strategies: ["exact", "similarity"], - }, - { - name: "should return a match with indentation", - searchStr: " line2", - content: ["line1", " line2", "line3"], - expected: { index: 1, confidence: 1 }, - strategies: ["exact", "similarity", "levenshtein"], - }, - { - name: "should return a match with more complex indentation", - searchStr: " line3", - content: [" line1", " line2", " line3", " line4"], - expected: { index: 2, confidence: 1 }, - strategies: ["exact", "similarity", "levenshtein"], - }, - { - name: "should return a match with mixed indentation", - searchStr: "\tline2", - content: [" line1", "\tline2", " line3"], - expected: { index: 1, confidence: 1 }, - strategies: ["exact", "similarity", "levenshtein"], - }, - { - name: "should return a match with mixed indentation and multi-line", - searchStr: " line2\n\tline3", - content: ["line1", " line2", "\tline3", " line4"], - expected: { index: 1, confidence: 1 }, - strategies: ["exact", "similarity", "levenshtein"], - }, - { - name: "should return no match if mixed indentation and multi-line is not found", - searchStr: " line2\n line4", - content: ["line1", " line2", "\tline3", " line4"], - expected: { index: -1, confidence: 0 }, - strategies: ["exact", "similarity"], - }, - { - name: "should return a match with leading and trailing spaces", - searchStr: " line2 ", - content: ["line1", " line2 ", "line3"], - expected: { index: 1, confidence: 1 }, - strategies: ["exact", "similarity", "levenshtein"], - }, - { - name: "should return a match with leading and trailing tabs", - searchStr: "\tline2\t", - content: ["line1", "\tline2\t", "line3"], - expected: { index: 1, confidence: 1 }, - strategies: ["exact", "similarity", "levenshtein"], - }, - { - name: "should return a match with mixed leading and trailing spaces and tabs", - searchStr: " \tline2\t ", - content: ["line1", " \tline2\t ", "line3"], - expected: { index: 1, confidence: 1 }, - strategies: ["exact", "similarity", "levenshtein"], - }, - { - name: "should return a match with mixed leading and trailing spaces and tabs and multi-line", - searchStr: " \tline2\t \n line3 ", - content: ["line1", " \tline2\t ", " line3 ", "line4"], - expected: { index: 1, confidence: 1 }, - strategies: ["exact", "similarity", "levenshtein"], - }, - { - name: "should return no match if mixed leading and trailing spaces and tabs and multi-line is not found", - searchStr: " \tline2\t \n line4 ", - content: ["line1", " \tline2\t ", " line3 ", "line4"], - expected: { index: -1, confidence: 0 }, - strategies: ["exact", "similarity"], - }, + { + name: "should return no match if the search string is not found", + searchStr: "not found", + content: ["line1", "line2", "line3"], + expected: { index: -1, confidence: 0 }, + strategies: ["exact", "similarity", "levenshtein"], + }, + { + name: "should return a match if the search string is found", + searchStr: "line2", + content: ["line1", "line2", "line3"], + expected: { index: 1, confidence: 1 }, + strategies: ["exact", "similarity", "levenshtein"], + }, + { + name: "should return a match with correct index when startIndex is provided", + searchStr: "line3", + content: ["line1", "line2", "line3", "line4", "line3"], + startIndex: 3, + expected: { index: 4, confidence: 1 }, + strategies: ["exact", "similarity", "levenshtein"], + }, + { + name: "should return a match even if there are more lines in content", + searchStr: "line2", + content: ["line1", "line2", "line3", "line4", "line5"], + expected: { index: 1, confidence: 1 }, + strategies: ["exact", "similarity", "levenshtein"], + }, + { + name: "should return a match even if the search string is at the beginning of the content", + searchStr: "line1", + content: ["line1", "line2", "line3"], + expected: { index: 0, confidence: 1 }, + strategies: ["exact", "similarity", "levenshtein"], + }, + { + name: "should return a match even if the search string is at the end of the content", + searchStr: "line3", + content: ["line1", "line2", "line3"], + expected: { index: 2, confidence: 1 }, + strategies: ["exact", "similarity", "levenshtein"], + }, + { + name: "should return a match for a multi-line search string", + searchStr: "line2\nline3", + content: ["line1", "line2", "line3", "line4"], + expected: { index: 1, confidence: 1 }, + strategies: ["exact", "similarity", "levenshtein"], + }, + { + name: "should return no match if a multi-line search string is not found", + searchStr: "line2\nline4", + content: ["line1", "line2", "line3", "line4"], + expected: { index: -1, confidence: 0 }, + strategies: ["exact", "similarity"], + }, + { + name: "should return a match with indentation", + searchStr: " line2", + content: ["line1", " line2", "line3"], + expected: { index: 1, confidence: 1 }, + strategies: ["exact", "similarity", "levenshtein"], + }, + { + name: "should return a match with more complex indentation", + searchStr: " line3", + content: [" line1", " line2", " line3", " line4"], + expected: { index: 2, confidence: 1 }, + strategies: ["exact", "similarity", "levenshtein"], + }, + { + name: "should return a match with mixed indentation", + searchStr: "\tline2", + content: [" line1", "\tline2", " line3"], + expected: { index: 1, confidence: 1 }, + strategies: ["exact", "similarity", "levenshtein"], + }, + { + name: "should return a match with mixed indentation and multi-line", + searchStr: " line2\n\tline3", + content: ["line1", " line2", "\tline3", " line4"], + expected: { index: 1, confidence: 1 }, + strategies: ["exact", "similarity", "levenshtein"], + }, + { + name: "should return no match if mixed indentation and multi-line is not found", + searchStr: " line2\n line4", + content: ["line1", " line2", "\tline3", " line4"], + expected: { index: -1, confidence: 0 }, + strategies: ["exact", "similarity"], + }, + { + name: "should return a match with leading and trailing spaces", + searchStr: " line2 ", + content: ["line1", " line2 ", "line3"], + expected: { index: 1, confidence: 1 }, + strategies: ["exact", "similarity", "levenshtein"], + }, + { + name: "should return a match with leading and trailing tabs", + searchStr: "\tline2\t", + content: ["line1", "\tline2\t", "line3"], + expected: { index: 1, confidence: 1 }, + strategies: ["exact", "similarity", "levenshtein"], + }, + { + name: "should return a match with mixed leading and trailing spaces and tabs", + searchStr: " \tline2\t ", + content: ["line1", " \tline2\t ", "line3"], + expected: { index: 1, confidence: 1 }, + strategies: ["exact", "similarity", "levenshtein"], + }, + { + name: "should return a match with mixed leading and trailing spaces and tabs and multi-line", + searchStr: " \tline2\t \n line3 ", + content: ["line1", " \tline2\t ", " line3 ", "line4"], + expected: { index: 1, confidence: 1 }, + strategies: ["exact", "similarity", "levenshtein"], + }, + { + name: "should return no match if mixed leading and trailing spaces and tabs and multi-line is not found", + searchStr: " \tline2\t \n line4 ", + content: ["line1", " \tline2\t ", " line3 ", "line4"], + expected: { index: -1, confidence: 0 }, + strategies: ["exact", "similarity"], + }, ] describe("findExactMatch", () => { - testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => { + testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => { if (!strategies?.includes("exact")) { return } - it(name, () => { + it(name, () => { const result = findExactMatch(searchStr, content, startIndex) expect(result.index).toBe(expected.index) expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence) @@ -155,16 +155,16 @@ describe("findExactMatch", () => { }) describe("findAnchorMatch", () => { - const anchorTestCases = [ - { - name: "should return no match if no anchors are found", - searchStr: " \n \n ", - content: ["line1", "line2", "line3"], - expected: { index: -1, confidence: 0 }, - }, - { - name: "should return no match if anchor positions cannot be validated", - searchStr: "unique line\ncontext line 1\ncontext line 2", + const anchorTestCases = [ + { + name: "should return no match if no anchors are found", + searchStr: " \n \n ", + content: ["line1", "line2", "line3"], + expected: { index: -1, confidence: 0 }, + }, + { + name: "should return no match if anchor positions cannot be validated", + searchStr: "unique line\ncontext line 1\ncontext line 2", content: [ "different line 1", "different line 2", @@ -173,24 +173,24 @@ describe("findAnchorMatch", () => { "context line 1", "context line 2", ], - expected: { index: -1, confidence: 0 }, - }, - { - name: "should return a match if anchor positions can be validated", - searchStr: "unique line\ncontext line 1\ncontext line 2", - content: ["line1", "line2", "unique line", "context line 1", "context line 2", "line 6"], - expected: { index: 2, confidence: 1 }, - }, - { - name: "should return a match with correct index when startIndex is provided", - searchStr: "unique line\ncontext line 1\ncontext line 2", - content: ["line1", "line2", "line3", "unique line", "context line 1", "context line 2", "line 7"], - startIndex: 3, - expected: { index: 3, confidence: 1 }, - }, - { - name: "should return a match even if there are more lines in content", - searchStr: "unique line\ncontext line 1\ncontext line 2", + expected: { index: -1, confidence: 0 }, + }, + { + name: "should return a match if anchor positions can be validated", + searchStr: "unique line\ncontext line 1\ncontext line 2", + content: ["line1", "line2", "unique line", "context line 1", "context line 2", "line 6"], + expected: { index: 2, confidence: 1 }, + }, + { + name: "should return a match with correct index when startIndex is provided", + searchStr: "unique line\ncontext line 1\ncontext line 2", + content: ["line1", "line2", "line3", "unique line", "context line 1", "context line 2", "line 7"], + startIndex: 3, + expected: { index: 3, confidence: 1 }, + }, + { + name: "should return a match even if there are more lines in content", + searchStr: "unique line\ncontext line 1\ncontext line 2", content: [ "line1", "line2", @@ -201,30 +201,30 @@ describe("findAnchorMatch", () => { "extra line 1", "extra line 2", ], - expected: { index: 2, confidence: 1 }, - }, - { - name: "should return a match even if the anchor is at the beginning of the content", - searchStr: "unique line\ncontext line 1\ncontext line 2", - content: ["unique line", "context line 1", "context line 2", "line 6"], - expected: { index: 0, confidence: 1 }, - }, - { - name: "should return a match even if the anchor is at the end of the content", - searchStr: "unique line\ncontext line 1\ncontext line 2", - content: ["line1", "line2", "unique line", "context line 1", "context line 2"], - expected: { index: 2, confidence: 1 }, - }, - { - name: "should return no match if no valid anchor is found", - searchStr: "non-unique line\ncontext line 1\ncontext line 2", - content: ["line1", "line2", "non-unique line", "context line 1", "context line 2", "non-unique line"], - expected: { index: -1, confidence: 0 }, - }, + expected: { index: 2, confidence: 1 }, + }, + { + name: "should return a match even if the anchor is at the beginning of the content", + searchStr: "unique line\ncontext line 1\ncontext line 2", + content: ["unique line", "context line 1", "context line 2", "line 6"], + expected: { index: 0, confidence: 1 }, + }, + { + name: "should return a match even if the anchor is at the end of the content", + searchStr: "unique line\ncontext line 1\ncontext line 2", + content: ["line1", "line2", "unique line", "context line 1", "context line 2"], + expected: { index: 2, confidence: 1 }, + }, + { + name: "should return no match if no valid anchor is found", + searchStr: "non-unique line\ncontext line 1\ncontext line 2", + content: ["line1", "line2", "non-unique line", "context line 1", "context line 2", "non-unique line"], + expected: { index: -1, confidence: 0 }, + }, ] - anchorTestCases.forEach(({ name, searchStr, content, startIndex, expected }) => { - it(name, () => { + anchorTestCases.forEach(({ name, searchStr, content, startIndex, expected }) => { + it(name, () => { const result = findAnchorMatch(searchStr, content, startIndex) expect(result.index).toBe(expected.index) expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence) @@ -234,11 +234,11 @@ describe("findAnchorMatch", () => { }) describe("findSimilarityMatch", () => { - testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => { + testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => { if (!strategies?.includes("similarity")) { return } - it(name, () => { + it(name, () => { const result = findSimilarityMatch(searchStr, content, startIndex) expect(result.index).toBe(expected.index) expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence) @@ -248,11 +248,11 @@ describe("findSimilarityMatch", () => { }) describe("findLevenshteinMatch", () => { - testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => { + testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => { if (!strategies?.includes("levenshtein")) { return } - it(name, () => { + it(name, () => { const result = findLevenshteinMatch(searchStr, content, startIndex) expect(result.index).toBe(expected.index) expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence) diff --git a/src/core/diff/strategies/new-unified/edit-strategies.ts b/src/core/diff/strategies/new-unified/edit-strategies.ts index 0828c83..81922b1 100644 --- a/src/core/diff/strategies/new-unified/edit-strategies.ts +++ b/src/core/diff/strategies/new-unified/edit-strategies.ts @@ -18,7 +18,7 @@ function inferIndentation(line: string, contextLines: string[], previousIndent: const contextLine = contextLines[0] if (contextLine) { const contextMatch = contextLine.match(/^(\s+)/) - if (contextMatch) { + if (contextMatch) { return contextMatch[1] } } @@ -28,19 +28,15 @@ function inferIndentation(line: string, contextLines: string[], previousIndent: } // Context matching edit strategy -export function applyContextMatching( - hunk: Hunk, - content: string[], - matchPosition: number, -): EditResult { - if (matchPosition === -1) { +export function applyContextMatching(hunk: Hunk, content: string[], matchPosition: number): EditResult { + if (matchPosition === -1) { return { confidence: 0, result: content, strategy: "context" } } const newResult = [...content.slice(0, matchPosition)] let sourceIndex = matchPosition - for (const change of hunk.changes) { + for (const change of hunk.changes) { if (change.type === "context") { // Use the original line from content if available if (sourceIndex < content.length) { @@ -82,20 +78,16 @@ export function applyContextMatching( const confidence = validateEditResult(hunk, afterText) - return { + return { confidence, result: newResult, - strategy: "context" + strategy: "context", } } // DMP edit strategy -export function applyDMP( - hunk: Hunk, - content: string[], - matchPosition: number, -): EditResult { - if (matchPosition === -1) { +export function applyDMP(hunk: Hunk, content: string[], matchPosition: number): EditResult { + if (matchPosition === -1) { return { confidence: 0, result: content, strategy: "dmp" } } @@ -105,9 +97,9 @@ export function applyDMP( const beforeLineCount = hunk.changes .filter((change) => change.type === "context" || change.type === "remove") .reduce((count, change) => count + change.content.split("\n").length, 0) - - // Build BEFORE block (context + removals) - const beforeLines = hunk.changes + + // Build BEFORE block (context + removals) + const beforeLines = hunk.changes .filter((change) => change.type === "context" || change.type === "remove") .map((change) => { if (change.originalLine) { @@ -115,9 +107,9 @@ export function applyDMP( } return change.indent ? change.indent + change.content : change.content }) - - // Build AFTER block (context + additions) - const afterLines = hunk.changes + + // Build AFTER block (context + additions) + const afterLines = hunk.changes .filter((change) => change.type === "context" || change.type === "add") .map((change) => { if (change.originalLine) { @@ -139,17 +131,17 @@ export function applyDMP( const patchedLines = patchedText.split("\n") // Construct final result - const newResult = [ - ...content.slice(0, matchPosition), - ...patchedLines, + const newResult = [ + ...content.slice(0, matchPosition), + ...patchedLines, ...content.slice(matchPosition + beforeLineCount), ] - + const confidence = validateEditResult(hunk, patchedText) - - return { + + return { confidence, - result: newResult, + result: newResult, strategy: "dmp", } } @@ -171,7 +163,7 @@ export async function applyGitFallback(hunk: Hunk, content: string[]): Promise change.type === "context" || change.type === "remove") .map((change) => change.originalLine || change.indent + change.content) - + const replaceLines = hunk.changes .filter((change) => change.type === "context" || change.type === "add") .map((change) => change.originalLine || change.indent + change.content) @@ -272,16 +264,16 @@ export async function applyGitFallback(hunk: Hunk, content: string[]): Promise { // Don't attempt regular edits if confidence is too low if (confidence < confidenceThreshold) { console.log( - `Search confidence (${confidence}) below minimum threshold (${confidenceThreshold}), trying git fallback...` + `Search confidence (${confidence}) below minimum threshold (${confidenceThreshold}), trying git fallback...`, ) return applyGitFallback(hunk, content) } diff --git a/src/core/diff/strategies/new-unified/index.ts b/src/core/diff/strategies/new-unified/index.ts index b0eac64..7c19384 100644 --- a/src/core/diff/strategies/new-unified/index.ts +++ b/src/core/diff/strategies/new-unified/index.ts @@ -242,7 +242,7 @@ Your diff here originalContent: string, diffContent: string, startLine?: number, - endLine?: number + endLine?: number, ): Promise { const parsedDiff = this.parseUnifiedDiff(diffContent) const originalLines = originalContent.split("\n") @@ -280,7 +280,7 @@ Your diff here subHunkResult, subSearchResult.index, subSearchResult.confidence, - this.confidenceThreshold + this.confidenceThreshold, ) if (subEditResult.confidence >= this.confidenceThreshold) { subHunkResult = subEditResult.result @@ -302,12 +302,12 @@ Your diff here const contextRatio = contextLines / totalLines let errorMsg = `Failed to find a matching location in the file (${Math.floor( - confidence * 100 + confidence * 100, )}% confidence, needs ${Math.floor(this.confidenceThreshold * 100)}%)\n\n` errorMsg += "Debug Info:\n" errorMsg += `- Search Strategy Used: ${strategy}\n` errorMsg += `- Context Lines: ${contextLines} out of ${totalLines} total lines (${Math.floor( - contextRatio * 100 + contextRatio * 100, )}%)\n` errorMsg += `- Attempted to split into ${subHunks.length} sub-hunks but still failed\n` @@ -339,7 +339,7 @@ Your diff here } else { // Edit failure - likely due to content mismatch let errorMsg = `Failed to apply the edit using ${editResult.strategy} strategy (${Math.floor( - editResult.confidence * 100 + editResult.confidence * 100, )}% confidence)\n\n` errorMsg += "Debug Info:\n" errorMsg += "- The location was found but the content didn't match exactly\n" diff --git a/src/core/diff/strategies/new-unified/search-strategies.ts b/src/core/diff/strategies/new-unified/search-strategies.ts index 7bee5ba..97fd499 100644 --- a/src/core/diff/strategies/new-unified/search-strategies.ts +++ b/src/core/diff/strategies/new-unified/search-strategies.ts @@ -69,26 +69,26 @@ export function getDMPSimilarity(original: string, modified: string): number { export function validateEditResult(hunk: Hunk, result: string): number { // Build the expected text from the hunk const expectedText = hunk.changes - .filter(change => change.type === "context" || change.type === "add") - .map(change => change.indent ? change.indent + change.content : change.content) - .join("\n"); + .filter((change) => change.type === "context" || change.type === "add") + .map((change) => (change.indent ? change.indent + change.content : change.content)) + .join("\n") // Calculate similarity between the result and expected text - const similarity = getDMPSimilarity(expectedText, result); + const similarity = getDMPSimilarity(expectedText, result) // If the result is unchanged from original, return low confidence const originalText = hunk.changes - .filter(change => change.type === "context" || change.type === "remove") - .map(change => change.indent ? change.indent + change.content : change.content) - .join("\n"); + .filter((change) => change.type === "context" || change.type === "remove") + .map((change) => (change.indent ? change.indent + change.content : change.content)) + .join("\n") - const originalSimilarity = getDMPSimilarity(originalText, result); + const originalSimilarity = getDMPSimilarity(originalText, result) if (originalSimilarity > 0.97 && similarity !== 1) { - return 0.8 * similarity; // Some confidence since we found the right location + return 0.8 * similarity // Some confidence since we found the right location } - + // For partial matches, scale the confidence but keep it high if we're close - return similarity; + return similarity } // Helper function to validate context lines against original content @@ -114,7 +114,7 @@ function validateContextLines(searchStr: string, content: string, confidenceThre function createOverlappingWindows( content: string[], searchSize: number, - overlapSize: number = DEFAULT_OVERLAP_SIZE + overlapSize: number = DEFAULT_OVERLAP_SIZE, ): { window: string[]; startIndex: number }[] { const windows: { window: string[]; startIndex: number }[] = [] @@ -140,7 +140,7 @@ function createOverlappingWindows( // Helper function to combine overlapping matches function combineOverlappingMatches( matches: (SearchResult & { windowIndex: number })[], - overlapSize: number = DEFAULT_OVERLAP_SIZE + overlapSize: number = DEFAULT_OVERLAP_SIZE, ): SearchResult[] { if (matches.length === 0) { return [] @@ -162,7 +162,7 @@ function combineOverlappingMatches( (m) => Math.abs(m.windowIndex - match.windowIndex) === 1 && Math.abs(m.index - match.index) <= overlapSize && - !usedIndices.has(m.windowIndex) + !usedIndices.has(m.windowIndex), ) if (overlapping.length > 0) { @@ -196,7 +196,7 @@ export function findExactMatch( searchStr: string, content: string[], startIndex: number = 0, - confidenceThreshold: number = 0.97 + confidenceThreshold: number = 0.97, ): SearchResult { const searchLines = searchStr.split("\n") const windows = createOverlappingWindows(content.slice(startIndex), searchLines.length) @@ -210,7 +210,7 @@ export function findExactMatch( const matchedContent = windowData.window .slice( windowStr.slice(0, exactMatch).split("\n").length - 1, - windowStr.slice(0, exactMatch).split("\n").length - 1 + searchLines.length + windowStr.slice(0, exactMatch).split("\n").length - 1 + searchLines.length, ) .join("\n") @@ -236,7 +236,7 @@ export function findSimilarityMatch( searchStr: string, content: string[], startIndex: number = 0, - confidenceThreshold: number = 0.97 + confidenceThreshold: number = 0.97, ): SearchResult { const searchLines = searchStr.split("\n") let bestScore = 0 @@ -269,7 +269,7 @@ export function findLevenshteinMatch( searchStr: string, content: string[], startIndex: number = 0, - confidenceThreshold: number = 0.97 + confidenceThreshold: number = 0.97, ): SearchResult { const searchLines = searchStr.split("\n") const candidates = [] @@ -324,7 +324,7 @@ export function findAnchorMatch( searchStr: string, content: string[], startIndex: number = 0, - confidenceThreshold: number = 0.97 + confidenceThreshold: number = 0.97, ): SearchResult { const searchLines = searchStr.split("\n") const { first, last } = identifyAnchors(searchStr) @@ -391,7 +391,7 @@ export function findBestMatch( searchStr: string, content: string[], startIndex: number = 0, - confidenceThreshold: number = 0.97 + confidenceThreshold: number = 0.97, ): SearchResult { const strategies = [findExactMatch, findAnchorMatch, findSimilarityMatch, findLevenshteinMatch] diff --git a/src/core/diff/strategies/new-unified/types.ts b/src/core/diff/strategies/new-unified/types.ts index a734f6e..0e243d3 100644 --- a/src/core/diff/strategies/new-unified/types.ts +++ b/src/core/diff/strategies/new-unified/types.ts @@ -1,20 +1,20 @@ export type Change = { - type: 'context' | 'add' | 'remove'; - content: string; - indent: string; - originalLine?: string; -}; + type: "context" | "add" | "remove" + content: string + indent: string + originalLine?: string +} export type Hunk = { - changes: Change[]; -}; + changes: Change[] +} export type Diff = { - hunks: Hunk[]; -}; + hunks: Hunk[] +} export type EditResult = { - confidence: number; - result: string[]; - strategy: string; -}; \ No newline at end of file + confidence: number + result: string[] + strategy: string +} diff --git a/src/core/diff/strategies/search-replace.ts b/src/core/diff/strategies/search-replace.ts index 2d2bcab..1ede3c3 100644 --- a/src/core/diff/strategies/search-replace.ts +++ b/src/core/diff/strategies/search-replace.ts @@ -1,72 +1,74 @@ import { DiffStrategy, DiffResult } from "../types" import { addLineNumbers, everyLineHasLineNumbers, stripLineNumbers } from "../../../integrations/misc/extract-text" -const BUFFER_LINES = 20; // Number of extra context lines to show before and after matches +const BUFFER_LINES = 20 // Number of extra context lines to show before and after matches function levenshteinDistance(a: string, b: string): number { - const matrix: number[][] = []; + const matrix: number[][] = [] - // Initialize matrix - for (let i = 0; i <= a.length; i++) { - matrix[i] = [i]; - } - for (let j = 0; j <= b.length; j++) { - matrix[0][j] = j; - } + // Initialize matrix + for (let i = 0; i <= a.length; i++) { + matrix[i] = [i] + } + for (let j = 0; j <= b.length; j++) { + matrix[0][j] = j + } - // Fill matrix - for (let i = 1; i <= a.length; i++) { - for (let j = 1; j <= b.length; j++) { - if (a[i-1] === b[j-1]) { - matrix[i][j] = matrix[i-1][j-1]; - } else { - matrix[i][j] = Math.min( - matrix[i-1][j-1] + 1, // substitution - matrix[i][j-1] + 1, // insertion - matrix[i-1][j] + 1 // deletion - ); - } - } - } + // Fill matrix + for (let i = 1; i <= a.length; i++) { + for (let j = 1; j <= b.length; j++) { + if (a[i - 1] === b[j - 1]) { + matrix[i][j] = matrix[i - 1][j - 1] + } else { + matrix[i][j] = Math.min( + matrix[i - 1][j - 1] + 1, // substitution + matrix[i][j - 1] + 1, // insertion + matrix[i - 1][j] + 1, // deletion + ) + } + } + } - return matrix[a.length][b.length]; + return matrix[a.length][b.length] } function getSimilarity(original: string, search: string): number { - if (search === '') { - return 1; - } + if (search === "") { + return 1 + } - // Normalize strings by removing extra whitespace but preserve case - const normalizeStr = (str: string) => str.replace(/\s+/g, ' ').trim(); - - const normalizedOriginal = normalizeStr(original); - const normalizedSearch = normalizeStr(search); - - if (normalizedOriginal === normalizedSearch) { return 1; } - - // Calculate Levenshtein distance - const distance = levenshteinDistance(normalizedOriginal, normalizedSearch); - - // Calculate similarity ratio (0 to 1, where 1 is exact match) - const maxLength = Math.max(normalizedOriginal.length, normalizedSearch.length); - return 1 - (distance / maxLength); + // Normalize strings by removing extra whitespace but preserve case + const normalizeStr = (str: string) => str.replace(/\s+/g, " ").trim() + + const normalizedOriginal = normalizeStr(original) + const normalizedSearch = normalizeStr(search) + + if (normalizedOriginal === normalizedSearch) { + return 1 + } + + // Calculate Levenshtein distance + const distance = levenshteinDistance(normalizedOriginal, normalizedSearch) + + // Calculate similarity ratio (0 to 1, where 1 is exact match) + const maxLength = Math.max(normalizedOriginal.length, normalizedSearch.length) + return 1 - distance / maxLength } export class SearchReplaceDiffStrategy implements DiffStrategy { - private fuzzyThreshold: number; - private bufferLines: number; + private fuzzyThreshold: number + private bufferLines: number - constructor(fuzzyThreshold?: number, bufferLines?: number) { - // Use provided threshold or default to exact matching (1.0) - // Note: fuzzyThreshold is inverted in UI (0% = 1.0, 10% = 0.9) - // so we use it directly here - this.fuzzyThreshold = fuzzyThreshold ?? 1.0; - this.bufferLines = bufferLines ?? BUFFER_LINES; - } + constructor(fuzzyThreshold?: number, bufferLines?: number) { + // Use provided threshold or default to exact matching (1.0) + // Note: fuzzyThreshold is inverted in UI (0% = 1.0, 10% = 0.9) + // so we use it directly here + this.fuzzyThreshold = fuzzyThreshold ?? 1.0 + this.bufferLines = bufferLines ?? BUFFER_LINES + } - getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string { - return `## apply_diff + getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string { + return `## apply_diff Description: Request to replace existing code using a search and replace block. This tool allows for precise, surgical replaces to files by specifying exactly what content to search for and what to replace it with. The tool will maintain proper indentation and formatting while making changes. @@ -125,193 +127,204 @@ Your search/replace content here 1 5 ` - } + } - async applyDiff(originalContent: string, diffContent: string, startLine?: number, endLine?: number): Promise { - // Extract the search and replace blocks - const match = diffContent.match(/<<<<<<< SEARCH\n([\s\S]*?)\n?=======\n([\s\S]*?)\n?>>>>>>> REPLACE/); - if (!match) { - return { - success: false, - error: `Invalid diff format - missing required SEARCH/REPLACE sections\n\nDebug Info:\n- Expected Format: <<<<<<< SEARCH\\n[search content]\\n=======\\n[replace content]\\n>>>>>>> REPLACE\n- Tip: Make sure to include both SEARCH and REPLACE sections with correct markers` - }; - } + async applyDiff( + originalContent: string, + diffContent: string, + startLine?: number, + endLine?: number, + ): Promise { + // Extract the search and replace blocks + const match = diffContent.match(/<<<<<<< SEARCH\n([\s\S]*?)\n?=======\n([\s\S]*?)\n?>>>>>>> REPLACE/) + if (!match) { + return { + success: false, + error: `Invalid diff format - missing required SEARCH/REPLACE sections\n\nDebug Info:\n- Expected Format: <<<<<<< SEARCH\\n[search content]\\n=======\\n[replace content]\\n>>>>>>> REPLACE\n- Tip: Make sure to include both SEARCH and REPLACE sections with correct markers`, + } + } - let [_, searchContent, replaceContent] = match; + let [_, searchContent, replaceContent] = match - // Detect line ending from original content - const lineEnding = originalContent.includes('\r\n') ? '\r\n' : '\n'; + // Detect line ending from original content + const lineEnding = originalContent.includes("\r\n") ? "\r\n" : "\n" - // Strip line numbers from search and replace content if every line starts with a line number - if (everyLineHasLineNumbers(searchContent) && everyLineHasLineNumbers(replaceContent)) { - searchContent = stripLineNumbers(searchContent); - replaceContent = stripLineNumbers(replaceContent); - } - - // Split content into lines, handling both \n and \r\n - const searchLines = searchContent === '' ? [] : searchContent.split(/\r?\n/); - const replaceLines = replaceContent === '' ? [] : replaceContent.split(/\r?\n/); - const originalLines = originalContent.split(/\r?\n/); + // Strip line numbers from search and replace content if every line starts with a line number + if (everyLineHasLineNumbers(searchContent) && everyLineHasLineNumbers(replaceContent)) { + searchContent = stripLineNumbers(searchContent) + replaceContent = stripLineNumbers(replaceContent) + } - // Validate that empty search requires start line - if (searchLines.length === 0 && !startLine) { - return { - success: false, - error: `Empty search content requires start_line to be specified\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, specify the line number where content should be inserted` - }; - } + // Split content into lines, handling both \n and \r\n + const searchLines = searchContent === "" ? [] : searchContent.split(/\r?\n/) + const replaceLines = replaceContent === "" ? [] : replaceContent.split(/\r?\n/) + const originalLines = originalContent.split(/\r?\n/) - // Validate that empty search requires same start and end line - if (searchLines.length === 0 && startLine && endLine && startLine !== endLine) { - return { - success: false, - error: `Empty search content requires start_line and end_line to be the same (got ${startLine}-${endLine})\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, use the same line number for both start_line and end_line` - }; - } - - // Initialize search variables - let matchIndex = -1; - let bestMatchScore = 0; - let bestMatchContent = ""; - const searchChunk = searchLines.join('\n'); + // Validate that empty search requires start line + if (searchLines.length === 0 && !startLine) { + return { + success: false, + error: `Empty search content requires start_line to be specified\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, specify the line number where content should be inserted`, + } + } - // Determine search bounds - let searchStartIndex = 0; - let searchEndIndex = originalLines.length; + // Validate that empty search requires same start and end line + if (searchLines.length === 0 && startLine && endLine && startLine !== endLine) { + return { + success: false, + error: `Empty search content requires start_line and end_line to be the same (got ${startLine}-${endLine})\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, use the same line number for both start_line and end_line`, + } + } - // Validate and handle line range if provided - if (startLine && endLine) { - // Convert to 0-based index - const exactStartIndex = startLine - 1; - const exactEndIndex = endLine - 1; + // Initialize search variables + let matchIndex = -1 + let bestMatchScore = 0 + let bestMatchContent = "" + const searchChunk = searchLines.join("\n") - if (exactStartIndex < 0 || exactEndIndex > originalLines.length || exactStartIndex > exactEndIndex) { - return { - success: false, - error: `Line range ${startLine}-${endLine} is invalid (file has ${originalLines.length} lines)\n\nDebug Info:\n- Requested Range: lines ${startLine}-${endLine}\n- File Bounds: lines 1-${originalLines.length}`, - }; - } + // Determine search bounds + let searchStartIndex = 0 + let searchEndIndex = originalLines.length - // Try exact match first - const originalChunk = originalLines.slice(exactStartIndex, exactEndIndex + 1).join('\n'); - const similarity = getSimilarity(originalChunk, searchChunk); - if (similarity >= this.fuzzyThreshold) { - matchIndex = exactStartIndex; - bestMatchScore = similarity; - bestMatchContent = originalChunk; - } else { - // Set bounds for buffered search - searchStartIndex = Math.max(0, startLine - (this.bufferLines + 1)); - searchEndIndex = Math.min(originalLines.length, endLine + this.bufferLines); - } - } + // Validate and handle line range if provided + if (startLine && endLine) { + // Convert to 0-based index + const exactStartIndex = startLine - 1 + const exactEndIndex = endLine - 1 - // If no match found yet, try middle-out search within bounds - if (matchIndex === -1) { - const midPoint = Math.floor((searchStartIndex + searchEndIndex) / 2); - let leftIndex = midPoint; - let rightIndex = midPoint + 1; + if (exactStartIndex < 0 || exactEndIndex > originalLines.length || exactStartIndex > exactEndIndex) { + return { + success: false, + error: `Line range ${startLine}-${endLine} is invalid (file has ${originalLines.length} lines)\n\nDebug Info:\n- Requested Range: lines ${startLine}-${endLine}\n- File Bounds: lines 1-${originalLines.length}`, + } + } - // Search outward from the middle within bounds - while (leftIndex >= searchStartIndex || rightIndex <= searchEndIndex - searchLines.length) { - // Check left side if still in range - if (leftIndex >= searchStartIndex) { - const originalChunk = originalLines.slice(leftIndex, leftIndex + searchLines.length).join('\n'); - const similarity = getSimilarity(originalChunk, searchChunk); - if (similarity > bestMatchScore) { - bestMatchScore = similarity; - matchIndex = leftIndex; - bestMatchContent = originalChunk; - } - leftIndex--; - } + // Try exact match first + const originalChunk = originalLines.slice(exactStartIndex, exactEndIndex + 1).join("\n") + const similarity = getSimilarity(originalChunk, searchChunk) + if (similarity >= this.fuzzyThreshold) { + matchIndex = exactStartIndex + bestMatchScore = similarity + bestMatchContent = originalChunk + } else { + // Set bounds for buffered search + searchStartIndex = Math.max(0, startLine - (this.bufferLines + 1)) + searchEndIndex = Math.min(originalLines.length, endLine + this.bufferLines) + } + } - // Check right side if still in range - if (rightIndex <= searchEndIndex - searchLines.length) { - const originalChunk = originalLines.slice(rightIndex, rightIndex + searchLines.length).join('\n'); - const similarity = getSimilarity(originalChunk, searchChunk); - if (similarity > bestMatchScore) { - bestMatchScore = similarity; - matchIndex = rightIndex; - bestMatchContent = originalChunk; - } - rightIndex++; - } - } - } + // If no match found yet, try middle-out search within bounds + if (matchIndex === -1) { + const midPoint = Math.floor((searchStartIndex + searchEndIndex) / 2) + let leftIndex = midPoint + let rightIndex = midPoint + 1 - // Require similarity to meet threshold - if (matchIndex === -1 || bestMatchScore < this.fuzzyThreshold) { - const searchChunk = searchLines.join('\n'); - const originalContentSection = startLine !== undefined && endLine !== undefined - ? `\n\nOriginal Content:\n${addLineNumbers( - originalLines.slice( - Math.max(0, startLine - 1 - this.bufferLines), - Math.min(originalLines.length, endLine + this.bufferLines) - ).join('\n'), - Math.max(1, startLine - this.bufferLines) - )}` - : `\n\nOriginal Content:\n${addLineNumbers(originalLines.join('\n'))}`; + // Search outward from the middle within bounds + while (leftIndex >= searchStartIndex || rightIndex <= searchEndIndex - searchLines.length) { + // Check left side if still in range + if (leftIndex >= searchStartIndex) { + const originalChunk = originalLines.slice(leftIndex, leftIndex + searchLines.length).join("\n") + const similarity = getSimilarity(originalChunk, searchChunk) + if (similarity > bestMatchScore) { + bestMatchScore = similarity + matchIndex = leftIndex + bestMatchContent = originalChunk + } + leftIndex-- + } - const bestMatchSection = bestMatchContent - ? `\n\nBest Match Found:\n${addLineNumbers(bestMatchContent, matchIndex + 1)}` - : `\n\nBest Match Found:\n(no match)`; + // Check right side if still in range + if (rightIndex <= searchEndIndex - searchLines.length) { + const originalChunk = originalLines.slice(rightIndex, rightIndex + searchLines.length).join("\n") + const similarity = getSimilarity(originalChunk, searchChunk) + if (similarity > bestMatchScore) { + bestMatchScore = similarity + matchIndex = rightIndex + bestMatchContent = originalChunk + } + rightIndex++ + } + } + } - const lineRange = startLine || endLine ? - ` at ${startLine ? `start: ${startLine}` : 'start'} to ${endLine ? `end: ${endLine}` : 'end'}` : ''; - return { - success: false, - error: `No sufficiently similar match found${lineRange} (${Math.floor(bestMatchScore * 100)}% similar, needs ${Math.floor(this.fuzzyThreshold * 100)}%)\n\nDebug Info:\n- Similarity Score: ${Math.floor(bestMatchScore * 100)}%\n- Required Threshold: ${Math.floor(this.fuzzyThreshold * 100)}%\n- Search Range: ${startLine && endLine ? `lines ${startLine}-${endLine}` : 'start to end'}\n\nSearch Content:\n${searchChunk}${bestMatchSection}${originalContentSection}` - }; - } + // Require similarity to meet threshold + if (matchIndex === -1 || bestMatchScore < this.fuzzyThreshold) { + const searchChunk = searchLines.join("\n") + const originalContentSection = + startLine !== undefined && endLine !== undefined + ? `\n\nOriginal Content:\n${addLineNumbers( + originalLines + .slice( + Math.max(0, startLine - 1 - this.bufferLines), + Math.min(originalLines.length, endLine + this.bufferLines), + ) + .join("\n"), + Math.max(1, startLine - this.bufferLines), + )}` + : `\n\nOriginal Content:\n${addLineNumbers(originalLines.join("\n"))}` - // Get the matched lines from the original content - const matchedLines = originalLines.slice(matchIndex, matchIndex + searchLines.length); - - // Get the exact indentation (preserving tabs/spaces) of each line - const originalIndents = matchedLines.map(line => { - const match = line.match(/^[\t ]*/); - return match ? match[0] : ''; - }); + const bestMatchSection = bestMatchContent + ? `\n\nBest Match Found:\n${addLineNumbers(bestMatchContent, matchIndex + 1)}` + : `\n\nBest Match Found:\n(no match)` - // Get the exact indentation of each line in the search block - const searchIndents = searchLines.map(line => { - const match = line.match(/^[\t ]*/); - return match ? match[0] : ''; - }); + const lineRange = + startLine || endLine + ? ` at ${startLine ? `start: ${startLine}` : "start"} to ${endLine ? `end: ${endLine}` : "end"}` + : "" + return { + success: false, + error: `No sufficiently similar match found${lineRange} (${Math.floor(bestMatchScore * 100)}% similar, needs ${Math.floor(this.fuzzyThreshold * 100)}%)\n\nDebug Info:\n- Similarity Score: ${Math.floor(bestMatchScore * 100)}%\n- Required Threshold: ${Math.floor(this.fuzzyThreshold * 100)}%\n- Search Range: ${startLine && endLine ? `lines ${startLine}-${endLine}` : "start to end"}\n\nSearch Content:\n${searchChunk}${bestMatchSection}${originalContentSection}`, + } + } - // Apply the replacement while preserving exact indentation - const indentedReplaceLines = replaceLines.map((line, i) => { - // Get the matched line's exact indentation - const matchedIndent = originalIndents[0] || ''; - - // Get the current line's indentation relative to the search content - const currentIndentMatch = line.match(/^[\t ]*/); - const currentIndent = currentIndentMatch ? currentIndentMatch[0] : ''; - const searchBaseIndent = searchIndents[0] || ''; - - // Calculate the relative indentation level - const searchBaseLevel = searchBaseIndent.length; - const currentLevel = currentIndent.length; - const relativeLevel = currentLevel - searchBaseLevel; - - // If relative level is negative, remove indentation from matched indent - // If positive, add to matched indent - const finalIndent = relativeLevel < 0 - ? matchedIndent.slice(0, Math.max(0, matchedIndent.length + relativeLevel)) - : matchedIndent + currentIndent.slice(searchBaseLevel); - - return finalIndent + line.trim(); - }); + // Get the matched lines from the original content + const matchedLines = originalLines.slice(matchIndex, matchIndex + searchLines.length) - // Construct the final content - const beforeMatch = originalLines.slice(0, matchIndex); - const afterMatch = originalLines.slice(matchIndex + searchLines.length); - - const finalContent = [...beforeMatch, ...indentedReplaceLines, ...afterMatch].join(lineEnding); - return { - success: true, - content: finalContent - }; - } -} \ No newline at end of file + // Get the exact indentation (preserving tabs/spaces) of each line + const originalIndents = matchedLines.map((line) => { + const match = line.match(/^[\t ]*/) + return match ? match[0] : "" + }) + + // Get the exact indentation of each line in the search block + const searchIndents = searchLines.map((line) => { + const match = line.match(/^[\t ]*/) + return match ? match[0] : "" + }) + + // Apply the replacement while preserving exact indentation + const indentedReplaceLines = replaceLines.map((line, i) => { + // Get the matched line's exact indentation + const matchedIndent = originalIndents[0] || "" + + // Get the current line's indentation relative to the search content + const currentIndentMatch = line.match(/^[\t ]*/) + const currentIndent = currentIndentMatch ? currentIndentMatch[0] : "" + const searchBaseIndent = searchIndents[0] || "" + + // Calculate the relative indentation level + const searchBaseLevel = searchBaseIndent.length + const currentLevel = currentIndent.length + const relativeLevel = currentLevel - searchBaseLevel + + // If relative level is negative, remove indentation from matched indent + // If positive, add to matched indent + const finalIndent = + relativeLevel < 0 + ? matchedIndent.slice(0, Math.max(0, matchedIndent.length + relativeLevel)) + : matchedIndent + currentIndent.slice(searchBaseLevel) + + return finalIndent + line.trim() + }) + + // Construct the final content + const beforeMatch = originalLines.slice(0, matchIndex) + const afterMatch = originalLines.slice(matchIndex + searchLines.length) + + const finalContent = [...beforeMatch, ...indentedReplaceLines, ...afterMatch].join(lineEnding) + return { + success: true, + content: finalContent, + } + } +} diff --git a/src/core/diff/strategies/unified.ts b/src/core/diff/strategies/unified.ts index 564398f..f1cdb3b 100644 --- a/src/core/diff/strategies/unified.ts +++ b/src/core/diff/strategies/unified.ts @@ -2,8 +2,8 @@ import { applyPatch } from "diff" import { DiffStrategy, DiffResult } from "../types" export class UnifiedDiffStrategy implements DiffStrategy { - getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string { - return `## apply_diff + getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string { + return `## apply_diff Description: Apply a unified diff to a file at the specified path. This tool is useful when you need to make specific modifications to a file based on a set of changes provided in unified diff format (diff -U3). Parameters: @@ -106,32 +106,32 @@ Usage: Your diff here ` - } + } - async applyDiff(originalContent: string, diffContent: string): Promise { - try { - const result = applyPatch(originalContent, diffContent) - if (result === false) { - return { - success: false, - error: "Failed to apply unified diff - patch rejected", - details: { - searchContent: diffContent - } - } - } - return { - success: true, - content: result - } - } catch (error) { - return { - success: false, - error: `Error applying unified diff: ${error.message}`, - details: { - searchContent: diffContent - } - } - } - } + async applyDiff(originalContent: string, diffContent: string): Promise { + try { + const result = applyPatch(originalContent, diffContent) + if (result === false) { + return { + success: false, + error: "Failed to apply unified diff - patch rejected", + details: { + searchContent: diffContent, + }, + } + } + return { + success: true, + content: result, + } + } catch (error) { + return { + success: false, + error: `Error applying unified diff: ${error.message}`, + details: { + searchContent: diffContent, + }, + } + } + } } diff --git a/src/core/diff/types.ts b/src/core/diff/types.ts index 518112a..61275de 100644 --- a/src/core/diff/types.ts +++ b/src/core/diff/types.ts @@ -2,31 +2,35 @@ * Interface for implementing different diff strategies */ -export type DiffResult = - | { success: true; content: string } - | { success: false; error: string; details?: { - similarity?: number; - threshold?: number; - matchedRange?: { start: number; end: number }; - searchContent?: string; - bestMatch?: string; - }}; +export type DiffResult = + | { success: true; content: string } + | { + success: false + error: string + details?: { + similarity?: number + threshold?: number + matchedRange?: { start: number; end: number } + searchContent?: string + bestMatch?: string + } + } export interface DiffStrategy { - /** - * Get the tool description for this diff strategy - * @param args The tool arguments including cwd and toolOptions - * @returns The complete tool description including format requirements and examples - */ - getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string + /** + * Get the tool description for this diff strategy + * @param args The tool arguments including cwd and toolOptions + * @returns The complete tool description including format requirements and examples + */ + getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string - /** - * Apply a diff to the original content - * @param originalContent The original file content - * @param diffContent The diff content in the strategy's format - * @param startLine Optional line number where the search block starts. If not provided, searches the entire file. - * @param endLine Optional line number where the search block ends. If not provided, searches the entire file. - * @returns A DiffResult object containing either the successful result or error details - */ - applyDiff(originalContent: string, diffContent: string, startLine?: number, endLine?: number): Promise -} \ No newline at end of file + /** + * Apply a diff to the original content + * @param originalContent The original file content + * @param diffContent The diff content in the strategy's format + * @param startLine Optional line number where the search block starts. If not provided, searches the entire file. + * @param endLine Optional line number where the search block ends. If not provided, searches the entire file. + * @returns A DiffResult object containing either the successful result or error details + */ + applyDiff(originalContent: string, diffContent: string, startLine?: number, endLine?: number): Promise +} diff --git a/src/core/mentions/__tests__/index.test.ts b/src/core/mentions/__tests__/index.test.ts index 609f0cf..7a779d3 100644 --- a/src/core/mentions/__tests__/index.test.ts +++ b/src/core/mentions/__tests__/index.test.ts @@ -1,20 +1,20 @@ // Create mock vscode module before importing anything const createMockUri = (scheme: string, path: string) => ({ scheme, - authority: '', + authority: "", path, - query: '', - fragment: '', + query: "", + fragment: "", fsPath: path, with: jest.fn(), toString: () => path, toJSON: () => ({ scheme, - authority: '', + authority: "", path, - query: '', - fragment: '' - }) + query: "", + fragment: "", + }), }) const mockExecuteCommand = jest.fn() @@ -23,9 +23,11 @@ const mockShowErrorMessage = jest.fn() const mockVscode = { workspace: { - workspaceFolders: [{ - uri: { fsPath: "/test/workspace" } - }] + workspaceFolders: [ + { + uri: { fsPath: "/test/workspace" }, + }, + ], }, window: { showErrorMessage: mockShowErrorMessage, @@ -34,17 +36,17 @@ const mockVscode = { createTextEditorDecorationType: jest.fn(), createOutputChannel: jest.fn(), createWebviewPanel: jest.fn(), - activeTextEditor: undefined + activeTextEditor: undefined, }, commands: { - executeCommand: mockExecuteCommand + executeCommand: mockExecuteCommand, }, env: { - openExternal: mockOpenExternal + openExternal: mockOpenExternal, }, Uri: { - parse: jest.fn((url: string) => createMockUri('https', url)), - file: jest.fn((path: string) => createMockUri('file', path)) + parse: jest.fn((url: string) => createMockUri("https", url)), + file: jest.fn((path: string) => createMockUri("file", path)), }, Position: jest.fn(), Range: jest.fn(), @@ -54,12 +56,12 @@ const mockVscode = { Error: 0, Warning: 1, Information: 2, - Hint: 3 - } + Hint: 3, + }, } // Mock modules -jest.mock('vscode', () => mockVscode) +jest.mock("vscode", () => mockVscode) jest.mock("../../../services/browser/UrlContentFetcher") jest.mock("../../../utils/git") @@ -74,7 +76,7 @@ describe("mentions", () => { beforeEach(() => { jest.clearAllMocks() - + // Create a mock instance with just the methods we need mockUrlContentFetcher = { launchBrowser: jest.fn().mockResolvedValue(undefined), @@ -94,14 +96,10 @@ Date: Mon Jan 5 23:50:06 2025 -0500 Detailed commit message with multiple lines - Fixed parsing issue - Added tests` - + jest.mocked(git.getCommitInfo).mockResolvedValue(commitInfo) - const result = await parseMentions( - `Check out this commit @${commitHash}`, - mockCwd, - mockUrlContentFetcher - ) + const result = await parseMentions(`Check out this commit @${commitHash}`, mockCwd, mockUrlContentFetcher) expect(result).toContain(`'${commitHash}' (see below for commit info)`) expect(result).toContain(``) @@ -111,14 +109,10 @@ Detailed commit message with multiple lines it("should handle errors fetching git info", async () => { const commitHash = "abc1234" const errorMessage = "Failed to get commit info" - + jest.mocked(git.getCommitInfo).mockRejectedValue(new Error(errorMessage)) - const result = await parseMentions( - `Check out this commit @${commitHash}`, - mockCwd, - mockUrlContentFetcher - ) + const result = await parseMentions(`Check out this commit @${commitHash}`, mockCwd, mockUrlContentFetcher) expect(result).toContain(`'${commitHash}' (see below for commit info)`) expect(result).toContain(``) @@ -143,13 +137,15 @@ Detailed commit message with multiple lines const mockUri = mockVscode.Uri.parse(url) expect(mockOpenExternal).toHaveBeenCalled() const calledArg = mockOpenExternal.mock.calls[0][0] - expect(calledArg).toEqual(expect.objectContaining({ - scheme: mockUri.scheme, - authority: mockUri.authority, - path: mockUri.path, - query: mockUri.query, - fragment: mockUri.fragment - })) + expect(calledArg).toEqual( + expect.objectContaining({ + scheme: mockUri.scheme, + authority: mockUri.authority, + path: mockUri.path, + query: mockUri.query, + fragment: mockUri.fragment, + }), + ) }) }) -}) \ No newline at end of file +}) diff --git a/src/core/mode-validator.ts b/src/core/mode-validator.ts index 00523fb..c432c73 100644 --- a/src/core/mode-validator.ts +++ b/src/core/mode-validator.ts @@ -1,12 +1,10 @@ -import { Mode, isToolAllowedForMode, TestToolName, getModeConfig } from '../shared/modes'; +import { Mode, isToolAllowedForMode, TestToolName, getModeConfig } from "../shared/modes" -export { isToolAllowedForMode }; -export type { TestToolName }; +export { isToolAllowedForMode } +export type { TestToolName } export function validateToolUse(toolName: TestToolName, mode: Mode): void { - if (!isToolAllowedForMode(toolName, mode)) { - throw new Error( - `Tool "${toolName}" is not allowed in ${mode} mode.` - ); - } -} \ No newline at end of file + if (!isToolAllowedForMode(toolName, mode)) { + throw new Error(`Tool "${toolName}" is not allowed in ${mode} mode.`) + } +} diff --git a/src/core/prompts/__tests__/system.test.ts b/src/core/prompts/__tests__/system.test.ts index f12a101..de9c503 100644 --- a/src/core/prompts/__tests__/system.test.ts +++ b/src/core/prompts/__tests__/system.test.ts @@ -1,422 +1,357 @@ -import { SYSTEM_PROMPT, addCustomInstructions } from '../system' -import { McpHub } from '../../../services/mcp/McpHub' -import { McpServer } from '../../../shared/mcp' -import { ClineProvider } from '../../../core/webview/ClineProvider' -import { SearchReplaceDiffStrategy } from '../../../core/diff/strategies/search-replace' -import fs from 'fs/promises' -import os from 'os' -import { defaultModeSlug, modes } from '../../../shared/modes' +import { SYSTEM_PROMPT, addCustomInstructions } from "../system" +import { McpHub } from "../../../services/mcp/McpHub" +import { McpServer } from "../../../shared/mcp" +import { ClineProvider } from "../../../core/webview/ClineProvider" +import { SearchReplaceDiffStrategy } from "../../../core/diff/strategies/search-replace" +import fs from "fs/promises" +import os from "os" +import { defaultModeSlug, modes } from "../../../shared/modes" // Import path utils to get access to toPosix string extension -import '../../../utils/path' +import "../../../utils/path" // Mock environment-specific values for consistent tests -jest.mock('os', () => ({ - ...jest.requireActual('os'), - homedir: () => '/home/user' +jest.mock("os", () => ({ + ...jest.requireActual("os"), + homedir: () => "/home/user", })) -jest.mock('default-shell', () => '/bin/bash') +jest.mock("default-shell", () => "/bin/bash") -jest.mock('os-name', () => () => 'Linux') +jest.mock("os-name", () => () => "Linux") // Mock fs.readFile to return empty mcpServers config and mock rules files -jest.mock('fs/promises', () => ({ - ...jest.requireActual('fs/promises'), - readFile: jest.fn().mockImplementation(async (path: string) => { - if (path.endsWith('mcpSettings.json')) { - return '{"mcpServers": {}}' - } - if (path.endsWith('.clinerules-code')) { - return '# Code Mode Rules\n1. Code specific rule' - } - if (path.endsWith('.clinerules-ask')) { - return '# Ask Mode Rules\n1. Ask specific rule' - } - if (path.endsWith('.clinerules-architect')) { - return '# Architect Mode Rules\n1. Architect specific rule' - } - if (path.endsWith('.clinerules')) { - return '# Test Rules\n1. First rule\n2. Second rule' - } - return '' - }), - writeFile: jest.fn().mockResolvedValue(undefined) +jest.mock("fs/promises", () => ({ + ...jest.requireActual("fs/promises"), + readFile: jest.fn().mockImplementation(async (path: string) => { + if (path.endsWith("mcpSettings.json")) { + return '{"mcpServers": {}}' + } + if (path.endsWith(".clinerules-code")) { + return "# Code Mode Rules\n1. Code specific rule" + } + if (path.endsWith(".clinerules-ask")) { + return "# Ask Mode Rules\n1. Ask specific rule" + } + if (path.endsWith(".clinerules-architect")) { + return "# Architect Mode Rules\n1. Architect specific rule" + } + if (path.endsWith(".clinerules")) { + return "# Test Rules\n1. First rule\n2. Second rule" + } + return "" + }), + writeFile: jest.fn().mockResolvedValue(undefined), })) // Create a minimal mock of ClineProvider const mockProvider = { - ensureMcpServersDirectoryExists: async () => '/mock/mcp/path', - ensureSettingsDirectoryExists: async () => '/mock/settings/path', - postMessageToWebview: async () => {}, - context: { - extension: { - packageJSON: { - version: '1.0.0' - } - } - } + ensureMcpServersDirectoryExists: async () => "/mock/mcp/path", + ensureSettingsDirectoryExists: async () => "/mock/settings/path", + postMessageToWebview: async () => {}, + context: { + extension: { + packageJSON: { + version: "1.0.0", + }, + }, + }, } as unknown as ClineProvider // Instead of extending McpHub, create a mock that implements just what we need -const createMockMcpHub = (): McpHub => ({ - getServers: () => [], - getMcpServersPath: async () => '/mock/mcp/path', - getMcpSettingsFilePath: async () => '/mock/settings/path', - dispose: async () => {}, - // Add other required public methods with no-op implementations - restartConnection: async () => {}, - readResource: async () => ({ contents: [] }), - callTool: async () => ({ content: [] }), - toggleServerDisabled: async () => {}, - toggleToolAlwaysAllow: async () => {}, - isConnecting: false, - connections: [] -} as unknown as McpHub) +const createMockMcpHub = (): McpHub => + ({ + getServers: () => [], + getMcpServersPath: async () => "/mock/mcp/path", + getMcpSettingsFilePath: async () => "/mock/settings/path", + dispose: async () => {}, + // Add other required public methods with no-op implementations + restartConnection: async () => {}, + readResource: async () => ({ contents: [] }), + callTool: async () => ({ content: [] }), + toggleServerDisabled: async () => {}, + toggleToolAlwaysAllow: async () => {}, + isConnecting: false, + connections: [], + }) as unknown as McpHub -describe('SYSTEM_PROMPT', () => { - let mockMcpHub: McpHub +describe("SYSTEM_PROMPT", () => { + let mockMcpHub: McpHub - beforeEach(() => { - jest.clearAllMocks() - }) + beforeEach(() => { + jest.clearAllMocks() + }) - afterEach(async () => { - // Clean up any McpHub instances - if (mockMcpHub) { - await mockMcpHub.dispose() - } - }) + afterEach(async () => { + // Clean up any McpHub instances + if (mockMcpHub) { + await mockMcpHub.dispose() + } + }) - it('should maintain consistent system prompt', async () => { - const prompt = await SYSTEM_PROMPT( - '/test/path', - false, // supportsComputerUse - undefined, // mcpHub - undefined, // diffStrategy - undefined // browserViewportSize - ) - - expect(prompt).toMatchSnapshot() - }) + it("should maintain consistent system prompt", async () => { + const prompt = await SYSTEM_PROMPT( + "/test/path", + false, // supportsComputerUse + undefined, // mcpHub + undefined, // diffStrategy + undefined, // browserViewportSize + ) - it('should include browser actions when supportsComputerUse is true', async () => { - const prompt = await SYSTEM_PROMPT( - '/test/path', - true, - undefined, - undefined, - '1280x800' - ) - - expect(prompt).toMatchSnapshot() - }) + expect(prompt).toMatchSnapshot() + }) - it('should include MCP server info when mcpHub is provided', async () => { - mockMcpHub = createMockMcpHub() + it("should include browser actions when supportsComputerUse is true", async () => { + const prompt = await SYSTEM_PROMPT("/test/path", true, undefined, undefined, "1280x800") - const prompt = await SYSTEM_PROMPT( - '/test/path', - false, - mockMcpHub - ) - - expect(prompt).toMatchSnapshot() - }) + expect(prompt).toMatchSnapshot() + }) - it('should explicitly handle undefined mcpHub', async () => { - const prompt = await SYSTEM_PROMPT( - '/test/path', - false, - undefined, // explicitly undefined mcpHub - undefined, - undefined - ) - - expect(prompt).toMatchSnapshot() - }) + it("should include MCP server info when mcpHub is provided", async () => { + mockMcpHub = createMockMcpHub() - it('should handle different browser viewport sizes', async () => { - const prompt = await SYSTEM_PROMPT( - '/test/path', - true, - undefined, - undefined, - '900x600' // different viewport size - ) - - expect(prompt).toMatchSnapshot() - }) + const prompt = await SYSTEM_PROMPT("/test/path", false, mockMcpHub) - it('should include diff strategy tool description', async () => { - const prompt = await SYSTEM_PROMPT( - '/test/path', - false, - undefined, - new SearchReplaceDiffStrategy(), // Use actual diff strategy from the codebase - undefined - ) - - expect(prompt).toMatchSnapshot() - }) + expect(prompt).toMatchSnapshot() + }) - afterAll(() => { - jest.restoreAllMocks() - }) + it("should explicitly handle undefined mcpHub", async () => { + const prompt = await SYSTEM_PROMPT( + "/test/path", + false, + undefined, // explicitly undefined mcpHub + undefined, + undefined, + ) + + expect(prompt).toMatchSnapshot() + }) + + it("should handle different browser viewport sizes", async () => { + const prompt = await SYSTEM_PROMPT( + "/test/path", + true, + undefined, + undefined, + "900x600", // different viewport size + ) + + expect(prompt).toMatchSnapshot() + }) + + it("should include diff strategy tool description", async () => { + const prompt = await SYSTEM_PROMPT( + "/test/path", + false, + undefined, + new SearchReplaceDiffStrategy(), // Use actual diff strategy from the codebase + undefined, + ) + + expect(prompt).toMatchSnapshot() + }) + + afterAll(() => { + jest.restoreAllMocks() + }) }) -describe('addCustomInstructions', () => { - beforeEach(() => { - jest.clearAllMocks() - }) +describe("addCustomInstructions", () => { + beforeEach(() => { + jest.clearAllMocks() + }) - it('should generate correct prompt for architect mode', async () => { - const prompt = await SYSTEM_PROMPT( - '/test/path', - false, - undefined, - undefined, - undefined, - 'architect' - ) - - expect(prompt).toMatchSnapshot() - }) + it("should generate correct prompt for architect mode", async () => { + const prompt = await SYSTEM_PROMPT("/test/path", false, undefined, undefined, undefined, "architect") - it('should generate correct prompt for ask mode', async () => { - const prompt = await SYSTEM_PROMPT( - '/test/path', - false, - undefined, - undefined, - undefined, - 'ask' - ) - - expect(prompt).toMatchSnapshot() - }) + expect(prompt).toMatchSnapshot() + }) - it('should prioritize mode-specific rules for code mode', async () => { - const instructions = await addCustomInstructions( - {}, - '/test/path', - defaultModeSlug - ) - expect(instructions).toMatchSnapshot() - }) + it("should generate correct prompt for ask mode", async () => { + const prompt = await SYSTEM_PROMPT("/test/path", false, undefined, undefined, undefined, "ask") - it('should prioritize mode-specific rules for ask mode', async () => { - const instructions = await addCustomInstructions( - {}, - '/test/path', - modes[2].slug - ) - expect(instructions).toMatchSnapshot() - }) + expect(prompt).toMatchSnapshot() + }) - it('should prioritize mode-specific rules for architect mode', async () => { - const instructions = await addCustomInstructions( - {}, - '/test/path', - modes[1].slug - ) - - expect(instructions).toMatchSnapshot() - }) + it("should prioritize mode-specific rules for code mode", async () => { + const instructions = await addCustomInstructions({}, "/test/path", defaultModeSlug) + expect(instructions).toMatchSnapshot() + }) - it('should prioritize mode-specific rules for test engineer mode', async () => { - // Mock readFile to include test engineer rules - const mockReadFile = jest.fn().mockImplementation(async (path: string) => { - if (path.endsWith('.clinerules-test')) { - return '# Test Engineer Rules\n1. Always write tests first\n2. Get approval before modifying non-test code' - } - if (path.endsWith('.clinerules')) { - return '# Test Rules\n1. First rule\n2. Second rule' - } - return '' - }) - jest.spyOn(fs, 'readFile').mockImplementation(mockReadFile) + it("should prioritize mode-specific rules for ask mode", async () => { + const instructions = await addCustomInstructions({}, "/test/path", modes[2].slug) + expect(instructions).toMatchSnapshot() + }) - const instructions = await addCustomInstructions( - {}, - '/test/path', - 'test' - ) - expect(instructions).toMatchSnapshot() - }) + it("should prioritize mode-specific rules for architect mode", async () => { + const instructions = await addCustomInstructions({}, "/test/path", modes[1].slug) - it('should prioritize mode-specific rules for code reviewer mode', async () => { - // Mock readFile to include code reviewer rules - const mockReadFile = jest.fn().mockImplementation(async (path: string) => { - if (path.endsWith('.clinerules-review')) { - return '# Code Reviewer Rules\n1. Provide specific examples in feedback\n2. Focus on maintainability and best practices' - } - if (path.endsWith('.clinerules')) { - return '# Test Rules\n1. First rule\n2. Second rule' - } - return '' - }) - jest.spyOn(fs, 'readFile').mockImplementation(mockReadFile) + expect(instructions).toMatchSnapshot() + }) - const instructions = await addCustomInstructions( - {}, - '/test/path', - 'review' - ) - expect(instructions).toMatchSnapshot() - }) + it("should prioritize mode-specific rules for test engineer mode", async () => { + // Mock readFile to include test engineer rules + const mockReadFile = jest.fn().mockImplementation(async (path: string) => { + if (path.endsWith(".clinerules-test")) { + return "# Test Engineer Rules\n1. Always write tests first\n2. Get approval before modifying non-test code" + } + if (path.endsWith(".clinerules")) { + return "# Test Rules\n1. First rule\n2. Second rule" + } + return "" + }) + jest.spyOn(fs, "readFile").mockImplementation(mockReadFile) - it('should generate correct prompt for test engineer mode', async () => { - const prompt = await SYSTEM_PROMPT( - '/test/path', - false, - undefined, - undefined, - undefined, - 'test' - ) - - // Verify test engineer role requirements - expect(prompt).toContain('must ask the user to confirm before making ANY changes to non-test code') - expect(prompt).toContain('ask the user to confirm your test plan') - expect(prompt).toMatchSnapshot() - }) + const instructions = await addCustomInstructions({}, "/test/path", "test") + expect(instructions).toMatchSnapshot() + }) - it('should generate correct prompt for code reviewer mode', async () => { - const prompt = await SYSTEM_PROMPT( - '/test/path', - false, - undefined, - undefined, - undefined, - 'review' - ) - - // Verify code reviewer role constraints - expect(prompt).toContain('providing detailed, actionable feedback') - expect(prompt).toContain('maintain a read-only approach') - expect(prompt).toMatchSnapshot() - }) + it("should prioritize mode-specific rules for code reviewer mode", async () => { + // Mock readFile to include code reviewer rules + const mockReadFile = jest.fn().mockImplementation(async (path: string) => { + if (path.endsWith(".clinerules-review")) { + return "# Code Reviewer Rules\n1. Provide specific examples in feedback\n2. Focus on maintainability and best practices" + } + if (path.endsWith(".clinerules")) { + return "# Test Rules\n1. First rule\n2. Second rule" + } + return "" + }) + jest.spyOn(fs, "readFile").mockImplementation(mockReadFile) - it('should fall back to generic rules when mode-specific rules not found', async () => { - // Mock readFile to return ENOENT for mode-specific file - const mockReadFile = jest.fn().mockImplementation(async (path: string) => { - if (path.endsWith('.clinerules-code') || - path.endsWith('.clinerules-test') || - path.endsWith('.clinerules-review')) { - const error = new Error('ENOENT') as NodeJS.ErrnoException - error.code = 'ENOENT' - throw error - } - if (path.endsWith('.clinerules')) { - return '# Test Rules\n1. First rule\n2. Second rule' - } - return '' - }) - jest.spyOn(fs, 'readFile').mockImplementation(mockReadFile) + const instructions = await addCustomInstructions({}, "/test/path", "review") + expect(instructions).toMatchSnapshot() + }) - const instructions = await addCustomInstructions( - {}, - '/test/path', - defaultModeSlug - ) - - expect(instructions).toMatchSnapshot() - }) + it("should generate correct prompt for test engineer mode", async () => { + const prompt = await SYSTEM_PROMPT("/test/path", false, undefined, undefined, undefined, "test") - it('should include preferred language when provided', async () => { - const instructions = await addCustomInstructions( - { preferredLanguage: 'Spanish' }, - '/test/path', - defaultModeSlug - ) - - expect(instructions).toMatchSnapshot() - }) + // Verify test engineer role requirements + expect(prompt).toContain("must ask the user to confirm before making ANY changes to non-test code") + expect(prompt).toContain("ask the user to confirm your test plan") + expect(prompt).toMatchSnapshot() + }) - it('should include custom instructions when provided', async () => { - const instructions = await addCustomInstructions( - { customInstructions: 'Custom test instructions' }, - '/test/path' - ) - - expect(instructions).toMatchSnapshot() - }) + it("should generate correct prompt for code reviewer mode", async () => { + const prompt = await SYSTEM_PROMPT("/test/path", false, undefined, undefined, undefined, "review") - it('should combine all custom instructions', async () => { - const instructions = await addCustomInstructions( - { - customInstructions: 'Custom test instructions', - preferredLanguage: 'French' - }, - '/test/path', - defaultModeSlug - ) - expect(instructions).toMatchSnapshot() - }) + // Verify code reviewer role constraints + expect(prompt).toContain("providing detailed, actionable feedback") + expect(prompt).toContain("maintain a read-only approach") + expect(prompt).toMatchSnapshot() + }) - it('should handle undefined mode-specific instructions', async () => { - const instructions = await addCustomInstructions( - {}, - '/test/path' - ) - - expect(instructions).toMatchSnapshot() - }) + it("should fall back to generic rules when mode-specific rules not found", async () => { + // Mock readFile to return ENOENT for mode-specific file + const mockReadFile = jest.fn().mockImplementation(async (path: string) => { + if ( + path.endsWith(".clinerules-code") || + path.endsWith(".clinerules-test") || + path.endsWith(".clinerules-review") + ) { + const error = new Error("ENOENT") as NodeJS.ErrnoException + error.code = "ENOENT" + throw error + } + if (path.endsWith(".clinerules")) { + return "# Test Rules\n1. First rule\n2. Second rule" + } + return "" + }) + jest.spyOn(fs, "readFile").mockImplementation(mockReadFile) - it('should trim mode-specific instructions', async () => { - const instructions = await addCustomInstructions( - { customInstructions: ' Custom mode instructions ' }, - '/test/path' - ) - - expect(instructions).toMatchSnapshot() - }) + const instructions = await addCustomInstructions({}, "/test/path", defaultModeSlug) - it('should handle empty mode-specific instructions', async () => { - const instructions = await addCustomInstructions( - { customInstructions: '' }, - '/test/path' - ) - - expect(instructions).toMatchSnapshot() - }) + expect(instructions).toMatchSnapshot() + }) - it('should combine global and mode-specific instructions', async () => { - const instructions = await addCustomInstructions( - { - customInstructions: 'Global instructions', - customPrompts: { - code: { customInstructions: 'Mode-specific instructions' } - } - }, - '/test/path', - defaultModeSlug - ) - - expect(instructions).toMatchSnapshot() - }) + it("should include preferred language when provided", async () => { + const instructions = await addCustomInstructions( + { preferredLanguage: "Spanish" }, + "/test/path", + defaultModeSlug, + ) - it('should prioritize mode-specific instructions after global ones', async () => { - const instructions = await addCustomInstructions( - { - customInstructions: 'First instruction', - customPrompts: { - code: { customInstructions: 'Second instruction' } - } - }, - '/test/path', - defaultModeSlug - ) - - const instructionParts = instructions.split('\n\n') - const globalIndex = instructionParts.findIndex(part => part.includes('First instruction')) - const modeSpecificIndex = instructionParts.findIndex(part => part.includes('Second instruction')) - - expect(globalIndex).toBeLessThan(modeSpecificIndex) - expect(instructions).toMatchSnapshot() - }) + expect(instructions).toMatchSnapshot() + }) - afterAll(() => { - jest.restoreAllMocks() - }) + it("should include custom instructions when provided", async () => { + const instructions = await addCustomInstructions( + { customInstructions: "Custom test instructions" }, + "/test/path", + ) + + expect(instructions).toMatchSnapshot() + }) + + it("should combine all custom instructions", async () => { + const instructions = await addCustomInstructions( + { + customInstructions: "Custom test instructions", + preferredLanguage: "French", + }, + "/test/path", + defaultModeSlug, + ) + expect(instructions).toMatchSnapshot() + }) + + it("should handle undefined mode-specific instructions", async () => { + const instructions = await addCustomInstructions({}, "/test/path") + + expect(instructions).toMatchSnapshot() + }) + + it("should trim mode-specific instructions", async () => { + const instructions = await addCustomInstructions( + { customInstructions: " Custom mode instructions " }, + "/test/path", + ) + + expect(instructions).toMatchSnapshot() + }) + + it("should handle empty mode-specific instructions", async () => { + const instructions = await addCustomInstructions({ customInstructions: "" }, "/test/path") + + expect(instructions).toMatchSnapshot() + }) + + it("should combine global and mode-specific instructions", async () => { + const instructions = await addCustomInstructions( + { + customInstructions: "Global instructions", + customPrompts: { + code: { customInstructions: "Mode-specific instructions" }, + }, + }, + "/test/path", + defaultModeSlug, + ) + + expect(instructions).toMatchSnapshot() + }) + + it("should prioritize mode-specific instructions after global ones", async () => { + const instructions = await addCustomInstructions( + { + customInstructions: "First instruction", + customPrompts: { + code: { customInstructions: "Second instruction" }, + }, + }, + "/test/path", + defaultModeSlug, + ) + + const instructionParts = instructions.split("\n\n") + const globalIndex = instructionParts.findIndex((part) => part.includes("First instruction")) + const modeSpecificIndex = instructionParts.findIndex((part) => part.includes("Second instruction")) + + expect(globalIndex).toBeLessThan(modeSpecificIndex) + expect(instructions).toMatchSnapshot() + }) + + afterAll(() => { + jest.restoreAllMocks() + }) }) diff --git a/src/core/prompts/sections/capabilities.ts b/src/core/prompts/sections/capabilities.ts index 98f3b14..c30e38a 100644 --- a/src/core/prompts/sections/capabilities.ts +++ b/src/core/prompts/sections/capabilities.ts @@ -2,27 +2,31 @@ import { DiffStrategy } from "../../diff/DiffStrategy" import { McpHub } from "../../../services/mcp/McpHub" export function getCapabilitiesSection( - cwd: string, - supportsComputerUse: boolean, - mcpHub?: McpHub, - diffStrategy?: DiffStrategy, + cwd: string, + supportsComputerUse: boolean, + mcpHub?: McpHub, + diffStrategy?: DiffStrategy, ): string { - return `==== + return `==== CAPABILITIES - You have access to tools that let you execute CLI commands on the user's computer, list files, view source code definitions, regex search${ - supportsComputerUse ? ", use the browser" : "" -}, read and write files, and ask follow-up questions. These tools help you effectively accomplish a wide range of tasks, such as writing code, making edits or improvements to existing files, understanding the current state of a project, performing system operations, and much more. + supportsComputerUse ? ", use the browser" : "" + }, read and write files, and ask follow-up questions. These tools help you effectively accomplish a wide range of tasks, such as writing code, making edits or improvements to existing files, understanding the current state of a project, performing system operations, and much more. - When the user initially gives you a task, a recursive list of all filepaths in the current working directory ('${cwd}') will be included in environment_details. This provides an overview of the project's file structure, offering key insights into the project from directory/file names (how developers conceptualize and organize their code) and file extensions (the language used). This can also guide decision-making on which files to explore further. If you need to further explore directories such as outside the current working directory, you can use the list_files tool. If you pass 'true' for the recursive parameter, it will list files recursively. Otherwise, it will list files at the top level, which is better suited for generic directories where you don't necessarily need the nested structure, like the Desktop. - You can use search_files to perform regex searches across files in a specified directory, outputting context-rich results that include surrounding lines. This is particularly useful for understanding code patterns, finding specific implementations, or identifying areas that need refactoring. - You can use the list_code_definition_names tool to get an overview of source code definitions for all files at the top level of a specified directory. This can be particularly useful when you need to understand the broader context and relationships between certain parts of the code. You may need to call this tool multiple times to understand various parts of the codebase related to the task. - For example, when asked to make edits or improvements you might analyze the file structure in the initial environment_details to get an overview of the project, then use list_code_definition_names to get further insight using source code definitions for files located in relevant directories, then read_file to examine the contents of relevant files, analyze the code and suggest improvements or make necessary edits, then use the write_to_file ${diffStrategy ? "or apply_diff " : ""}tool to apply the changes. If you refactored code that could affect other parts of the codebase, you could use search_files to ensure you update other files as needed. - You can use the execute_command tool to run commands on the user's computer whenever you feel it can help accomplish the user's task. When you need to execute a CLI command, you must provide a clear explanation of what the command does. Prefer to execute complex CLI commands over creating executable scripts, since they are more flexible and easier to run. Interactive and long-running commands are allowed, since the commands are run in the user's VSCode terminal. The user may keep commands running in the background and you will be kept updated on their status along the way. Each command you execute is run in a new terminal instance.${ - supportsComputerUse - ? "\n- You can use the browser_action tool to interact with websites (including html files and locally running development servers) through a Puppeteer-controlled browser when you feel it is necessary in accomplishing the user's task. This tool is particularly useful for web development tasks as it allows you to launch a browser, navigate to pages, interact with elements through clicks and keyboard input, and capture the results through screenshots and console logs. This tool may be useful at key stages of web development tasks-such as after implementing new features, making substantial changes, when troubleshooting issues, or to verify the result of your work. You can analyze the provided screenshots to ensure correct rendering or identify errors, and review console logs for runtime issues.\n - For example, if asked to add a component to a react website, you might create the necessary files, use execute_command to run the site locally, then use browser_action to launch the browser, navigate to the local server, and verify the component renders & functions correctly before closing the browser." - : "" -}${mcpHub ? ` + supportsComputerUse + ? "\n- You can use the browser_action tool to interact with websites (including html files and locally running development servers) through a Puppeteer-controlled browser when you feel it is necessary in accomplishing the user's task. This tool is particularly useful for web development tasks as it allows you to launch a browser, navigate to pages, interact with elements through clicks and keyboard input, and capture the results through screenshots and console logs. This tool may be useful at key stages of web development tasks-such as after implementing new features, making substantial changes, when troubleshooting issues, or to verify the result of your work. You can analyze the provided screenshots to ensure correct rendering or identify errors, and review console logs for runtime issues.\n - For example, if asked to add a component to a react website, you might create the necessary files, use execute_command to run the site locally, then use browser_action to launch the browser, navigate to the local server, and verify the component renders & functions correctly before closing the browser." + : "" + }${ + mcpHub + ? ` - You have access to MCP servers that may provide additional tools and resources. Each server may provide different capabilities that you can use to accomplish tasks more effectively. -` : ''}` -} \ No newline at end of file +` + : "" + }` +} diff --git a/src/core/prompts/sections/custom-instructions.ts b/src/core/prompts/sections/custom-instructions.ts index 2d68787..b55e472 100644 --- a/src/core/prompts/sections/custom-instructions.ts +++ b/src/core/prompts/sections/custom-instructions.ts @@ -1,46 +1,51 @@ -import fs from 'fs/promises' -import path from 'path' +import fs from "fs/promises" +import path from "path" export async function loadRuleFiles(cwd: string): Promise { - const ruleFiles = ['.clinerules', '.cursorrules', '.windsurfrules'] - let combinedRules = '' + const ruleFiles = [".clinerules", ".cursorrules", ".windsurfrules"] + let combinedRules = "" - for (const file of ruleFiles) { - try { - const content = await fs.readFile(path.join(cwd, file), 'utf-8') - if (content.trim()) { - combinedRules += `\n# Rules from ${file}:\n${content.trim()}\n` - } - } catch (err) { - // Silently skip if file doesn't exist - if ((err as NodeJS.ErrnoException).code !== 'ENOENT') { - throw err - } - } - } + for (const file of ruleFiles) { + try { + const content = await fs.readFile(path.join(cwd, file), "utf-8") + if (content.trim()) { + combinedRules += `\n# Rules from ${file}:\n${content.trim()}\n` + } + } catch (err) { + // Silently skip if file doesn't exist + if ((err as NodeJS.ErrnoException).code !== "ENOENT") { + throw err + } + } + } - return combinedRules + return combinedRules } -export async function addCustomInstructions(customInstructions: string, cwd: string, preferredLanguage?: string): Promise { - const ruleFileContent = await loadRuleFiles(cwd) - const allInstructions = [] +export async function addCustomInstructions( + customInstructions: string, + cwd: string, + preferredLanguage?: string, +): Promise { + const ruleFileContent = await loadRuleFiles(cwd) + const allInstructions = [] - if (preferredLanguage) { - allInstructions.push(`You should always speak and think in the ${preferredLanguage} language.`) - } - - if (customInstructions.trim()) { - allInstructions.push(customInstructions.trim()) - } + if (preferredLanguage) { + allInstructions.push(`You should always speak and think in the ${preferredLanguage} language.`) + } - if (ruleFileContent && ruleFileContent.trim()) { - allInstructions.push(ruleFileContent.trim()) - } + if (customInstructions.trim()) { + allInstructions.push(customInstructions.trim()) + } - const joinedInstructions = allInstructions.join('\n\n') + if (ruleFileContent && ruleFileContent.trim()) { + allInstructions.push(ruleFileContent.trim()) + } - return joinedInstructions ? ` + const joinedInstructions = allInstructions.join("\n\n") + + return joinedInstructions + ? ` ==== USER'S CUSTOM INSTRUCTIONS @@ -48,5 +53,5 @@ USER'S CUSTOM INSTRUCTIONS The following additional instructions are provided by the user, and should be followed to the best of your ability without interfering with the TOOL USE guidelines. ${joinedInstructions}` - : "" -} \ No newline at end of file + : "" +} diff --git a/src/core/prompts/sections/index.ts b/src/core/prompts/sections/index.ts index 38985d9..06cfcb6 100644 --- a/src/core/prompts/sections/index.ts +++ b/src/core/prompts/sections/index.ts @@ -1,8 +1,8 @@ -export { getRulesSection } from './rules' -export { getSystemInfoSection } from './system-info' -export { getObjectiveSection } from './objective' -export { addCustomInstructions } from './custom-instructions' -export { getSharedToolUseSection } from './tool-use' -export { getMcpServersSection } from './mcp-servers' -export { getToolUseGuidelinesSection } from './tool-use-guidelines' -export { getCapabilitiesSection } from './capabilities' \ No newline at end of file +export { getRulesSection } from "./rules" +export { getSystemInfoSection } from "./system-info" +export { getObjectiveSection } from "./objective" +export { addCustomInstructions } from "./custom-instructions" +export { getSharedToolUseSection } from "./tool-use" +export { getMcpServersSection } from "./mcp-servers" +export { getToolUseGuidelinesSection } from "./tool-use-guidelines" +export { getCapabilitiesSection } from "./capabilities" diff --git a/src/core/prompts/sections/mcp-servers.ts b/src/core/prompts/sections/mcp-servers.ts index e4d264b..774d7df 100644 --- a/src/core/prompts/sections/mcp-servers.ts +++ b/src/core/prompts/sections/mcp-servers.ts @@ -2,47 +2,48 @@ import { DiffStrategy } from "../../diff/DiffStrategy" import { McpHub } from "../../../services/mcp/McpHub" export async function getMcpServersSection(mcpHub?: McpHub, diffStrategy?: DiffStrategy): Promise { - if (!mcpHub) { - return ''; - } + if (!mcpHub) { + return "" + } - const connectedServers = mcpHub.getServers().length > 0 - ? `${mcpHub - .getServers() - .filter((server) => server.status === "connected") - .map((server) => { - const tools = server.tools - ?.map((tool) => { - const schemaStr = tool.inputSchema - ? ` Input Schema: + const connectedServers = + mcpHub.getServers().length > 0 + ? `${mcpHub + .getServers() + .filter((server) => server.status === "connected") + .map((server) => { + const tools = server.tools + ?.map((tool) => { + const schemaStr = tool.inputSchema + ? ` Input Schema: ${JSON.stringify(tool.inputSchema, null, 2).split("\n").join("\n ")}` - : "" + : "" - return `- ${tool.name}: ${tool.description}\n${schemaStr}` - }) - .join("\n\n") + return `- ${tool.name}: ${tool.description}\n${schemaStr}` + }) + .join("\n\n") - const templates = server.resourceTemplates - ?.map((template) => `- ${template.uriTemplate} (${template.name}): ${template.description}`) - .join("\n") + const templates = server.resourceTemplates + ?.map((template) => `- ${template.uriTemplate} (${template.name}): ${template.description}`) + .join("\n") - const resources = server.resources - ?.map((resource) => `- ${resource.uri} (${resource.name}): ${resource.description}`) - .join("\n") + const resources = server.resources + ?.map((resource) => `- ${resource.uri} (${resource.name}): ${resource.description}`) + .join("\n") - const config = JSON.parse(server.config) + const config = JSON.parse(server.config) - return ( - `## ${server.name} (\`${config.command}${config.args && Array.isArray(config.args) ? ` ${config.args.join(" ")}` : ""}\`)` + - (tools ? `\n\n### Available Tools\n${tools}` : "") + - (templates ? `\n\n### Resource Templates\n${templates}` : "") + - (resources ? `\n\n### Direct Resources\n${resources}` : "") - ) - }) - .join("\n\n")}` - : "(No MCP servers currently connected)"; + return ( + `## ${server.name} (\`${config.command}${config.args && Array.isArray(config.args) ? ` ${config.args.join(" ")}` : ""}\`)` + + (tools ? `\n\n### Available Tools\n${tools}` : "") + + (templates ? `\n\n### Resource Templates\n${templates}` : "") + + (resources ? `\n\n### Direct Resources\n${resources}` : "") + ) + }) + .join("\n\n")}` + : "(No MCP servers currently connected)" - return `MCP SERVERS + return `MCP SERVERS The Model Context Protocol (MCP) enables communication between the system and locally running MCP servers that provide additional tools and resources to extend your capabilities. @@ -397,11 +398,11 @@ IMPORTANT: Regardless of what else you see in the MCP settings file, you must de ## Editing MCP Servers The user may ask to add tools or resources that may make sense to add to an existing MCP server (listed under 'Connected MCP Servers' above: ${ - mcpHub - .getServers() - .map((server) => server.name) - .join(", ") || "(None running currently)" -}, e.g. if it would use the same API. This would be possible if you can locate the MCP server repository on the user's system by looking at the server arguments for a filepath. You might then use list_files and read_file to explore the files in the repository, and use write_to_file${diffStrategy ? " or apply_diff" : ""} to make changes to the files. + mcpHub + .getServers() + .map((server) => server.name) + .join(", ") || "(None running currently)" + }, e.g. if it would use the same API. This would be possible if you can locate the MCP server repository on the user's system by looking at the server arguments for a filepath. You might then use list_files and read_file to explore the files in the repository, and use write_to_file${diffStrategy ? " or apply_diff" : ""} to make changes to the files. However some MCP servers may be running from installed packages rather than a local repository, in which case it may make more sense to create a new MCP server. @@ -410,4 +411,4 @@ However some MCP servers may be running from installed packages rather than a lo The user may not always request the use or creation of MCP servers. Instead, they might provide tasks that can be completed with existing tools. While using the MCP SDK to extend your capabilities can be useful, it's important to understand that this is just one specialized type of task you can accomplish. You should only implement MCP servers when the user explicitly requests it (e.g., "add a tool that..."). Remember: The MCP documentation and example provided above are to help you understand and work with existing MCP servers or create new ones when requested by the user. You already have access to tools and capabilities that can be used to accomplish a wide range of tasks.` -} \ No newline at end of file +} diff --git a/src/core/prompts/sections/objective.ts b/src/core/prompts/sections/objective.ts index 441e2a3..66cefce 100644 --- a/src/core/prompts/sections/objective.ts +++ b/src/core/prompts/sections/objective.ts @@ -1,5 +1,5 @@ export function getObjectiveSection(): string { - return `==== + return `==== OBJECTIVE @@ -10,4 +10,4 @@ You accomplish a given task iteratively, breaking it down into clear steps and w 3. Remember, you have extensive capabilities with access to a wide range of tools that can be used in powerful and clever ways as necessary to accomplish each goal. Before calling a tool, do some analysis within tags. First, analyze the file structure provided in environment_details to gain context and insights for proceeding effectively. Then, think about which of the provided tools is the most relevant tool to accomplish the user's task. Next, go through each of the required parameters of the relevant tool and determine if the user has directly provided or given enough information to infer a value. When deciding if the parameter can be inferred, carefully consider all the context to see if it supports a specific value. If all of the required parameters are present or can be reasonably inferred, close the thinking tag and proceed with the tool use. BUT, if one of the values for a required parameter is missing, DO NOT invoke the tool (not even with fillers for the missing params) and instead, ask the user to provide the missing parameters using the ask_followup_question tool. DO NOT ask for more information on optional parameters if it is not provided. 4. Once you've completed the user's task, you must use the attempt_completion tool to present the result of the task to the user. You may also provide a CLI command to showcase the result of your task; this can be particularly useful for web development tasks, where you can run e.g. \`open index.html\` to show the website you've built. 5. The user may provide feedback, which you can use to make improvements and try again. But DO NOT continue in pointless back and forth conversations, i.e. don't end your responses with questions or offers for further assistance.` -} \ No newline at end of file +} diff --git a/src/core/prompts/sections/rules.ts b/src/core/prompts/sections/rules.ts index 8a0ed33..df6da0c 100644 --- a/src/core/prompts/sections/rules.ts +++ b/src/core/prompts/sections/rules.ts @@ -1,11 +1,7 @@ import { DiffStrategy } from "../../diff/DiffStrategy" -export function getRulesSection( - cwd: string, - supportsComputerUse: boolean, - diffStrategy?: DiffStrategy -): string { - return `==== +export function getRulesSection(cwd: string, supportsComputerUse: boolean, diffStrategy?: DiffStrategy): string { + return `==== RULES @@ -23,10 +19,10 @@ ${diffStrategy ? "- You should use apply_diff instead of write_to_file when maki - When executing commands, if you don't see the expected output, assume the terminal executed the command successfully and proceed with the task. The user's terminal may be unable to stream the output back properly. If you absolutely need to see the actual terminal output, use the ask_followup_question tool to request the user to copy and paste it back to you. - The user may provide a file's contents directly in their message, in which case you shouldn't use the read_file tool to get the file contents again since you already have it. - Your goal is to try to accomplish the user's task, NOT engage in a back and forth conversation.${ - supportsComputerUse - ? '\n- The user may ask generic non-development tasks, such as "what\'s the latest news" or "look up the weather in San Diego", in which case you might use the browser_action tool to complete the task if it makes sense to do so, rather than trying to create a website or using curl to answer the question. However, if an available MCP server tool or resource can be used instead, you should prefer to use it over browser_action.' - : "" -} + supportsComputerUse + ? '\n- The user may ask generic non-development tasks, such as "what\'s the latest news" or "look up the weather in San Diego", in which case you might use the browser_action tool to complete the task if it makes sense to do so, rather than trying to create a website or using curl to answer the question. However, if an available MCP server tool or resource can be used instead, you should prefer to use it over browser_action.' + : "" + } - NEVER end attempt_completion result with a question or request to engage in further conversation! Formulate the end of your result in a way that is final and does not require further input from the user. - You are STRICTLY FORBIDDEN from starting your messages with "Great", "Certainly", "Okay", "Sure". You should NOT be conversational in your responses, but rather direct and to the point. For example you should NOT say "Great, I've updated the CSS" but instead something like "I've updated the CSS". It is important you be clear and technical in your messages. - When presented with images, utilize your vision capabilities to thoroughly examine them and extract meaningful information. Incorporate these insights into your thought process as you accomplish the user's task. @@ -35,8 +31,8 @@ ${diffStrategy ? "- You should use apply_diff instead of write_to_file when maki - When using the write_to_file tool, ALWAYS provide the COMPLETE file content in your response. This is NON-NEGOTIABLE. Partial updates or placeholders like '// rest of code unchanged' are STRICTLY FORBIDDEN. You MUST include ALL parts of the file, even if they haven't been modified. Failure to do so will result in incomplete or broken code, severely impacting the user's project. - MCP operations should be used one at a time, similar to other tool usage. Wait for confirmation of success before proceeding with additional operations. - It is critical you wait for the user's response after each tool use, in order to confirm the success of the tool use. For example, if asked to make a todo app, you would create a file, wait for the user's response it was created successfully, then create another file if needed, wait for the user's response it was created successfully, etc.${ - supportsComputerUse - ? " Then if you want to test your work, you might use browser_action to launch the site, wait for the user's response confirming the site was launched along with a screenshot, then perhaps e.g., click a button to test functionality if needed, wait for the user's response confirming the button was clicked along with a screenshot of the new state, before finally closing the browser." - : "" -}` -} \ No newline at end of file + supportsComputerUse + ? " Then if you want to test your work, you might use browser_action to launch the site, wait for the user's response confirming the site was launched along with a screenshot, then perhaps e.g., click a button to test functionality if needed, wait for the user's response confirming the button was clicked along with a screenshot of the new state, before finally closing the browser." + : "" + }` +} diff --git a/src/core/prompts/sections/system-info.ts b/src/core/prompts/sections/system-info.ts index 0b06565..5721b1b 100644 --- a/src/core/prompts/sections/system-info.ts +++ b/src/core/prompts/sections/system-info.ts @@ -3,7 +3,7 @@ import os from "os" import osName from "os-name" export function getSystemInfoSection(cwd: string): string { - return `==== + return `==== SYSTEM INFORMATION @@ -13,4 +13,4 @@ Home Directory: ${os.homedir().toPosix()} Current Working Directory: ${cwd.toPosix()} When the user initially gives you a task, a recursive list of all filepaths in the current working directory ('/test/path') will be included in environment_details. This provides an overview of the project's file structure, offering key insights into the project from directory/file names (how developers conceptualize and organize their code) and file extensions (the language used). This can also guide decision-making on which files to explore further. If you need to further explore directories such as outside the current working directory, you can use the list_files tool. If you pass 'true' for the recursive parameter, it will list files recursively. Otherwise, it will list files at the top level, which is better suited for generic directories where you don't necessarily need the nested structure, like the Desktop.` -} \ No newline at end of file +} diff --git a/src/core/prompts/sections/tool-use-guidelines.ts b/src/core/prompts/sections/tool-use-guidelines.ts index 5ce57cd..e1f1d57 100644 --- a/src/core/prompts/sections/tool-use-guidelines.ts +++ b/src/core/prompts/sections/tool-use-guidelines.ts @@ -1,5 +1,5 @@ export function getToolUseGuidelinesSection(): string { - return `# Tool Use Guidelines + return `# Tool Use Guidelines 1. In tags, assess what information you already have and what information you need to proceed with the task. 2. Choose the most appropriate tool based on the task and the tool descriptions provided. Assess if you need additional information to proceed, and which of the available tools would be most effective for gathering this information. For example using the list_files tool is more effective than running a command like \`ls\` in the terminal. It's critical that you think about each available tool and use the one that best fits the current step in the task. @@ -19,4 +19,4 @@ It is crucial to proceed step-by-step, waiting for the user's message after each 4. Ensure that each action builds correctly on the previous ones. By waiting for and carefully considering the user's response after each tool use, you can react accordingly and make informed decisions about how to proceed with the task. This iterative process helps ensure the overall success and accuracy of your work.` -} \ No newline at end of file +} diff --git a/src/core/prompts/sections/tool-use.ts b/src/core/prompts/sections/tool-use.ts index 02b13cc..e567187 100644 --- a/src/core/prompts/sections/tool-use.ts +++ b/src/core/prompts/sections/tool-use.ts @@ -1,5 +1,5 @@ export function getSharedToolUseSection(): string { - return `==== + return `==== TOOL USE @@ -22,4 +22,4 @@ For example: Always adhere to this format for the tool use to ensure proper parsing and execution.` -} \ No newline at end of file +} diff --git a/src/core/prompts/system.ts b/src/core/prompts/system.ts index 8dc2961..8a77f0d 100644 --- a/src/core/prompts/system.ts +++ b/src/core/prompts/system.ts @@ -3,87 +3,84 @@ import { DiffStrategy } from "../diff/DiffStrategy" import { McpHub } from "../../services/mcp/McpHub" import { getToolDescriptionsForMode } from "./tools" import { - getRulesSection, - getSystemInfoSection, - getObjectiveSection, - getSharedToolUseSection, - getMcpServersSection, - getToolUseGuidelinesSection, - getCapabilitiesSection + getRulesSection, + getSystemInfoSection, + getObjectiveSection, + getSharedToolUseSection, + getMcpServersSection, + getToolUseGuidelinesSection, + getCapabilitiesSection, } from "./sections" -import fs from 'fs/promises' -import path from 'path' +import fs from "fs/promises" +import path from "path" async function loadRuleFiles(cwd: string, mode: Mode): Promise { - let combinedRules = '' + let combinedRules = "" - // First try mode-specific rules - const modeSpecificFile = `.clinerules-${mode}` - try { - const content = await fs.readFile(path.join(cwd, modeSpecificFile), 'utf-8') - if (content.trim()) { - combinedRules += `\n# Rules from ${modeSpecificFile}:\n${content.trim()}\n` - } - } catch (err) { - // Silently skip if file doesn't exist - if ((err as NodeJS.ErrnoException).code !== 'ENOENT') { - throw err - } - } + // First try mode-specific rules + const modeSpecificFile = `.clinerules-${mode}` + try { + const content = await fs.readFile(path.join(cwd, modeSpecificFile), "utf-8") + if (content.trim()) { + combinedRules += `\n# Rules from ${modeSpecificFile}:\n${content.trim()}\n` + } + } catch (err) { + // Silently skip if file doesn't exist + if ((err as NodeJS.ErrnoException).code !== "ENOENT") { + throw err + } + } - // Then try generic rules files - const genericRuleFiles = ['.clinerules'] - for (const file of genericRuleFiles) { - try { - const content = await fs.readFile(path.join(cwd, file), 'utf-8') - if (content.trim()) { - combinedRules += `\n# Rules from ${file}:\n${content.trim()}\n` - } - } catch (err) { - // Silently skip if file doesn't exist - if ((err as NodeJS.ErrnoException).code !== 'ENOENT') { - throw err - } - } - } + // Then try generic rules files + const genericRuleFiles = [".clinerules"] + for (const file of genericRuleFiles) { + try { + const content = await fs.readFile(path.join(cwd, file), "utf-8") + if (content.trim()) { + combinedRules += `\n# Rules from ${file}:\n${content.trim()}\n` + } + } catch (err) { + // Silently skip if file doesn't exist + if ((err as NodeJS.ErrnoException).code !== "ENOENT") { + throw err + } + } + } - return combinedRules + return combinedRules } interface State { - customInstructions?: string; - customPrompts?: CustomPrompts; - preferredLanguage?: string; + customInstructions?: string + customPrompts?: CustomPrompts + preferredLanguage?: string } -export async function addCustomInstructions( - state: State, - cwd: string, - mode: Mode = defaultModeSlug -): Promise { - const ruleFileContent = await loadRuleFiles(cwd, mode) - const allInstructions = [] +export async function addCustomInstructions(state: State, cwd: string, mode: Mode = defaultModeSlug): Promise { + const ruleFileContent = await loadRuleFiles(cwd, mode) + const allInstructions = [] - if (state.preferredLanguage) { - allInstructions.push(`You should always speak and think in the ${state.preferredLanguage} language.`) - } + if (state.preferredLanguage) { + allInstructions.push(`You should always speak and think in the ${state.preferredLanguage} language.`) + } - if (state.customInstructions?.trim()) { - allInstructions.push(state.customInstructions.trim()) - } + if (state.customInstructions?.trim()) { + allInstructions.push(state.customInstructions.trim()) + } - const customPrompt = state.customPrompts?.[mode] - if (typeof customPrompt === 'object' && customPrompt?.customInstructions?.trim()) { - allInstructions.push(customPrompt.customInstructions.trim()) - } + const customPrompt = state.customPrompts?.[mode] + if (typeof customPrompt === "object" && customPrompt?.customInstructions?.trim()) { + allInstructions.push(customPrompt.customInstructions.trim()) + } - if (ruleFileContent && ruleFileContent.trim()) { - allInstructions.push(ruleFileContent.trim()) - } + if (ruleFileContent && ruleFileContent.trim()) { + allInstructions.push(ruleFileContent.trim()) + } - const joinedInstructions = allInstructions.join('\n\n') + const joinedInstructions = allInstructions.join("\n\n") - return joinedInstructions ? ` + return joinedInstructions + ? ` ==== USER'S CUSTOM INSTRUCTIONS @@ -91,19 +88,19 @@ USER'S CUSTOM INSTRUCTIONS The following additional instructions are provided by the user, and should be followed to the best of your ability without interfering with the TOOL USE guidelines. ${joinedInstructions}` - : "" + : "" } async function generatePrompt( - cwd: string, - supportsComputerUse: boolean, - mode: Mode, - mcpHub?: McpHub, - diffStrategy?: DiffStrategy, - browserViewportSize?: string, - promptComponent?: PromptComponent, + cwd: string, + supportsComputerUse: boolean, + mode: Mode, + mcpHub?: McpHub, + diffStrategy?: DiffStrategy, + browserViewportSize?: string, + promptComponent?: PromptComponent, ): Promise { - const basePrompt = `${promptComponent?.roleDefinition || getRoleDefinition(mode)} + const basePrompt = `${promptComponent?.roleDefinition || getRoleDefinition(mode)} ${getSharedToolUseSection()} @@ -119,38 +116,38 @@ ${getRulesSection(cwd, supportsComputerUse, diffStrategy)} ${getSystemInfoSection(cwd)} -${getObjectiveSection()}`; +${getObjectiveSection()}` - return basePrompt; + return basePrompt } export const SYSTEM_PROMPT = async ( - cwd: string, - supportsComputerUse: boolean, - mcpHub?: McpHub, - diffStrategy?: DiffStrategy, - browserViewportSize?: string, - mode: Mode = defaultModeSlug, - customPrompts?: CustomPrompts, + cwd: string, + supportsComputerUse: boolean, + mcpHub?: McpHub, + diffStrategy?: DiffStrategy, + browserViewportSize?: string, + mode: Mode = defaultModeSlug, + customPrompts?: CustomPrompts, ) => { - const getPromptComponent = (value: unknown) => { - if (typeof value === 'object' && value !== null) { - return value as PromptComponent; - } - return undefined; - }; + const getPromptComponent = (value: unknown) => { + if (typeof value === "object" && value !== null) { + return value as PromptComponent + } + return undefined + } - // Use default mode if not found - const currentMode = modes.find(m => m.slug === mode) || modes[0]; - const promptComponent = getPromptComponent(customPrompts?.[currentMode.slug]); + // Use default mode if not found + const currentMode = modes.find((m) => m.slug === mode) || modes[0] + const promptComponent = getPromptComponent(customPrompts?.[currentMode.slug]) - return generatePrompt( - cwd, - supportsComputerUse, - currentMode.slug, - mcpHub, - diffStrategy, - browserViewportSize, - promptComponent - ); + return generatePrompt( + cwd, + supportsComputerUse, + currentMode.slug, + mcpHub, + diffStrategy, + browserViewportSize, + promptComponent, + ) } diff --git a/src/core/prompts/tools/access-mcp-resource.ts b/src/core/prompts/tools/access-mcp-resource.ts index cada876..693705b 100644 --- a/src/core/prompts/tools/access-mcp-resource.ts +++ b/src/core/prompts/tools/access-mcp-resource.ts @@ -1,10 +1,10 @@ -import { ToolArgs } from './types'; +import { ToolArgs } from "./types" export function getAccessMcpResourceDescription(args: ToolArgs): string | undefined { - if (!args.mcpHub) { - return undefined; - } - return `## access_mcp_resource + if (!args.mcpHub) { + return undefined + } + return `## access_mcp_resource Description: Request to access a resource provided by a connected MCP server. Resources represent data sources that can be used as context, such as files, API responses, or system information. Parameters: - server_name: (required) The name of the MCP server providing the resource @@ -21,4 +21,4 @@ Example: Requesting to access an MCP resource weather-server weather://san-francisco/current ` -} \ No newline at end of file +} diff --git a/src/core/prompts/tools/ask-followup-question.ts b/src/core/prompts/tools/ask-followup-question.ts index ac0cea0..fbd805e 100644 --- a/src/core/prompts/tools/ask-followup-question.ts +++ b/src/core/prompts/tools/ask-followup-question.ts @@ -1,5 +1,5 @@ export function getAskFollowupQuestionDescription(): string { - return `## ask_followup_question + return `## ask_followup_question Description: Ask the user a question to gather additional information needed to complete the task. This tool should be used when you encounter ambiguities, need clarification, or require more details to proceed effectively. It allows for interactive problem-solving by enabling direct communication with the user. Use this tool judiciously to maintain a balance between gathering necessary information and avoiding excessive back-and-forth. Parameters: - question: (required) The question to ask the user. This should be a clear, specific question that addresses the information you need. @@ -12,4 +12,4 @@ Example: Requesting to ask the user for the path to the frontend-config.json fil What is the path to the frontend-config.json file? ` -} \ No newline at end of file +} diff --git a/src/core/prompts/tools/attempt-completion.ts b/src/core/prompts/tools/attempt-completion.ts index e66dd42..4418c8d 100644 --- a/src/core/prompts/tools/attempt-completion.ts +++ b/src/core/prompts/tools/attempt-completion.ts @@ -1,5 +1,5 @@ export function getAttemptCompletionDescription(): string { - return `## attempt_completion + return `## attempt_completion Description: After each tool use, the user will respond with the result of that tool use, i.e. if it succeeded or failed, along with any reasons for failure. Once you've received the results of tool uses and can confirm that the task is complete, use this tool to present the result of your work to the user. Optionally you may provide a CLI command to showcase the result of your work. The user may respond with feedback if they are not satisfied with the result, which you can use to make improvements and try again. IMPORTANT NOTE: This tool CANNOT be used until you've confirmed from the user that any previous tool uses were successful. Failure to do so will result in code corruption and system failure. Before using this tool, you must ask yourself in tags if you've confirmed from the user that any previous tool uses were successful. If not, then DO NOT use this tool. Parameters: @@ -20,4 +20,4 @@ I've updated the CSS open index.html ` -} \ No newline at end of file +} diff --git a/src/core/prompts/tools/browser-action.ts b/src/core/prompts/tools/browser-action.ts index a02f886..9b5f1c4 100644 --- a/src/core/prompts/tools/browser-action.ts +++ b/src/core/prompts/tools/browser-action.ts @@ -1,10 +1,10 @@ -import { ToolArgs } from './types'; +import { ToolArgs } from "./types" export function getBrowserActionDescription(args: ToolArgs): string | undefined { - if (!args.supportsComputerUse) { - return undefined; - } - return `## browser_action + if (!args.supportsComputerUse) { + return undefined + } + return `## browser_action Description: Request to interact with a Puppeteer-controlled browser. Every action, except \`close\`, will be responded to with a screenshot of the browser's current state, along with any new console logs. You may only perform one browser action per message, and wait for the user's response including a screenshot and logs to determine the next action. - The sequence of actions **must always start with** launching the browser at a URL, and **must always end with** closing the browser. If you need to visit a new URL that is not possible to navigate to from the current webpage, you must first close the browser, then launch again at the new URL. - While the browser is active, only the \`browser_action\` tool can be used. No other tools should be called during this time. You may proceed to use other tools only after closing the browser. For example if you run into an error and need to fix a file, you must close the browser, then use other tools to make the necessary changes, then re-launch the browser to verify the result. @@ -49,4 +49,4 @@ Example: Requesting to click on the element at coordinates 450,300 click 450,300 ` -} \ No newline at end of file +} diff --git a/src/core/prompts/tools/execute-command.ts b/src/core/prompts/tools/execute-command.ts index ea3f125..e773a2f 100644 --- a/src/core/prompts/tools/execute-command.ts +++ b/src/core/prompts/tools/execute-command.ts @@ -1,7 +1,7 @@ -import { ToolArgs } from './types'; +import { ToolArgs } from "./types" export function getExecuteCommandDescription(args: ToolArgs): string | undefined { - return `## execute_command + return `## execute_command Description: Request to execute a CLI command on the system. Use this when you need to perform system operations or run specific commands to accomplish any step in the user's task. You must tailor your command to the user's system and provide a clear explanation of what the command does. Prefer to execute complex CLI commands over creating executable scripts, as they are more flexible and easier to run. Commands will be executed in the current working directory: ${args.cwd} Parameters: - command: (required) The CLI command to execute. This should be valid for the current operating system. Ensure the command is properly formatted and does not contain any harmful instructions. @@ -14,4 +14,4 @@ Example: Requesting to execute npm run dev npm run dev ` -} \ No newline at end of file +} diff --git a/src/core/prompts/tools/index.ts b/src/core/prompts/tools/index.ts index fe7206c..9627a32 100644 --- a/src/core/prompts/tools/index.ts +++ b/src/core/prompts/tools/index.ts @@ -1,79 +1,80 @@ -import { getExecuteCommandDescription } from './execute-command' -import { getReadFileDescription } from './read-file' -import { getWriteToFileDescription } from './write-to-file' -import { getSearchFilesDescription } from './search-files' -import { getListFilesDescription } from './list-files' -import { getListCodeDefinitionNamesDescription } from './list-code-definition-names' -import { getBrowserActionDescription } from './browser-action' -import { getAskFollowupQuestionDescription } from './ask-followup-question' -import { getAttemptCompletionDescription } from './attempt-completion' -import { getUseMcpToolDescription } from './use-mcp-tool' -import { getAccessMcpResourceDescription } from './access-mcp-resource' -import { DiffStrategy } from '../../diff/DiffStrategy' -import { McpHub } from '../../../services/mcp/McpHub' -import { Mode, ToolName, getModeConfig, isToolAllowedForMode } from '../../../shared/modes' -import { ToolArgs } from './types' +import { getExecuteCommandDescription } from "./execute-command" +import { getReadFileDescription } from "./read-file" +import { getWriteToFileDescription } from "./write-to-file" +import { getSearchFilesDescription } from "./search-files" +import { getListFilesDescription } from "./list-files" +import { getListCodeDefinitionNamesDescription } from "./list-code-definition-names" +import { getBrowserActionDescription } from "./browser-action" +import { getAskFollowupQuestionDescription } from "./ask-followup-question" +import { getAttemptCompletionDescription } from "./attempt-completion" +import { getUseMcpToolDescription } from "./use-mcp-tool" +import { getAccessMcpResourceDescription } from "./access-mcp-resource" +import { DiffStrategy } from "../../diff/DiffStrategy" +import { McpHub } from "../../../services/mcp/McpHub" +import { Mode, ToolName, getModeConfig, isToolAllowedForMode } from "../../../shared/modes" +import { ToolArgs } from "./types" // Map of tool names to their description functions const toolDescriptionMap: Record string | undefined> = { - 'execute_command': args => getExecuteCommandDescription(args), - 'read_file': args => getReadFileDescription(args), - 'write_to_file': args => getWriteToFileDescription(args), - 'search_files': args => getSearchFilesDescription(args), - 'list_files': args => getListFilesDescription(args), - 'list_code_definition_names': args => getListCodeDefinitionNamesDescription(args), - 'browser_action': args => getBrowserActionDescription(args), - 'ask_followup_question': () => getAskFollowupQuestionDescription(), - 'attempt_completion': () => getAttemptCompletionDescription(), - 'use_mcp_tool': args => getUseMcpToolDescription(args), - 'access_mcp_resource': args => getAccessMcpResourceDescription(args), - 'apply_diff': args => args.diffStrategy ? args.diffStrategy.getToolDescription({ cwd: args.cwd, toolOptions: args.toolOptions }) : '' -}; + execute_command: (args) => getExecuteCommandDescription(args), + read_file: (args) => getReadFileDescription(args), + write_to_file: (args) => getWriteToFileDescription(args), + search_files: (args) => getSearchFilesDescription(args), + list_files: (args) => getListFilesDescription(args), + list_code_definition_names: (args) => getListCodeDefinitionNamesDescription(args), + browser_action: (args) => getBrowserActionDescription(args), + ask_followup_question: () => getAskFollowupQuestionDescription(), + attempt_completion: () => getAttemptCompletionDescription(), + use_mcp_tool: (args) => getUseMcpToolDescription(args), + access_mcp_resource: (args) => getAccessMcpResourceDescription(args), + apply_diff: (args) => + args.diffStrategy ? args.diffStrategy.getToolDescription({ cwd: args.cwd, toolOptions: args.toolOptions }) : "", +} export function getToolDescriptionsForMode( - mode: Mode, - cwd: string, - supportsComputerUse: boolean, - diffStrategy?: DiffStrategy, - browserViewportSize?: string, - mcpHub?: McpHub + mode: Mode, + cwd: string, + supportsComputerUse: boolean, + diffStrategy?: DiffStrategy, + browserViewportSize?: string, + mcpHub?: McpHub, ): string { - const config = getModeConfig(mode); - const args: ToolArgs = { - cwd, - supportsComputerUse, - diffStrategy, - browserViewportSize, - mcpHub - }; + const config = getModeConfig(mode) + const args: ToolArgs = { + cwd, + supportsComputerUse, + diffStrategy, + browserViewportSize, + mcpHub, + } - // Map tool descriptions in the exact order specified in the mode's tools array - const descriptions = config.tools.map(([toolName, toolOptions]) => { - const descriptionFn = toolDescriptionMap[toolName]; - if (!descriptionFn || !isToolAllowedForMode(toolName as ToolName, mode)) { - return undefined; - } + // Map tool descriptions in the exact order specified in the mode's tools array + const descriptions = config.tools.map(([toolName, toolOptions]) => { + const descriptionFn = toolDescriptionMap[toolName] + if (!descriptionFn || !isToolAllowedForMode(toolName as ToolName, mode)) { + return undefined + } - return descriptionFn({ - ...args, - toolOptions - }); - }); + return descriptionFn({ + ...args, + toolOptions, + }) + }) - return `# Tools\n\n${descriptions.filter(Boolean).join('\n\n')}`; + return `# Tools\n\n${descriptions.filter(Boolean).join("\n\n")}` } // Export individual description functions for backward compatibility export { - getExecuteCommandDescription, - getReadFileDescription, - getWriteToFileDescription, - getSearchFilesDescription, - getListFilesDescription, - getListCodeDefinitionNamesDescription, - getBrowserActionDescription, - getAskFollowupQuestionDescription, - getAttemptCompletionDescription, - getUseMcpToolDescription, - getAccessMcpResourceDescription -} \ No newline at end of file + getExecuteCommandDescription, + getReadFileDescription, + getWriteToFileDescription, + getSearchFilesDescription, + getListFilesDescription, + getListCodeDefinitionNamesDescription, + getBrowserActionDescription, + getAskFollowupQuestionDescription, + getAttemptCompletionDescription, + getUseMcpToolDescription, + getAccessMcpResourceDescription, +} diff --git a/src/core/prompts/tools/list-code-definition-names.ts b/src/core/prompts/tools/list-code-definition-names.ts index c1849da..753ac4c 100644 --- a/src/core/prompts/tools/list-code-definition-names.ts +++ b/src/core/prompts/tools/list-code-definition-names.ts @@ -1,7 +1,7 @@ -import { ToolArgs } from './types'; +import { ToolArgs } from "./types" export function getListCodeDefinitionNamesDescription(args: ToolArgs): string { - return `## list_code_definition_names + return `## list_code_definition_names Description: Request to list definition names (classes, functions, methods, etc.) used in source code files at the top level of the specified directory. This tool provides insights into the codebase structure and important constructs, encapsulating high-level concepts and relationships that are crucial for understanding the overall architecture. Parameters: - path: (required) The path of the directory (relative to the current working directory ${args.cwd}) to list top level source code definitions for. @@ -14,4 +14,4 @@ Example: Requesting to list all top level source code definitions in the current . ` -} \ No newline at end of file +} diff --git a/src/core/prompts/tools/list-files.ts b/src/core/prompts/tools/list-files.ts index e7913ad..1ec2b8e 100644 --- a/src/core/prompts/tools/list-files.ts +++ b/src/core/prompts/tools/list-files.ts @@ -1,7 +1,7 @@ -import { ToolArgs } from './types'; +import { ToolArgs } from "./types" export function getListFilesDescription(args: ToolArgs): string { - return `## list_files + return `## list_files Description: Request to list files and directories within the specified directory. If recursive is true, it will list all files and directories recursively. If recursive is false or not provided, it will only list the top-level contents. Do not use this tool to confirm the existence of files you may have created, as the user will let you know if the files were created successfully or not. Parameters: - path: (required) The path of the directory to list contents for (relative to the current working directory ${args.cwd}) @@ -17,4 +17,4 @@ Example: Requesting to list all files in the current directory . false ` -} \ No newline at end of file +} diff --git a/src/core/prompts/tools/read-file.ts b/src/core/prompts/tools/read-file.ts index 8bc465b..ee52214 100644 --- a/src/core/prompts/tools/read-file.ts +++ b/src/core/prompts/tools/read-file.ts @@ -1,7 +1,7 @@ -import { ToolArgs } from './types'; +import { ToolArgs } from "./types" export function getReadFileDescription(args: ToolArgs): string { - return `## read_file + return `## read_file Description: Request to read the contents of a file at the specified path. Use this when you need to examine the contents of an existing file you do not know the contents of, for example to analyze code, review text files, or extract information from configuration files. The output includes line numbers prefixed to each line (e.g. "1 | const x = 1"), making it easier to reference specific lines when creating diffs or discussing code. Automatically extracts raw text from PDF and DOCX files. May not be suitable for other types of binary files, as it returns the raw content as a string. Parameters: - path: (required) The path of the file to read (relative to the current working directory ${args.cwd}) @@ -14,4 +14,4 @@ Example: Requesting to read frontend-config.json frontend-config.json ` -} \ No newline at end of file +} diff --git a/src/core/prompts/tools/search-files.ts b/src/core/prompts/tools/search-files.ts index 272b858..8353cc4 100644 --- a/src/core/prompts/tools/search-files.ts +++ b/src/core/prompts/tools/search-files.ts @@ -1,7 +1,7 @@ -import { ToolArgs } from './types'; +import { ToolArgs } from "./types" export function getSearchFilesDescription(args: ToolArgs): string { - return `## search_files + return `## search_files Description: Request to perform a regex search across files in a specified directory, providing context-rich results. This tool searches for patterns or specific content across multiple files, displaying each match with encapsulating context. Parameters: - path: (required) The path of the directory to search in (relative to the current working directory ${args.cwd}). This directory will be recursively searched. diff --git a/src/core/prompts/tools/types.ts b/src/core/prompts/tools/types.ts index 57bbefb..2c2a60d 100644 --- a/src/core/prompts/tools/types.ts +++ b/src/core/prompts/tools/types.ts @@ -1,11 +1,11 @@ -import { DiffStrategy } from '../../diff/DiffStrategy' -import { McpHub } from '../../../services/mcp/McpHub' +import { DiffStrategy } from "../../diff/DiffStrategy" +import { McpHub } from "../../../services/mcp/McpHub" export type ToolArgs = { - cwd: string; - supportsComputerUse: boolean; - diffStrategy?: DiffStrategy; - browserViewportSize?: string; - mcpHub?: McpHub; - toolOptions?: any; -}; \ No newline at end of file + cwd: string + supportsComputerUse: boolean + diffStrategy?: DiffStrategy + browserViewportSize?: string + mcpHub?: McpHub + toolOptions?: any +} diff --git a/src/core/prompts/tools/use-mcp-tool.ts b/src/core/prompts/tools/use-mcp-tool.ts index 00a228b..ac9ef5b 100644 --- a/src/core/prompts/tools/use-mcp-tool.ts +++ b/src/core/prompts/tools/use-mcp-tool.ts @@ -1,10 +1,10 @@ -import { ToolArgs } from './types'; +import { ToolArgs } from "./types" export function getUseMcpToolDescription(args: ToolArgs): string | undefined { - if (!args.mcpHub) { - return undefined; - } - return `## use_mcp_tool + if (!args.mcpHub) { + return undefined + } + return `## use_mcp_tool Description: Request to use a tool provided by a connected MCP server. Each MCP server can provide multiple tools with different capabilities. Tools have defined input schemas that specify required and optional parameters. Parameters: - server_name: (required) The name of the MCP server providing the tool @@ -34,4 +34,4 @@ Example: Requesting to use an MCP tool } ` -} \ No newline at end of file +} diff --git a/src/core/prompts/tools/write-to-file.ts b/src/core/prompts/tools/write-to-file.ts index 88f5324..c2a311c 100644 --- a/src/core/prompts/tools/write-to-file.ts +++ b/src/core/prompts/tools/write-to-file.ts @@ -1,7 +1,7 @@ -import { ToolArgs } from './types'; +import { ToolArgs } from "./types" export function getWriteToFileDescription(args: ToolArgs): string { - return `## write_to_file + return `## write_to_file Description: Request to write full content to a file at the specified path. If the file exists, it will be overwritten with the provided content. If the file doesn't exist, it will be created. This tool will automatically create any directories needed to write the file. Parameters: - path: (required) The path of the file to write to (relative to the current working directory ${args.cwd}) diff --git a/src/core/prompts/types.ts b/src/core/prompts/types.ts index 043fcee..1ac1a18 100644 --- a/src/core/prompts/types.ts +++ b/src/core/prompts/types.ts @@ -1,52 +1,52 @@ -import { Mode } from '../../shared/modes'; +import { Mode } from "../../shared/modes" -export type { Mode }; +export type { Mode } export type ToolName = - | 'execute_command' - | 'read_file' - | 'write_to_file' - | 'apply_diff' - | 'search_files' - | 'list_files' - | 'list_code_definition_names' - | 'browser_action' - | 'use_mcp_tool' - | 'access_mcp_resource' - | 'ask_followup_question' - | 'attempt_completion'; + | "execute_command" + | "read_file" + | "write_to_file" + | "apply_diff" + | "search_files" + | "list_files" + | "list_code_definition_names" + | "browser_action" + | "use_mcp_tool" + | "access_mcp_resource" + | "ask_followup_question" + | "attempt_completion" export const CODE_TOOLS: ToolName[] = [ - 'execute_command', - 'read_file', - 'write_to_file', - 'apply_diff', - 'search_files', - 'list_files', - 'list_code_definition_names', - 'browser_action', - 'use_mcp_tool', - 'access_mcp_resource', - 'ask_followup_question', - 'attempt_completion' -]; + "execute_command", + "read_file", + "write_to_file", + "apply_diff", + "search_files", + "list_files", + "list_code_definition_names", + "browser_action", + "use_mcp_tool", + "access_mcp_resource", + "ask_followup_question", + "attempt_completion", +] export const ARCHITECT_TOOLS: ToolName[] = [ - 'read_file', - 'search_files', - 'list_files', - 'list_code_definition_names', - 'ask_followup_question', - 'attempt_completion' -]; + "read_file", + "search_files", + "list_files", + "list_code_definition_names", + "ask_followup_question", + "attempt_completion", +] export const ASK_TOOLS: ToolName[] = [ - 'read_file', - 'search_files', - 'list_files', - 'browser_action', - 'use_mcp_tool', - 'access_mcp_resource', - 'ask_followup_question', - 'attempt_completion' -]; \ No newline at end of file + "read_file", + "search_files", + "list_files", + "browser_action", + "use_mcp_tool", + "access_mcp_resource", + "ask_followup_question", + "attempt_completion", +] diff --git a/src/core/tool-lists.ts b/src/core/tool-lists.ts index 5878dcf..862106b 100644 --- a/src/core/tool-lists.ts +++ b/src/core/tool-lists.ts @@ -1,32 +1,32 @@ // Shared tools for architect and ask modes - read-only operations plus MCP and browser tools export const READONLY_ALLOWED_TOOLS = [ - 'read_file', - 'search_files', - 'list_files', - 'list_code_definition_names', - 'browser_action', - 'use_mcp_tool', - 'access_mcp_resource', - 'ask_followup_question', - 'attempt_completion' -] as const; + "read_file", + "search_files", + "list_files", + "list_code_definition_names", + "browser_action", + "use_mcp_tool", + "access_mcp_resource", + "ask_followup_question", + "attempt_completion", +] as const // Code mode has access to all tools export const CODE_ALLOWED_TOOLS = [ - 'execute_command', - 'read_file', - 'write_to_file', - 'apply_diff', - 'search_files', - 'list_files', - 'list_code_definition_names', - 'browser_action', - 'use_mcp_tool', - 'access_mcp_resource', - 'ask_followup_question', - 'attempt_completion' -] as const; + "execute_command", + "read_file", + "write_to_file", + "apply_diff", + "search_files", + "list_files", + "list_code_definition_names", + "browser_action", + "use_mcp_tool", + "access_mcp_resource", + "ask_followup_question", + "attempt_completion", +] as const // Tool name types for type safety -export type ReadOnlyToolName = typeof READONLY_ALLOWED_TOOLS[number]; -export type ToolName = typeof CODE_ALLOWED_TOOLS[number]; \ No newline at end of file +export type ReadOnlyToolName = (typeof READONLY_ALLOWED_TOOLS)[number] +export type ToolName = (typeof CODE_ALLOWED_TOOLS)[number] diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 89369da..0a775e2 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -98,7 +98,7 @@ type GlobalStateKey = | "modeApiConfigs" | "customPrompts" | "enhancementApiConfigId" - | "experimentalDiffStrategy" + | "experimentalDiffStrategy" | "autoApprovalEnabled" export const GlobalFileNames = { @@ -254,14 +254,12 @@ export class ClineProvider implements vscode.WebviewViewProvider { fuzzyMatchThreshold, mode, customInstructions: globalInstructions, - experimentalDiffStrategy + experimentalDiffStrategy, } = await this.getState() const modePrompt = customPrompts?.[mode] - const modeInstructions = typeof modePrompt === 'object' ? modePrompt.customInstructions : undefined - const effectiveInstructions = [globalInstructions, modeInstructions] - .filter(Boolean) - .join('\n\n') + const modeInstructions = typeof modePrompt === "object" ? modePrompt.customInstructions : undefined + const effectiveInstructions = [globalInstructions, modeInstructions].filter(Boolean).join("\n\n") this.cline = new Cline( this, @@ -272,7 +270,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { task, images, undefined, - experimentalDiffStrategy + experimentalDiffStrategy, ) } @@ -285,14 +283,12 @@ export class ClineProvider implements vscode.WebviewViewProvider { fuzzyMatchThreshold, mode, customInstructions: globalInstructions, - experimentalDiffStrategy + experimentalDiffStrategy, } = await this.getState() const modePrompt = customPrompts?.[mode] - const modeInstructions = typeof modePrompt === 'object' ? modePrompt.customInstructions : undefined - const effectiveInstructions = [globalInstructions, modeInstructions] - .filter(Boolean) - .join('\n\n') + const modeInstructions = typeof modePrompt === "object" ? modePrompt.customInstructions : undefined + const effectiveInstructions = [globalInstructions, modeInstructions].filter(Boolean).join("\n\n") this.cline = new Cline( this, @@ -303,7 +299,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { undefined, undefined, historyItem, - experimentalDiffStrategy + experimentalDiffStrategy, ) } @@ -403,7 +399,6 @@ export class ClineProvider implements vscode.WebviewViewProvider { async (message: WebviewMessage) => { switch (message.type) { case "webviewDidLaunch": - this.postStateToWebview() this.workspaceTracker?.initializeFilePaths() // don't await getTheme().then((theme) => @@ -450,53 +445,53 @@ export class ClineProvider implements vscode.WebviewViewProvider { } }) - - this.configManager.ListConfig().then(async (listApiConfig) => { - - if (!listApiConfig) { - return - } - - if (listApiConfig.length === 1) { - // check if first time init then sync with exist config - if (!checkExistKey(listApiConfig[0])) { - const { - apiConfiguration, - } = await this.getState() - await this.configManager.SaveConfig(listApiConfig[0].name ?? "default", apiConfiguration) - listApiConfig[0].apiProvider = apiConfiguration.apiProvider + this.configManager + .ListConfig() + .then(async (listApiConfig) => { + if (!listApiConfig) { + return } - } - let currentConfigName = await this.getGlobalState("currentApiConfigName") as string - - if (currentConfigName) { - if (!await this.configManager.HasConfig(currentConfigName)) { - // current config name not valid, get first config in list - await this.updateGlobalState("currentApiConfigName", listApiConfig?.[0]?.name) - if (listApiConfig?.[0]?.name) { - const apiConfig = await this.configManager.LoadConfig(listApiConfig?.[0]?.name); - - await Promise.all([ - this.updateGlobalState("listApiConfigMeta", listApiConfig), - this.postMessageToWebview({ type: "listApiConfig", listApiConfig }), - this.updateApiConfiguration(apiConfig), - ]) - await this.postStateToWebview() - return + if (listApiConfig.length === 1) { + // check if first time init then sync with exist config + if (!checkExistKey(listApiConfig[0])) { + const { apiConfiguration } = await this.getState() + await this.configManager.SaveConfig( + listApiConfig[0].name ?? "default", + apiConfiguration, + ) + listApiConfig[0].apiProvider = apiConfiguration.apiProvider } - } - } + let currentConfigName = (await this.getGlobalState("currentApiConfigName")) as string - await Promise.all( - [ + if (currentConfigName) { + if (!(await this.configManager.HasConfig(currentConfigName))) { + // current config name not valid, get first config in list + await this.updateGlobalState("currentApiConfigName", listApiConfig?.[0]?.name) + if (listApiConfig?.[0]?.name) { + const apiConfig = await this.configManager.LoadConfig( + listApiConfig?.[0]?.name, + ) + + await Promise.all([ + this.updateGlobalState("listApiConfigMeta", listApiConfig), + this.postMessageToWebview({ type: "listApiConfig", listApiConfig }), + this.updateApiConfiguration(apiConfig), + ]) + await this.postStateToWebview() + return + } + } + } + + await Promise.all([ await this.updateGlobalState("listApiConfigMeta", listApiConfig), - await this.postMessageToWebview({ type: "listApiConfig", listApiConfig }) - ] - ) - }).catch(console.error); + await this.postMessageToWebview({ type: "listApiConfig", listApiConfig }), + ]) + }) + .catch(console.error) break case "newTask": @@ -593,7 +588,10 @@ export class ClineProvider implements vscode.WebviewViewProvider { break case "refreshOpenAiModels": if (message?.values?.baseUrl && message?.values?.apiKey) { - const openAiModels = await this.getOpenAiModels(message?.values?.baseUrl, message?.values?.apiKey) + const openAiModels = await this.getOpenAiModels( + message?.values?.baseUrl, + message?.values?.apiKey, + ) this.postMessageToWebview({ type: "openAiModels", openAiModels }) } break @@ -625,12 +623,12 @@ export class ClineProvider implements vscode.WebviewViewProvider { break case "allowedCommands": - await this.context.globalState.update('allowedCommands', message.commands); + await this.context.globalState.update("allowedCommands", message.commands) // Also update workspace settings await vscode.workspace - .getConfiguration('roo-cline') - .update('allowedCommands', message.commands, vscode.ConfigurationTarget.Global); - break; + .getConfiguration("roo-cline") + .update("allowedCommands", message.commands, vscode.ConfigurationTarget.Global) + break case "openMcpSettings": { const mcpSettingsFilePath = await this.mcpHub?.getMcpSettingsFilePath() if (mcpSettingsFilePath) { @@ -651,7 +649,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { await this.mcpHub?.toggleToolAlwaysAllow( message.serverName!, message.toolName!, - message.alwaysAllow! + message.alwaysAllow!, ) } catch (error) { console.error(`Failed to toggle auto-approve for tool ${message.toolName}:`, error) @@ -660,10 +658,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { } case "toggleMcpServer": { try { - await this.mcpHub?.toggleServerDisabled( - message.serverName!, - message.disabled! - ) + await this.mcpHub?.toggleServerDisabled(message.serverName!, message.disabled!) } catch (error) { console.error(`Failed to toggle MCP server ${message.serverName}:`, error) } @@ -683,7 +678,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { case "soundEnabled": const soundEnabled = message.bool ?? true await this.updateGlobalState("soundEnabled", soundEnabled) - setSoundEnabled(soundEnabled) // Add this line to update the sound utility + setSoundEnabled(soundEnabled) // Add this line to update the sound utility await this.postStateToWebview() break case "soundVolume": @@ -729,84 +724,84 @@ export class ClineProvider implements vscode.WebviewViewProvider { case "mode": const newMode = message.text as Mode await this.updateGlobalState("mode", newMode) - + // Load the saved API config for the new mode if it exists const savedConfigId = await this.configManager.GetModeConfigId(newMode) const listApiConfig = await this.configManager.ListConfig() - + // Update listApiConfigMeta first to ensure UI has latest data await this.updateGlobalState("listApiConfigMeta", listApiConfig) - + // If this mode has a saved config, use it if (savedConfigId) { - const config = listApiConfig?.find(c => c.id === savedConfigId) + const config = listApiConfig?.find((c) => c.id === savedConfigId) if (config?.name) { const apiConfig = await this.configManager.LoadConfig(config.name) await Promise.all([ this.updateGlobalState("currentApiConfigName", config.name), - this.updateApiConfiguration(apiConfig) + this.updateApiConfiguration(apiConfig), ]) } } else { // If no saved config for this mode, save current config as default const currentApiConfigName = await this.getGlobalState("currentApiConfigName") if (currentApiConfigName) { - const config = listApiConfig?.find(c => c.name === currentApiConfigName) + const config = listApiConfig?.find((c) => c.name === currentApiConfigName) if (config?.id) { await this.configManager.SetModeConfig(newMode, config.id) } } } - + await this.postStateToWebview() break case "updateEnhancedPrompt": - const existingPrompts = await this.getGlobalState("customPrompts") || {} - + const existingPrompts = (await this.getGlobalState("customPrompts")) || {} + const updatedPrompts = { ...existingPrompts, - enhance: message.text + enhance: message.text, } - + await this.updateGlobalState("customPrompts", updatedPrompts) - + // Get current state and explicitly include customPrompts const currentState = await this.getState() - + const stateWithPrompts = { ...currentState, - customPrompts: updatedPrompts + customPrompts: updatedPrompts, } - + // Post state with prompts this.view?.webview.postMessage({ type: "state", - state: stateWithPrompts + state: stateWithPrompts, }) break case "updatePrompt": if (message.promptMode && message.customPrompt !== undefined) { - const existingPrompts = await this.getGlobalState("customPrompts") || {} - + const existingPrompts = (await this.getGlobalState("customPrompts")) || {} + const updatedPrompts = { ...existingPrompts, - [message.promptMode]: message.customPrompt + [message.promptMode]: message.customPrompt, } - + await this.updateGlobalState("customPrompts", updatedPrompts) - + // Get current state and explicitly include customPrompts const currentState = await this.getState() - + const stateWithPrompts = { ...currentState, - customPrompts: updatedPrompts + customPrompts: updatedPrompts, } - + // Post state with prompts this.view?.webview.postMessage({ type: "state", - state: stateWithPrompts + state: stateWithPrompts, }) } break @@ -817,60 +812,79 @@ export class ClineProvider implements vscode.WebviewViewProvider { "Just this message", "This and all subsequent messages", ) - if ((answer === "Just this message" || answer === "This and all subsequent messages") && - this.cline && typeof message.value === 'number' && message.value) { - const timeCutoff = message.value - 1000; // 1 second buffer before the message to delete - const messageIndex = this.cline.clineMessages.findIndex(msg => msg.ts && msg.ts >= timeCutoff) - const apiConversationHistoryIndex = this.cline.apiConversationHistory.findIndex(msg => msg.ts && msg.ts >= timeCutoff) - + if ( + (answer === "Just this message" || answer === "This and all subsequent messages") && + this.cline && + typeof message.value === "number" && + message.value + ) { + const timeCutoff = message.value - 1000 // 1 second buffer before the message to delete + const messageIndex = this.cline.clineMessages.findIndex( + (msg) => msg.ts && msg.ts >= timeCutoff, + ) + const apiConversationHistoryIndex = this.cline.apiConversationHistory.findIndex( + (msg) => msg.ts && msg.ts >= timeCutoff, + ) + if (messageIndex !== -1) { const { historyItem } = await this.getTaskWithId(this.cline.taskId) - + if (answer === "Just this message") { // Find the next user message first const nextUserMessage = this.cline.clineMessages .slice(messageIndex + 1) - .find(msg => msg.type === "say" && msg.say === "user_feedback") - + .find((msg) => msg.type === "say" && msg.say === "user_feedback") + // Handle UI messages if (nextUserMessage) { // Find absolute index of next user message - const nextUserMessageIndex = this.cline.clineMessages.findIndex(msg => msg === nextUserMessage) + const nextUserMessageIndex = this.cline.clineMessages.findIndex( + (msg) => msg === nextUserMessage, + ) // Keep messages before current message and after next user message await this.cline.overwriteClineMessages([ ...this.cline.clineMessages.slice(0, messageIndex), - ...this.cline.clineMessages.slice(nextUserMessageIndex) + ...this.cline.clineMessages.slice(nextUserMessageIndex), ]) } else { // If no next user message, keep only messages before current message await this.cline.overwriteClineMessages( - this.cline.clineMessages.slice(0, messageIndex) + this.cline.clineMessages.slice(0, messageIndex), ) } - + // Handle API messages if (apiConversationHistoryIndex !== -1) { if (nextUserMessage && nextUserMessage.ts) { // Keep messages before current API message and after next user message await this.cline.overwriteApiConversationHistory([ - ...this.cline.apiConversationHistory.slice(0, apiConversationHistoryIndex), - ...this.cline.apiConversationHistory.filter(msg => msg.ts && msg.ts >= nextUserMessage.ts) + ...this.cline.apiConversationHistory.slice( + 0, + apiConversationHistoryIndex, + ), + ...this.cline.apiConversationHistory.filter( + (msg) => msg.ts && msg.ts >= nextUserMessage.ts, + ), ]) } else { // If no next user message, keep only messages before current API message await this.cline.overwriteApiConversationHistory( - this.cline.apiConversationHistory.slice(0, apiConversationHistoryIndex) + this.cline.apiConversationHistory.slice(0, apiConversationHistoryIndex), ) } } } else if (answer === "This and all subsequent messages") { // Delete this message and all that follow - await this.cline.overwriteClineMessages(this.cline.clineMessages.slice(0, messageIndex)) + await this.cline.overwriteClineMessages( + this.cline.clineMessages.slice(0, messageIndex), + ) if (apiConversationHistoryIndex !== -1) { - await this.cline.overwriteApiConversationHistory(this.cline.apiConversationHistory.slice(0, apiConversationHistoryIndex)) + await this.cline.overwriteApiConversationHistory( + this.cline.apiConversationHistory.slice(0, apiConversationHistoryIndex), + ) } } - + await this.initClineWithHistoryItem(historyItem) } } @@ -891,12 +905,13 @@ export class ClineProvider implements vscode.WebviewViewProvider { case "enhancePrompt": if (message.text) { try { - const { apiConfiguration, customPrompts, listApiConfigMeta, enhancementApiConfigId } = await this.getState() - + const { apiConfiguration, customPrompts, listApiConfigMeta, enhancementApiConfigId } = + await this.getState() + // Try to get enhancement config first, fall back to current config let configToUse: ApiConfiguration = apiConfiguration if (enhancementApiConfigId) { - const config = listApiConfigMeta?.find(c => c.id === enhancementApiConfigId) + const config = listApiConfigMeta?.find((c) => c.id === enhancementApiConfigId) if (config?.name) { const loadedConfig = await this.configManager.LoadConfig(config.name) if (loadedConfig.apiProvider) { @@ -904,41 +919,49 @@ export class ClineProvider implements vscode.WebviewViewProvider { } } } - + const getEnhancePrompt = (value: string | PromptComponent | undefined): string => { - if (typeof value === 'string') { - return value; + if (typeof value === "string") { + return value } - return enhance.prompt; // Use the constant from modes.ts which we know is a string + return enhance.prompt // Use the constant from modes.ts which we know is a string } const enhancedPrompt = await enhancePrompt( configToUse, message.text, - getEnhancePrompt(customPrompts?.enhance) + getEnhancePrompt(customPrompts?.enhance), ) await this.postMessageToWebview({ type: "enhancedPrompt", - text: enhancedPrompt + text: enhancedPrompt, }) } catch (error) { console.error("Error enhancing prompt:", error) vscode.window.showErrorMessage("Failed to enhance prompt") await this.postMessageToWebview({ - type: "enhancedPrompt" + type: "enhancedPrompt", }) } } break case "getSystemPrompt": try { - const { apiConfiguration, customPrompts, customInstructions, preferredLanguage, browserViewportSize, mcpEnabled } = await this.getState() - const cwd = vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0) || '' + const { + apiConfiguration, + customPrompts, + customInstructions, + preferredLanguage, + browserViewportSize, + mcpEnabled, + } = await this.getState() + const cwd = + vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0) || "" const mode = message.mode ?? defaultModeSlug const instructions = await addCustomInstructions( { customInstructions, customPrompts, preferredLanguage }, cwd, - mode + mode, ) const systemPrompt = await SYSTEM_PROMPT( @@ -948,14 +971,14 @@ export class ClineProvider implements vscode.WebviewViewProvider { undefined, browserViewportSize ?? "900x600", mode, - customPrompts + customPrompts, ) const fullPrompt = instructions ? `${systemPrompt}${instructions}` : systemPrompt - + await this.postMessageToWebview({ type: "systemPrompt", text: fullPrompt, - mode: message.mode + mode: message.mode, }) } catch (error) { console.error("Error getting system prompt:", error) @@ -969,7 +992,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { const commits = await searchCommits(message.query || "", cwd) await this.postMessageToWebview({ type: "commitSearchResults", - commits + commits, }) } catch (error) { console.error("Error searching commits:", error) @@ -981,9 +1004,9 @@ export class ClineProvider implements vscode.WebviewViewProvider { case "upsertApiConfiguration": if (message.text && message.apiConfiguration) { try { - await this.configManager.SaveConfig(message.text, message.apiConfiguration); - let listApiConfig = await this.configManager.ListConfig(); - + await this.configManager.SaveConfig(message.text, message.apiConfiguration) + let listApiConfig = await this.configManager.ListConfig() + await Promise.all([ this.updateGlobalState("listApiConfigMeta", listApiConfig), this.updateApiConfiguration(message.apiConfiguration), @@ -1002,18 +1025,16 @@ export class ClineProvider implements vscode.WebviewViewProvider { try { const { oldName, newName } = message.values - await this.configManager.SaveConfig(newName, message.apiConfiguration); + await this.configManager.SaveConfig(newName, message.apiConfiguration) await this.configManager.DeleteConfig(oldName) - let listApiConfig = await this.configManager.ListConfig(); - const config = listApiConfig?.find(c => c.name === newName); - - // Update listApiConfigMeta first to ensure UI has latest data - await this.updateGlobalState("listApiConfigMeta", listApiConfig); + let listApiConfig = await this.configManager.ListConfig() + const config = listApiConfig?.find((c) => c.name === newName) - await Promise.all([ - this.updateGlobalState("currentApiConfigName", newName), - ]) + // Update listApiConfigMeta first to ensure UI has latest data + await this.updateGlobalState("listApiConfigMeta", listApiConfig) + + await Promise.all([this.updateGlobalState("currentApiConfigName", newName)]) await this.postStateToWebview() } catch (error) { @@ -1025,9 +1046,9 @@ export class ClineProvider implements vscode.WebviewViewProvider { case "loadApiConfiguration": if (message.text) { try { - const apiConfig = await this.configManager.LoadConfig(message.text); - const listApiConfig = await this.configManager.ListConfig(); - + const apiConfig = await this.configManager.LoadConfig(message.text) + const listApiConfig = await this.configManager.ListConfig() + await Promise.all([ this.updateGlobalState("listApiConfigMeta", listApiConfig), this.updateGlobalState("currentApiConfigName", message.text), @@ -1054,16 +1075,16 @@ export class ClineProvider implements vscode.WebviewViewProvider { } try { - await this.configManager.DeleteConfig(message.text); - const listApiConfig = await this.configManager.ListConfig(); - + await this.configManager.DeleteConfig(message.text) + const listApiConfig = await this.configManager.ListConfig() + // Update listApiConfigMeta first to ensure UI has latest data - await this.updateGlobalState("listApiConfigMeta", listApiConfig); + await this.updateGlobalState("listApiConfigMeta", listApiConfig) // If this was the current config, switch to first available let currentApiConfigName = await this.getGlobalState("currentApiConfigName") if (message.text === currentApiConfigName && listApiConfig?.[0]?.name) { - const apiConfig = await this.configManager.LoadConfig(listApiConfig[0].name); + const apiConfig = await this.configManager.LoadConfig(listApiConfig[0].name) await Promise.all([ this.updateGlobalState("currentApiConfigName", listApiConfig[0].name), this.updateApiConfiguration(apiConfig), @@ -1079,7 +1100,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { break case "getListApiConfiguration": try { - let listApiConfig = await this.configManager.ListConfig(); + let listApiConfig = await this.configManager.ListConfig() await this.updateGlobalState("listApiConfigMeta", listApiConfig) this.postMessageToWebview({ type: "listApiConfig", listApiConfig }) } catch (error) { @@ -1087,7 +1108,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { vscode.window.showErrorMessage("Failed to get list api configuration") } break - case "experimentalDiffStrategy": + case "experimentalDiffStrategy": await this.updateGlobalState("experimentalDiffStrategy", message.bool ?? false) // Update diffStrategy in current Cline instance if it exists if (this.cline) { @@ -1103,13 +1124,13 @@ export class ClineProvider implements vscode.WebviewViewProvider { private async updateApiConfiguration(apiConfiguration: ApiConfiguration) { // Update mode's default config - const { mode } = await this.getState(); + const { mode } = await this.getState() if (mode) { - const currentApiConfigName = await this.getGlobalState("currentApiConfigName"); - const listApiConfig = await this.configManager.ListConfig(); - const config = listApiConfig?.find(c => c.name === currentApiConfigName); + const currentApiConfigName = await this.getGlobalState("currentApiConfigName") + const listApiConfig = await this.configManager.ListConfig() + const config = listApiConfig?.find((c) => c.name === currentApiConfigName) if (config?.id) { - await this.configManager.SetModeConfig(mode, config.id); + await this.configManager.SetModeConfig(mode, config.id) } } @@ -1181,7 +1202,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { await this.storeSecret("mistralApiKey", mistralApiKey) if (this.cline) { this.cline.api = buildApiHandler(apiConfiguration) - } + } } async updateCustomInstructions(instructions?: string) { @@ -1252,11 +1273,11 @@ export class ClineProvider implements vscode.WebviewViewProvider { // VSCode LM API private async getVsCodeLmModels() { try { - const models = await vscode.lm.selectChatModels({}); - return models || []; + const models = await vscode.lm.selectChatModels({}) + return models || [] } catch (error) { - console.error('Error fetching VS Code LM models:', error); - return []; + console.error("Error fetching VS Code LM models:", error) + return [] } } @@ -1346,10 +1367,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { } async readGlamaModels(): Promise | undefined> { - const glamaModelsFilePath = path.join( - await this.ensureCacheDirectoryExists(), - GlobalFileNames.glamaModels, - ) + const glamaModelsFilePath = path.join(await this.ensureCacheDirectoryExists(), GlobalFileNames.glamaModels) const fileExists = await fileExistsAtPath(glamaModelsFilePath) if (fileExists) { const fileContents = await fs.readFile(glamaModelsFilePath, "utf8") @@ -1359,10 +1377,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { } async refreshGlamaModels() { - const glamaModelsFilePath = path.join( - await this.ensureCacheDirectoryExists(), - GlobalFileNames.glamaModels, - ) + const glamaModelsFilePath = path.join(await this.ensureCacheDirectoryExists(), GlobalFileNames.glamaModels) let models: Record = {} try { @@ -1397,7 +1412,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { } */ if (response.data) { - const rawModels = response.data; + const rawModels = response.data const parsePrice = (price: any) => { if (price) { return parseFloat(price) * 1_000_000 @@ -1565,7 +1580,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { uiMessagesFilePath: string apiConversationHistory: Anthropic.MessageParam[] }> { - const history = (await this.getGlobalState("taskHistory") as HistoryItem[] | undefined) || [] + const history = ((await this.getGlobalState("taskHistory")) as HistoryItem[] | undefined) || [] const historyItem = history.find((item) => item.id === id) if (historyItem) { const taskDirPath = path.join(this.context.globalStorageUri.fsPath, "tasks", id) @@ -1630,7 +1645,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { async deleteTaskFromState(id: string) { // Remove the task from history - const taskHistory = (await this.getGlobalState("taskHistory") as HistoryItem[]) || [] + const taskHistory = ((await this.getGlobalState("taskHistory")) as HistoryItem[]) || [] const updatedTaskHistory = taskHistory.filter((task) => task.id !== id) await this.updateGlobalState("taskHistory", updatedTaskHistory) @@ -1671,13 +1686,11 @@ export class ClineProvider implements vscode.WebviewViewProvider { mode, customPrompts, enhancementApiConfigId, - experimentalDiffStrategy, + experimentalDiffStrategy, autoApprovalEnabled, } = await this.getState() - const allowedCommands = vscode.workspace - .getConfiguration('roo-cline') - .get('allowedCommands') || [] + const allowedCommands = vscode.workspace.getConfiguration("roo-cline").get("allowedCommands") || [] return { version: this.context.extension?.packageJSON?.version ?? "", @@ -1700,7 +1713,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { soundVolume: soundVolume ?? 0.5, browserViewportSize: browserViewportSize ?? "900x600", screenshotQuality: screenshotQuality ?? 75, - preferredLanguage: preferredLanguage ?? 'English', + preferredLanguage: preferredLanguage ?? "English", writeDelayMs: writeDelayMs ?? 1000, terminalOutputLineLimit: terminalOutputLineLimit ?? 500, fuzzyMatchThreshold: fuzzyMatchThreshold ?? 1.0, @@ -1712,7 +1725,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { mode: mode ?? defaultModeSlug, customPrompts: customPrompts ?? {}, enhancementApiConfigId, - experimentalDiffStrategy: experimentalDiffStrategy ?? false, + experimentalDiffStrategy: experimentalDiffStrategy ?? false, autoApprovalEnabled: autoApprovalEnabled ?? false, } } @@ -1829,7 +1842,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { modeApiConfigs, customPrompts, enhancementApiConfigId, - experimentalDiffStrategy, + experimentalDiffStrategy, autoApprovalEnabled, ] = await Promise.all([ this.getGlobalState("apiProvider") as Promise, @@ -1891,7 +1904,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { this.getGlobalState("modeApiConfigs") as Promise | undefined>, this.getGlobalState("customPrompts") as Promise, this.getGlobalState("enhancementApiConfigId") as Promise, - this.getGlobalState("experimentalDiffStrategy") as Promise, + this.getGlobalState("experimentalDiffStrategy") as Promise, this.getGlobalState("autoApprovalEnabled") as Promise, ]) @@ -1962,48 +1975,50 @@ export class ClineProvider implements vscode.WebviewViewProvider { writeDelayMs: writeDelayMs ?? 1000, terminalOutputLineLimit: terminalOutputLineLimit ?? 500, mode: mode ?? defaultModeSlug, - preferredLanguage: preferredLanguage ?? (() => { - // Get VSCode's locale setting - const vscodeLang = vscode.env.language; - // Map VSCode locale to our supported languages - const langMap: { [key: string]: string } = { - 'en': 'English', - 'ar': 'Arabic', - 'pt-br': 'Brazilian Portuguese', - 'cs': 'Czech', - 'fr': 'French', - 'de': 'German', - 'hi': 'Hindi', - 'hu': 'Hungarian', - 'it': 'Italian', - 'ja': 'Japanese', - 'ko': 'Korean', - 'pl': 'Polish', - 'pt': 'Portuguese', - 'ru': 'Russian', - 'zh-cn': 'Simplified Chinese', - 'es': 'Spanish', - 'zh-tw': 'Traditional Chinese', - 'tr': 'Turkish' - }; - // Return mapped language or default to English - return langMap[vscodeLang.split('-')[0]] ?? 'English'; - })(), + preferredLanguage: + preferredLanguage ?? + (() => { + // Get VSCode's locale setting + const vscodeLang = vscode.env.language + // Map VSCode locale to our supported languages + const langMap: { [key: string]: string } = { + en: "English", + ar: "Arabic", + "pt-br": "Brazilian Portuguese", + cs: "Czech", + fr: "French", + de: "German", + hi: "Hindi", + hu: "Hungarian", + it: "Italian", + ja: "Japanese", + ko: "Korean", + pl: "Polish", + pt: "Portuguese", + ru: "Russian", + "zh-cn": "Simplified Chinese", + es: "Spanish", + "zh-tw": "Traditional Chinese", + tr: "Turkish", + } + // Return mapped language or default to English + return langMap[vscodeLang.split("-")[0]] ?? "English" + })(), mcpEnabled: mcpEnabled ?? true, alwaysApproveResubmit: alwaysApproveResubmit ?? false, requestDelaySeconds: requestDelaySeconds ?? 5, currentApiConfigName: currentApiConfigName ?? "default", listApiConfigMeta: listApiConfigMeta ?? [], - modeApiConfigs: modeApiConfigs ?? {} as Record, + modeApiConfigs: modeApiConfigs ?? ({} as Record), customPrompts: customPrompts ?? {}, enhancementApiConfigId, - experimentalDiffStrategy: experimentalDiffStrategy ?? false, + experimentalDiffStrategy: experimentalDiffStrategy ?? false, autoApprovalEnabled: autoApprovalEnabled ?? false, } } async updateTaskHistory(item: HistoryItem): Promise { - const history = (await this.getGlobalState("taskHistory") as HistoryItem[] | undefined) || [] + const history = ((await this.getGlobalState("taskHistory")) as HistoryItem[] | undefined) || [] const existingItemIndex = history.findIndex((h) => h.id === item.id) if (existingItemIndex !== -1) { diff --git a/src/core/webview/__tests__/ClineProvider.test.ts b/src/core/webview/__tests__/ClineProvider.test.ts index 98093e3..52fb611 100644 --- a/src/core/webview/__tests__/ClineProvider.test.ts +++ b/src/core/webview/__tests__/ClineProvider.test.ts @@ -1,1010 +1,990 @@ -import { ClineProvider } from '../ClineProvider' -import * as vscode from 'vscode' -import { ExtensionMessage, ExtensionState } from '../../../shared/ExtensionMessage' -import { setSoundEnabled } from '../../../utils/sound' -import { defaultModeSlug, modes } from '../../../shared/modes'; +import { ClineProvider } from "../ClineProvider" +import * as vscode from "vscode" +import { ExtensionMessage, ExtensionState } from "../../../shared/ExtensionMessage" +import { setSoundEnabled } from "../../../utils/sound" +import { defaultModeSlug, modes } from "../../../shared/modes" // Mock delay module -jest.mock('delay', () => { - const delayFn = (ms: number) => Promise.resolve(); - delayFn.createDelay = () => delayFn; - delayFn.reject = () => Promise.reject(new Error('Delay rejected')); - delayFn.range = () => Promise.resolve(); - return delayFn; -}); +jest.mock("delay", () => { + const delayFn = (ms: number) => Promise.resolve() + delayFn.createDelay = () => delayFn + delayFn.reject = () => Promise.reject(new Error("Delay rejected")) + delayFn.range = () => Promise.resolve() + return delayFn +}) // Mock MCP-related modules -jest.mock('@modelcontextprotocol/sdk/types.js', () => ({ - CallToolResultSchema: {}, - ListResourcesResultSchema: {}, - ListResourceTemplatesResultSchema: {}, - ListToolsResultSchema: {}, - ReadResourceResultSchema: {}, - ErrorCode: { - InvalidRequest: 'InvalidRequest', - MethodNotFound: 'MethodNotFound', - InternalError: 'InternalError' - }, - McpError: class McpError extends Error { - code: string; - constructor(code: string, message: string) { - super(message); - this.code = code; - this.name = 'McpError'; - } - } -}), { virtual: true }); +jest.mock( + "@modelcontextprotocol/sdk/types.js", + () => ({ + CallToolResultSchema: {}, + ListResourcesResultSchema: {}, + ListResourceTemplatesResultSchema: {}, + ListToolsResultSchema: {}, + ReadResourceResultSchema: {}, + ErrorCode: { + InvalidRequest: "InvalidRequest", + MethodNotFound: "MethodNotFound", + InternalError: "InternalError", + }, + McpError: class McpError extends Error { + code: string + constructor(code: string, message: string) { + super(message) + this.code = code + this.name = "McpError" + } + }, + }), + { virtual: true }, +) -jest.mock('@modelcontextprotocol/sdk/client/index.js', () => ({ - Client: jest.fn().mockImplementation(() => ({ - connect: jest.fn().mockResolvedValue(undefined), - close: jest.fn().mockResolvedValue(undefined), - listTools: jest.fn().mockResolvedValue({ tools: [] }), - callTool: jest.fn().mockResolvedValue({ content: [] }) - })) -}), { virtual: true }); +jest.mock( + "@modelcontextprotocol/sdk/client/index.js", + () => ({ + Client: jest.fn().mockImplementation(() => ({ + connect: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + listTools: jest.fn().mockResolvedValue({ tools: [] }), + callTool: jest.fn().mockResolvedValue({ content: [] }), + })), + }), + { virtual: true }, +) -jest.mock('@modelcontextprotocol/sdk/client/stdio.js', () => ({ - StdioClientTransport: jest.fn().mockImplementation(() => ({ - connect: jest.fn().mockResolvedValue(undefined), - close: jest.fn().mockResolvedValue(undefined) - })) -}), { virtual: true }); +jest.mock( + "@modelcontextprotocol/sdk/client/stdio.js", + () => ({ + StdioClientTransport: jest.fn().mockImplementation(() => ({ + connect: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + })), + }), + { virtual: true }, +) // Mock dependencies -jest.mock('vscode', () => ({ - ExtensionContext: jest.fn(), - OutputChannel: jest.fn(), - WebviewView: jest.fn(), - Uri: { - joinPath: jest.fn(), - file: jest.fn() - }, - window: { - showInformationMessage: jest.fn(), - showErrorMessage: jest.fn(), - }, - workspace: { - getConfiguration: jest.fn().mockReturnValue({ - get: jest.fn().mockReturnValue([]), - update: jest.fn() - }), - onDidChangeConfiguration: jest.fn().mockImplementation((callback) => ({ - dispose: jest.fn() - })), - onDidSaveTextDocument: jest.fn(() => ({ dispose: jest.fn() })), - onDidChangeTextDocument: jest.fn(() => ({ dispose: jest.fn() })), - onDidOpenTextDocument: jest.fn(() => ({ dispose: jest.fn() })), - onDidCloseTextDocument: jest.fn(() => ({ dispose: jest.fn() })) - }, - env: { - uriScheme: 'vscode', - language: 'en' - } +jest.mock("vscode", () => ({ + ExtensionContext: jest.fn(), + OutputChannel: jest.fn(), + WebviewView: jest.fn(), + Uri: { + joinPath: jest.fn(), + file: jest.fn(), + }, + window: { + showInformationMessage: jest.fn(), + showErrorMessage: jest.fn(), + }, + workspace: { + getConfiguration: jest.fn().mockReturnValue({ + get: jest.fn().mockReturnValue([]), + update: jest.fn(), + }), + onDidChangeConfiguration: jest.fn().mockImplementation((callback) => ({ + dispose: jest.fn(), + })), + onDidSaveTextDocument: jest.fn(() => ({ dispose: jest.fn() })), + onDidChangeTextDocument: jest.fn(() => ({ dispose: jest.fn() })), + onDidOpenTextDocument: jest.fn(() => ({ dispose: jest.fn() })), + onDidCloseTextDocument: jest.fn(() => ({ dispose: jest.fn() })), + }, + env: { + uriScheme: "vscode", + language: "en", + }, })) // Mock sound utility -jest.mock('../../../utils/sound', () => ({ - setSoundEnabled: jest.fn() +jest.mock("../../../utils/sound", () => ({ + setSoundEnabled: jest.fn(), })) // Mock ESM modules -jest.mock('p-wait-for', () => ({ - __esModule: true, - default: jest.fn().mockResolvedValue(undefined) +jest.mock("p-wait-for", () => ({ + __esModule: true, + default: jest.fn().mockResolvedValue(undefined), })) // Mock fs/promises -jest.mock('fs/promises', () => ({ - mkdir: jest.fn(), - writeFile: jest.fn(), - readFile: jest.fn(), - unlink: jest.fn(), - rmdir: jest.fn() +jest.mock("fs/promises", () => ({ + mkdir: jest.fn(), + writeFile: jest.fn(), + readFile: jest.fn(), + unlink: jest.fn(), + rmdir: jest.fn(), })) // Mock axios -jest.mock('axios', () => ({ - get: jest.fn().mockResolvedValue({ data: { data: [] } }), - post: jest.fn() +jest.mock("axios", () => ({ + get: jest.fn().mockResolvedValue({ data: { data: [] } }), + post: jest.fn(), })) // Mock buildApiHandler -jest.mock('../../../api', () => ({ - buildApiHandler: jest.fn() +jest.mock("../../../api", () => ({ + buildApiHandler: jest.fn(), })) // Mock system prompt -jest.mock('../../prompts/system', () => ({ - SYSTEM_PROMPT: jest.fn().mockImplementation(async () => 'mocked system prompt'), - codeMode: 'code', - addCustomInstructions: jest.fn().mockImplementation(async () => '') +jest.mock("../../prompts/system", () => ({ + SYSTEM_PROMPT: jest.fn().mockImplementation(async () => "mocked system prompt"), + codeMode: "code", + addCustomInstructions: jest.fn().mockImplementation(async () => ""), })) // Mock WorkspaceTracker -jest.mock('../../../integrations/workspace/WorkspaceTracker', () => { - return jest.fn().mockImplementation(() => ({ - initializeFilePaths: jest.fn(), - dispose: jest.fn() - })) +jest.mock("../../../integrations/workspace/WorkspaceTracker", () => { + return jest.fn().mockImplementation(() => ({ + initializeFilePaths: jest.fn(), + dispose: jest.fn(), + })) }) // Mock Cline -jest.mock('../../Cline', () => ({ - Cline: jest.fn().mockImplementation(( - provider, - apiConfiguration, - customInstructions, - diffEnabled, - fuzzyMatchThreshold, - task, - taskId - ) => ({ - abortTask: jest.fn(), - handleWebviewAskResponse: jest.fn(), - clineMessages: [], - apiConversationHistory: [], - overwriteClineMessages: jest.fn(), - overwriteApiConversationHistory: jest.fn(), - taskId: taskId || 'test-task-id' - })) +jest.mock("../../Cline", () => ({ + Cline: jest + .fn() + .mockImplementation( + (provider, apiConfiguration, customInstructions, diffEnabled, fuzzyMatchThreshold, task, taskId) => ({ + abortTask: jest.fn(), + handleWebviewAskResponse: jest.fn(), + clineMessages: [], + apiConversationHistory: [], + overwriteClineMessages: jest.fn(), + overwriteApiConversationHistory: jest.fn(), + taskId: taskId || "test-task-id", + }), + ), })) // Mock extract-text -jest.mock('../../../integrations/misc/extract-text', () => ({ - extractTextFromFile: jest.fn().mockImplementation(async (filePath: string) => { - const content = 'const x = 1;\nconst y = 2;\nconst z = 3;' - const lines = content.split('\n') - return lines.map((line, index) => `${index + 1} | ${line}`).join('\n') - }) +jest.mock("../../../integrations/misc/extract-text", () => ({ + extractTextFromFile: jest.fn().mockImplementation(async (filePath: string) => { + const content = "const x = 1;\nconst y = 2;\nconst z = 3;" + const lines = content.split("\n") + return lines.map((line, index) => `${index + 1} | ${line}`).join("\n") + }), })) // Spy on console.error and console.log to suppress expected messages beforeAll(() => { - jest.spyOn(console, 'error').mockImplementation(() => {}) - jest.spyOn(console, 'log').mockImplementation(() => {}) + jest.spyOn(console, "error").mockImplementation(() => {}) + jest.spyOn(console, "log").mockImplementation(() => {}) }) afterAll(() => { - jest.restoreAllMocks() + jest.restoreAllMocks() }) -describe('ClineProvider', () => { - let provider: ClineProvider - let mockContext: vscode.ExtensionContext - let mockOutputChannel: vscode.OutputChannel - let mockWebviewView: vscode.WebviewView - let mockPostMessage: jest.Mock - let visibilityChangeCallback: (e?: unknown) => void - - beforeEach(() => { - // Reset mocks - jest.clearAllMocks() - - // Mock context - mockContext = { - extensionPath: '/test/path', - extensionUri: {} as vscode.Uri, - globalState: { - get: jest.fn().mockImplementation((key: string) => { - switch (key) { - case 'mode': - return 'architect' - case 'currentApiConfigName': - return 'new-config' - default: - return undefined - } - }), - update: jest.fn(), - keys: jest.fn().mockReturnValue([]), - }, - secrets: { - get: jest.fn(), - store: jest.fn(), - delete: jest.fn() - }, - subscriptions: [], - extension: { - packageJSON: { version: '1.0.0' } - }, - globalStorageUri: { - fsPath: '/test/storage/path' - } - } as unknown as vscode.ExtensionContext - - // Mock output channel - mockOutputChannel = { - appendLine: jest.fn(), - clear: jest.fn(), - dispose: jest.fn() - } as unknown as vscode.OutputChannel - - // Mock webview - mockPostMessage = jest.fn() - mockWebviewView = { - webview: { - postMessage: mockPostMessage, - html: '', - options: {}, - onDidReceiveMessage: jest.fn(), - asWebviewUri: jest.fn() - }, - visible: true, - onDidDispose: jest.fn().mockImplementation((callback) => { - callback() - return { dispose: jest.fn() } - }), - onDidChangeVisibility: jest.fn().mockImplementation((callback) => { - visibilityChangeCallback = callback - return { dispose: jest.fn() } - }) - } as unknown as vscode.WebviewView - - provider = new ClineProvider(mockContext, mockOutputChannel) - }) - - test('constructor initializes correctly', () => { - expect(provider).toBeInstanceOf(ClineProvider) - // Since getVisibleInstance returns the last instance where view.visible is true - // @ts-ignore - accessing private property for testing - provider.view = mockWebviewView - expect(ClineProvider.getVisibleInstance()).toBe(provider) - }) - - test('resolveWebviewView sets up webview correctly', () => { - provider.resolveWebviewView(mockWebviewView) - - expect(mockWebviewView.webview.options).toEqual({ - enableScripts: true, - localResourceRoots: [mockContext.extensionUri] - }) - expect(mockWebviewView.webview.html).toContain('') - }) - - test('postMessageToWebview sends message to webview', async () => { - provider.resolveWebviewView(mockWebviewView) - - const mockState: ExtensionState = { - version: '1.0.0', - preferredLanguage: 'English', - clineMessages: [], - taskHistory: [], - shouldShowAnnouncement: false, - apiConfiguration: { - apiProvider: 'openrouter' - }, - customInstructions: undefined, - alwaysAllowReadOnly: false, - alwaysAllowWrite: false, - alwaysAllowExecute: false, - alwaysAllowBrowser: false, - alwaysAllowMcp: false, - uriScheme: 'vscode', - soundEnabled: false, - diffEnabled: false, - writeDelayMs: 1000, - browserViewportSize: "900x600", - fuzzyMatchThreshold: 1.0, - mcpEnabled: true, - requestDelaySeconds: 5, - mode: defaultModeSlug, - } - - const message: ExtensionMessage = { - type: 'state', - state: mockState - } - await provider.postMessageToWebview(message) - - expect(mockPostMessage).toHaveBeenCalledWith(message) - }) - - test('handles webviewDidLaunch message', async () => { - provider.resolveWebviewView(mockWebviewView) - - // Get the message handler from onDidReceiveMessage - const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] - - // Simulate webviewDidLaunch message - await messageHandler({ type: 'webviewDidLaunch' }) - - // Should post state and theme to webview - expect(mockPostMessage).toHaveBeenCalled() - }) - - test('clearTask aborts current task', async () => { - const mockAbortTask = jest.fn() - // @ts-ignore - accessing private property for testing - provider.cline = { abortTask: mockAbortTask } - - await provider.clearTask() - - expect(mockAbortTask).toHaveBeenCalled() - // @ts-ignore - accessing private property for testing - expect(provider.cline).toBeUndefined() - }) - - test('getState returns correct initial state', async () => { - const state = await provider.getState() - - expect(state).toHaveProperty('apiConfiguration') - expect(state.apiConfiguration).toHaveProperty('apiProvider') - expect(state).toHaveProperty('customInstructions') - expect(state).toHaveProperty('alwaysAllowReadOnly') - expect(state).toHaveProperty('alwaysAllowWrite') - expect(state).toHaveProperty('alwaysAllowExecute') - expect(state).toHaveProperty('alwaysAllowBrowser') - expect(state).toHaveProperty('taskHistory') - expect(state).toHaveProperty('soundEnabled') - expect(state).toHaveProperty('diffEnabled') - expect(state).toHaveProperty('writeDelayMs') - }) - - test('preferredLanguage defaults to VSCode language when not set', async () => { - // Mock VSCode language as Spanish - (vscode.env as any).language = 'es-ES'; - - const state = await provider.getState(); - expect(state.preferredLanguage).toBe('Spanish'); - }) - - test('preferredLanguage defaults to English for unsupported VSCode language', async () => { - // Mock VSCode language as an unsupported language - (vscode.env as any).language = 'unsupported-LANG'; - - const state = await provider.getState(); - expect(state.preferredLanguage).toBe('English'); - }) - - test('diffEnabled defaults to true when not set', async () => { - // Mock globalState.get to return undefined for diffEnabled - (mockContext.globalState.get as jest.Mock).mockReturnValue(undefined) - - const state = await provider.getState() - - expect(state.diffEnabled).toBe(true) - }) - - test('writeDelayMs defaults to 1000ms', async () => { - // Mock globalState.get to return undefined for writeDelayMs - (mockContext.globalState.get as jest.Mock).mockImplementation((key: string) => { - if (key === 'writeDelayMs') { - return undefined - } - return null - }) - - const state = await provider.getState() - expect(state.writeDelayMs).toBe(1000) - }) - - test('handles writeDelayMs message', async () => { - provider.resolveWebviewView(mockWebviewView) - const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] - - await messageHandler({ type: 'writeDelayMs', value: 2000 }) - - expect(mockContext.globalState.update).toHaveBeenCalledWith('writeDelayMs', 2000) - expect(mockPostMessage).toHaveBeenCalled() - }) - - test('updates sound utility when sound setting changes', async () => { - provider.resolveWebviewView(mockWebviewView) - - // Get the message handler from onDidReceiveMessage - const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] - - // Simulate setting sound to enabled - await messageHandler({ type: 'soundEnabled', bool: true }) - expect(setSoundEnabled).toHaveBeenCalledWith(true) - expect(mockContext.globalState.update).toHaveBeenCalledWith('soundEnabled', true) - expect(mockPostMessage).toHaveBeenCalled() - - // Simulate setting sound to disabled - await messageHandler({ type: 'soundEnabled', bool: false }) - expect(setSoundEnabled).toHaveBeenCalledWith(false) - expect(mockContext.globalState.update).toHaveBeenCalledWith('soundEnabled', false) - expect(mockPostMessage).toHaveBeenCalled() - }) - - test('requestDelaySeconds defaults to 5 seconds', async () => { - // Mock globalState.get to return undefined for requestDelaySeconds - (mockContext.globalState.get as jest.Mock).mockImplementation((key: string) => { - if (key === 'requestDelaySeconds') { - return undefined - } - return null - }) - - const state = await provider.getState() - expect(state.requestDelaySeconds).toBe(5) - }) - - test('alwaysApproveResubmit defaults to false', async () => { - // Mock globalState.get to return undefined for alwaysApproveResubmit - (mockContext.globalState.get as jest.Mock).mockReturnValue(undefined) - - const state = await provider.getState() - expect(state.alwaysApproveResubmit).toBe(false) - }) - - test('loads saved API config when switching modes', async () => { - provider.resolveWebviewView(mockWebviewView) - const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] - - // Mock ConfigManager methods - provider.configManager = { - GetModeConfigId: jest.fn().mockResolvedValue('test-id'), - ListConfig: jest.fn().mockResolvedValue([ - { name: 'test-config', id: 'test-id', apiProvider: 'anthropic' } - ]), - LoadConfig: jest.fn().mockResolvedValue({ apiProvider: 'anthropic' }), - SetModeConfig: jest.fn() - } as any - - // Switch to architect mode - await messageHandler({ type: 'mode', text: 'architect' }) - - // Should load the saved config for architect mode - expect(provider.configManager.GetModeConfigId).toHaveBeenCalledWith('architect') - expect(provider.configManager.LoadConfig).toHaveBeenCalledWith('test-config') - expect(mockContext.globalState.update).toHaveBeenCalledWith('currentApiConfigName', 'test-config') - }) - - test('saves current config when switching to mode without config', async () => { - provider.resolveWebviewView(mockWebviewView) - const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] - - // Mock ConfigManager methods - provider.configManager = { - GetModeConfigId: jest.fn().mockResolvedValue(undefined), - ListConfig: jest.fn().mockResolvedValue([ - { name: 'current-config', id: 'current-id', apiProvider: 'anthropic' } - ]), - SetModeConfig: jest.fn() - } as any - - // Mock current config name - (mockContext.globalState.get as jest.Mock).mockImplementation((key: string) => { - if (key === 'currentApiConfigName') { - return 'current-config' - } - return undefined - }) - - // Switch to architect mode - await messageHandler({ type: 'mode', text: 'architect' }) - - // Should save current config as default for architect mode - expect(provider.configManager.SetModeConfig).toHaveBeenCalledWith('architect', 'current-id') - }) - - test('saves config as default for current mode when loading config', async () => { - provider.resolveWebviewView(mockWebviewView) - const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] - - provider.configManager = { - LoadConfig: jest.fn().mockResolvedValue({ apiProvider: 'anthropic', id: 'new-id' }), - ListConfig: jest.fn().mockResolvedValue([ - { name: 'new-config', id: 'new-id', apiProvider: 'anthropic' } - ]), - SetModeConfig: jest.fn(), - GetModeConfigId: jest.fn().mockResolvedValue(undefined) - } as any - - // First set the mode - await messageHandler({ type: 'mode', text: 'architect' }) - - // Then load the config - await messageHandler({ type: 'loadApiConfiguration', text: 'new-config' }) - - // Should save new config as default for architect mode - expect(provider.configManager.SetModeConfig).toHaveBeenCalledWith('architect', 'new-id') - }) - - test('handles request delay settings messages', async () => { - provider.resolveWebviewView(mockWebviewView) - const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] - - // Test alwaysApproveResubmit - await messageHandler({ type: 'alwaysApproveResubmit', bool: true }) - expect(mockContext.globalState.update).toHaveBeenCalledWith('alwaysApproveResubmit', true) - expect(mockPostMessage).toHaveBeenCalled() - - // Test requestDelaySeconds - await messageHandler({ type: 'requestDelaySeconds', value: 10 }) - expect(mockContext.globalState.update).toHaveBeenCalledWith('requestDelaySeconds', 10) - expect(mockPostMessage).toHaveBeenCalled() - }) - - test('handles updatePrompt message correctly', async () => { - provider.resolveWebviewView(mockWebviewView) - const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] - - // Mock existing prompts - const existingPrompts = { - code: 'existing code prompt', - architect: 'existing architect prompt' - } - ;(mockContext.globalState.get as jest.Mock).mockImplementation((key: string) => { - if (key === 'customPrompts') { - return existingPrompts - } - return undefined - }) - - // Test updating a prompt - await messageHandler({ - type: 'updatePrompt', - promptMode: 'code', - customPrompt: 'new code prompt' - }) - - // Verify state was updated correctly - expect(mockContext.globalState.update).toHaveBeenCalledWith( - 'customPrompts', - { - ...existingPrompts, - code: 'new code prompt' - } - ) - - // Verify state was posted to webview - expect(mockPostMessage).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'state', - state: expect.objectContaining({ - customPrompts: { - ...existingPrompts, - code: 'new code prompt' - } - }) - }) - ) - }) - - test('customPrompts defaults to empty object', async () => { - // Mock globalState.get to return undefined for customPrompts - (mockContext.globalState.get as jest.Mock).mockImplementation((key: string) => { - if (key === 'customPrompts') { - return undefined - } - return null - }) - - const state = await provider.getState() - expect(state.customPrompts).toEqual({}) - }) - - test('uses mode-specific custom instructions in Cline initialization', async () => { - // Setup mock state - const modeCustomInstructions = 'Code mode instructions'; - const mockApiConfig = { - apiProvider: 'openrouter', - openRouterModelInfo: { supportsComputerUse: true } - }; - - jest.spyOn(provider, 'getState').mockResolvedValue({ - apiConfiguration: mockApiConfig, - customPrompts: { - code: { customInstructions: modeCustomInstructions } - }, - mode: 'code', - diffEnabled: true, - fuzzyMatchThreshold: 1.0 - } as any); - - // Reset Cline mock - const { Cline } = require('../../Cline'); - (Cline as jest.Mock).mockClear(); - - // Initialize Cline with a task - await provider.initClineWithTask('Test task'); - - // Verify Cline was initialized with mode-specific instructions - expect(Cline).toHaveBeenCalledWith( - provider, - mockApiConfig, - modeCustomInstructions, - true, - 1.0, - 'Test task', - undefined, - undefined, - undefined - ); - }); - test('handles mode-specific custom instructions updates', async () => { - provider.resolveWebviewView(mockWebviewView) - const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] - - // Mock existing prompts - const existingPrompts = { - code: { - roleDefinition: 'Code role', - customInstructions: 'Old instructions' - } - } - mockContext.globalState.get = jest.fn((key: string) => { - if (key === 'customPrompts') { - return existingPrompts - } - return undefined - }) - - // Update custom instructions for code mode - await messageHandler({ - type: 'updatePrompt', - promptMode: 'code', - customPrompt: { - roleDefinition: 'Code role', - customInstructions: 'New instructions' - } - }) - - // Verify state was updated correctly - expect(mockContext.globalState.update).toHaveBeenCalledWith( - 'customPrompts', - { - code: { - roleDefinition: 'Code role', - customInstructions: 'New instructions' - } - } - ) - }) - - test('saves mode config when updating API configuration', async () => { - // Setup mock context with mode and config name - mockContext = { - ...mockContext, - globalState: { - ...mockContext.globalState, - get: jest.fn((key: string) => { - if (key === 'mode') { - return 'code' - } else if (key === 'currentApiConfigName') { - return 'test-config' - } - return undefined - }), - update: jest.fn(), - keys: jest.fn().mockReturnValue([]), - } - } as unknown as vscode.ExtensionContext - - // Create new provider with updated mock context - provider = new ClineProvider(mockContext, mockOutputChannel) - provider.resolveWebviewView(mockWebviewView) - const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] - - provider.configManager = { - ListConfig: jest.fn().mockResolvedValue([ - { name: 'test-config', id: 'test-id', apiProvider: 'anthropic' } - ]), - SetModeConfig: jest.fn() - } as any - - // Update API configuration - await messageHandler({ - type: 'apiConfiguration', - apiConfiguration: { apiProvider: 'anthropic' } - }) - - // Should save config as default for current mode - expect(provider.configManager.SetModeConfig).toHaveBeenCalledWith('code', 'test-id') - }) - - test('file content includes line numbers', async () => { - const { extractTextFromFile } = require('../../../integrations/misc/extract-text') - const result = await extractTextFromFile('test.js') - expect(result).toBe('1 | const x = 1;\n2 | const y = 2;\n3 | const z = 3;') - }) - - describe('deleteMessage', () => { - beforeEach(() => { - // Mock window.showInformationMessage - ;(vscode.window.showInformationMessage as jest.Mock) = jest.fn() - provider.resolveWebviewView(mockWebviewView) - }) - - test('handles "Just this message" deletion correctly', async () => { - // Mock user selecting "Just this message" - ;(vscode.window.showInformationMessage as jest.Mock).mockResolvedValue('Just this message') - - // Setup mock messages - const mockMessages = [ - { ts: 1000, type: 'say', say: 'user_feedback' }, // User message 1 - { ts: 2000, type: 'say', say: 'tool' }, // Tool message - { ts: 3000, type: 'say', say: 'text', value: 4000 }, // Message to delete - { ts: 4000, type: 'say', say: 'browser_action' }, // Response to delete - { ts: 5000, type: 'say', say: 'user_feedback' }, // Next user message - { ts: 6000, type: 'say', say: 'user_feedback' } // Final message - ] - - const mockApiHistory = [ - { ts: 1000 }, - { ts: 2000 }, - { ts: 3000 }, - { ts: 4000 }, - { ts: 5000 }, - { ts: 6000 } - ] - - // Setup Cline instance with mock data - const mockCline = { - clineMessages: mockMessages, - apiConversationHistory: mockApiHistory, - overwriteClineMessages: jest.fn(), - overwriteApiConversationHistory: jest.fn(), - taskId: 'test-task-id', - abortTask: jest.fn(), - handleWebviewAskResponse: jest.fn() - } - // @ts-ignore - accessing private property for testing - provider.cline = mockCline - - // Mock getTaskWithId - ;(provider as any).getTaskWithId = jest.fn().mockResolvedValue({ - historyItem: { id: 'test-task-id' } - }) - - // Trigger message deletion - const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] - await messageHandler({ type: 'deleteMessage', value: 4000 }) - - // Verify correct messages were kept - expect(mockCline.overwriteClineMessages).toHaveBeenCalledWith([ - mockMessages[0], - mockMessages[1], - mockMessages[4], - mockMessages[5] - ]) - - // Verify correct API messages were kept - expect(mockCline.overwriteApiConversationHistory).toHaveBeenCalledWith([ - mockApiHistory[0], - mockApiHistory[1], - mockApiHistory[4], - mockApiHistory[5] - ]) - }) - - test('handles "This and all subsequent messages" deletion correctly', async () => { - // Mock user selecting "This and all subsequent messages" - ;(vscode.window.showInformationMessage as jest.Mock).mockResolvedValue('This and all subsequent messages') - - // Setup mock messages - const mockMessages = [ - { ts: 1000, type: 'say', say: 'user_feedback' }, - { ts: 2000, type: 'say', say: 'text', value: 3000 }, // Message to delete - { ts: 3000, type: 'say', say: 'user_feedback' }, - { ts: 4000, type: 'say', say: 'user_feedback' } - ] - - const mockApiHistory = [ - { ts: 1000 }, - { ts: 2000 }, - { ts: 3000 }, - { ts: 4000 } - ] - - // Setup Cline instance with mock data - const mockCline = { - clineMessages: mockMessages, - apiConversationHistory: mockApiHistory, - overwriteClineMessages: jest.fn(), - overwriteApiConversationHistory: jest.fn(), - taskId: 'test-task-id', - abortTask: jest.fn(), - handleWebviewAskResponse: jest.fn() - } - // @ts-ignore - accessing private property for testing - provider.cline = mockCline - - // Mock getTaskWithId - ;(provider as any).getTaskWithId = jest.fn().mockResolvedValue({ - historyItem: { id: 'test-task-id' } - }) - - // Trigger message deletion - const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] - await messageHandler({ type: 'deleteMessage', value: 3000 }) - - // Verify only messages before the deleted message were kept - expect(mockCline.overwriteClineMessages).toHaveBeenCalledWith([ - mockMessages[0] - ]) - - // Verify only API messages before the deleted message were kept - expect(mockCline.overwriteApiConversationHistory).toHaveBeenCalledWith([ - mockApiHistory[0] - ]) - }) - - test('handles Cancel correctly', async () => { - // Mock user selecting "Cancel" - ;(vscode.window.showInformationMessage as jest.Mock).mockResolvedValue('Cancel') - - const mockCline = { - clineMessages: [{ ts: 1000 }, { ts: 2000 }], - apiConversationHistory: [{ ts: 1000 }, { ts: 2000 }], - overwriteClineMessages: jest.fn(), - overwriteApiConversationHistory: jest.fn(), - taskId: 'test-task-id' - } - // @ts-ignore - accessing private property for testing - provider.cline = mockCline - - // Trigger message deletion - const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] - await messageHandler({ type: 'deleteMessage', value: 2000 }) - - // Verify no messages were deleted - expect(mockCline.overwriteClineMessages).not.toHaveBeenCalled() - expect(mockCline.overwriteApiConversationHistory).not.toHaveBeenCalled() - }) - }) - - describe('getSystemPrompt', () => { - beforeEach(() => { - mockPostMessage.mockClear(); - provider.resolveWebviewView(mockWebviewView); - }); - - const getMessageHandler = () => { - const mockCalls = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls; - expect(mockCalls.length).toBeGreaterThan(0); - return mockCalls[0][0]; - }; - - test('handles mcpEnabled setting correctly', async () => { - // Mock getState to return mcpEnabled: true - jest.spyOn(provider, 'getState').mockResolvedValue({ - apiConfiguration: { - apiProvider: 'openrouter' as const, - openRouterModelInfo: { - supportsComputerUse: true, - supportsPromptCache: false, - maxTokens: 4096, - contextWindow: 8192, - supportsImages: false, - inputPrice: 0.0, - outputPrice: 0.0, - description: undefined - } - }, - mcpEnabled: true, - mode: 'code' as const - } as any); - - const handler1 = getMessageHandler(); - expect(typeof handler1).toBe('function'); - await handler1({ type: 'getSystemPrompt', mode: 'code' }); - - // Verify mcpHub is passed when mcpEnabled is true - expect(mockPostMessage).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'systemPrompt', - text: expect.any(String) - }) - ); - - // Mock getState to return mcpEnabled: false - jest.spyOn(provider, 'getState').mockResolvedValue({ - apiConfiguration: { - apiProvider: 'openrouter' as const, - openRouterModelInfo: { - supportsComputerUse: true, - supportsPromptCache: false, - maxTokens: 4096, - contextWindow: 8192, - supportsImages: false, - inputPrice: 0.0, - outputPrice: 0.0, - description: undefined - } - }, - mcpEnabled: false, - mode: 'code' as const - } as any); - - const handler2 = getMessageHandler(); - await handler2({ type: 'getSystemPrompt', mode: 'code' }); - - // Verify mcpHub is not passed when mcpEnabled is false - expect(mockPostMessage).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'systemPrompt', - text: expect.any(String) - }) - ); - }); - - test('handles errors gracefully', async () => { - // Mock SYSTEM_PROMPT to throw an error - const systemPrompt = require('../../prompts/system') - jest.spyOn(systemPrompt, 'SYSTEM_PROMPT').mockRejectedValueOnce(new Error('Test error')) - - const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] - await messageHandler({ type: 'getSystemPrompt', mode: 'code' }) - - expect(vscode.window.showErrorMessage).toHaveBeenCalledWith('Failed to get system prompt') - }) - - test('uses mode-specific custom instructions in system prompt', async () => { - const systemPrompt = require('../../prompts/system') - const { addCustomInstructions } = systemPrompt - - // Mock getState to return mode-specific custom instructions - jest.spyOn(provider, 'getState').mockResolvedValue({ - apiConfiguration: { - apiProvider: 'openrouter', - openRouterModelInfo: { supportsComputerUse: true } - }, - customPrompts: { - code: { customInstructions: 'Code mode specific instructions' } - }, - mode: 'code', - mcpEnabled: false, - browserViewportSize: '900x600' - } as any) - - const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] - await messageHandler({ type: 'getSystemPrompt', mode: 'code' }) - - // Verify addCustomInstructions was called with mode-specific instructions - expect(addCustomInstructions).toHaveBeenCalledWith( - { - customInstructions: undefined, - customPrompts: { - code: { customInstructions: 'Code mode specific instructions' } - }, - preferredLanguage: undefined - }, - expect.any(String), - 'code' - ) - }) - - test('uses correct mode-specific instructions when mode is specified', async () => { - const systemPrompt = require('../../prompts/system') - const { addCustomInstructions } = systemPrompt - - // Mock getState to return instructions for multiple modes - jest.spyOn(provider, 'getState').mockResolvedValue({ - apiConfiguration: { - apiProvider: 'openrouter', - openRouterModelInfo: { supportsComputerUse: true } - }, - customPrompts: { - code: { customInstructions: 'Code mode instructions' }, - architect: { customInstructions: 'Architect mode instructions' } - }, - mode: 'code', - mcpEnabled: false, - browserViewportSize: '900x600' - } as any) - - const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] - - // Request architect mode prompt - await messageHandler({ type: 'getSystemPrompt', mode: 'architect' }) - - // Verify architect mode instructions were used - expect(addCustomInstructions).toHaveBeenCalledWith( - { - customInstructions: undefined, - customPrompts: { - code: { customInstructions: 'Code mode instructions' }, - architect: { customInstructions: 'Architect mode instructions' } - }, - preferredLanguage: undefined - }, - expect.any(String), - 'architect' - ) - }) - }) +describe("ClineProvider", () => { + let provider: ClineProvider + let mockContext: vscode.ExtensionContext + let mockOutputChannel: vscode.OutputChannel + let mockWebviewView: vscode.WebviewView + let mockPostMessage: jest.Mock + let visibilityChangeCallback: (e?: unknown) => void + + beforeEach(() => { + // Reset mocks + jest.clearAllMocks() + + // Mock context + mockContext = { + extensionPath: "/test/path", + extensionUri: {} as vscode.Uri, + globalState: { + get: jest.fn().mockImplementation((key: string) => { + switch (key) { + case "mode": + return "architect" + case "currentApiConfigName": + return "new-config" + default: + return undefined + } + }), + update: jest.fn(), + keys: jest.fn().mockReturnValue([]), + }, + secrets: { + get: jest.fn(), + store: jest.fn(), + delete: jest.fn(), + }, + subscriptions: [], + extension: { + packageJSON: { version: "1.0.0" }, + }, + globalStorageUri: { + fsPath: "/test/storage/path", + }, + } as unknown as vscode.ExtensionContext + + // Mock output channel + mockOutputChannel = { + appendLine: jest.fn(), + clear: jest.fn(), + dispose: jest.fn(), + } as unknown as vscode.OutputChannel + + // Mock webview + mockPostMessage = jest.fn() + mockWebviewView = { + webview: { + postMessage: mockPostMessage, + html: "", + options: {}, + onDidReceiveMessage: jest.fn(), + asWebviewUri: jest.fn(), + }, + visible: true, + onDidDispose: jest.fn().mockImplementation((callback) => { + callback() + return { dispose: jest.fn() } + }), + onDidChangeVisibility: jest.fn().mockImplementation((callback) => { + visibilityChangeCallback = callback + return { dispose: jest.fn() } + }), + } as unknown as vscode.WebviewView + + provider = new ClineProvider(mockContext, mockOutputChannel) + }) + + test("constructor initializes correctly", () => { + expect(provider).toBeInstanceOf(ClineProvider) + // Since getVisibleInstance returns the last instance where view.visible is true + // @ts-ignore - accessing private property for testing + provider.view = mockWebviewView + expect(ClineProvider.getVisibleInstance()).toBe(provider) + }) + + test("resolveWebviewView sets up webview correctly", () => { + provider.resolveWebviewView(mockWebviewView) + + expect(mockWebviewView.webview.options).toEqual({ + enableScripts: true, + localResourceRoots: [mockContext.extensionUri], + }) + expect(mockWebviewView.webview.html).toContain("") + }) + + test("postMessageToWebview sends message to webview", async () => { + provider.resolveWebviewView(mockWebviewView) + + const mockState: ExtensionState = { + version: "1.0.0", + preferredLanguage: "English", + clineMessages: [], + taskHistory: [], + shouldShowAnnouncement: false, + apiConfiguration: { + apiProvider: "openrouter", + }, + customInstructions: undefined, + alwaysAllowReadOnly: false, + alwaysAllowWrite: false, + alwaysAllowExecute: false, + alwaysAllowBrowser: false, + alwaysAllowMcp: false, + uriScheme: "vscode", + soundEnabled: false, + diffEnabled: false, + writeDelayMs: 1000, + browserViewportSize: "900x600", + fuzzyMatchThreshold: 1.0, + mcpEnabled: true, + requestDelaySeconds: 5, + mode: defaultModeSlug, + } + + const message: ExtensionMessage = { + type: "state", + state: mockState, + } + await provider.postMessageToWebview(message) + + expect(mockPostMessage).toHaveBeenCalledWith(message) + }) + + test("handles webviewDidLaunch message", async () => { + provider.resolveWebviewView(mockWebviewView) + + // Get the message handler from onDidReceiveMessage + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + + // Simulate webviewDidLaunch message + await messageHandler({ type: "webviewDidLaunch" }) + + // Should post state and theme to webview + expect(mockPostMessage).toHaveBeenCalled() + }) + + test("clearTask aborts current task", async () => { + const mockAbortTask = jest.fn() + // @ts-ignore - accessing private property for testing + provider.cline = { abortTask: mockAbortTask } + + await provider.clearTask() + + expect(mockAbortTask).toHaveBeenCalled() + // @ts-ignore - accessing private property for testing + expect(provider.cline).toBeUndefined() + }) + + test("getState returns correct initial state", async () => { + const state = await provider.getState() + + expect(state).toHaveProperty("apiConfiguration") + expect(state.apiConfiguration).toHaveProperty("apiProvider") + expect(state).toHaveProperty("customInstructions") + expect(state).toHaveProperty("alwaysAllowReadOnly") + expect(state).toHaveProperty("alwaysAllowWrite") + expect(state).toHaveProperty("alwaysAllowExecute") + expect(state).toHaveProperty("alwaysAllowBrowser") + expect(state).toHaveProperty("taskHistory") + expect(state).toHaveProperty("soundEnabled") + expect(state).toHaveProperty("diffEnabled") + expect(state).toHaveProperty("writeDelayMs") + }) + + test("preferredLanguage defaults to VSCode language when not set", async () => { + // Mock VSCode language as Spanish + ;(vscode.env as any).language = "es-ES" + + const state = await provider.getState() + expect(state.preferredLanguage).toBe("Spanish") + }) + + test("preferredLanguage defaults to English for unsupported VSCode language", async () => { + // Mock VSCode language as an unsupported language + ;(vscode.env as any).language = "unsupported-LANG" + + const state = await provider.getState() + expect(state.preferredLanguage).toBe("English") + }) + + test("diffEnabled defaults to true when not set", async () => { + // Mock globalState.get to return undefined for diffEnabled + ;(mockContext.globalState.get as jest.Mock).mockReturnValue(undefined) + + const state = await provider.getState() + + expect(state.diffEnabled).toBe(true) + }) + + test("writeDelayMs defaults to 1000ms", async () => { + // Mock globalState.get to return undefined for writeDelayMs + ;(mockContext.globalState.get as jest.Mock).mockImplementation((key: string) => { + if (key === "writeDelayMs") { + return undefined + } + return null + }) + + const state = await provider.getState() + expect(state.writeDelayMs).toBe(1000) + }) + + test("handles writeDelayMs message", async () => { + provider.resolveWebviewView(mockWebviewView) + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + + await messageHandler({ type: "writeDelayMs", value: 2000 }) + + expect(mockContext.globalState.update).toHaveBeenCalledWith("writeDelayMs", 2000) + expect(mockPostMessage).toHaveBeenCalled() + }) + + test("updates sound utility when sound setting changes", async () => { + provider.resolveWebviewView(mockWebviewView) + + // Get the message handler from onDidReceiveMessage + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + + // Simulate setting sound to enabled + await messageHandler({ type: "soundEnabled", bool: true }) + expect(setSoundEnabled).toHaveBeenCalledWith(true) + expect(mockContext.globalState.update).toHaveBeenCalledWith("soundEnabled", true) + expect(mockPostMessage).toHaveBeenCalled() + + // Simulate setting sound to disabled + await messageHandler({ type: "soundEnabled", bool: false }) + expect(setSoundEnabled).toHaveBeenCalledWith(false) + expect(mockContext.globalState.update).toHaveBeenCalledWith("soundEnabled", false) + expect(mockPostMessage).toHaveBeenCalled() + }) + + test("requestDelaySeconds defaults to 5 seconds", async () => { + // Mock globalState.get to return undefined for requestDelaySeconds + ;(mockContext.globalState.get as jest.Mock).mockImplementation((key: string) => { + if (key === "requestDelaySeconds") { + return undefined + } + return null + }) + + const state = await provider.getState() + expect(state.requestDelaySeconds).toBe(5) + }) + + test("alwaysApproveResubmit defaults to false", async () => { + // Mock globalState.get to return undefined for alwaysApproveResubmit + ;(mockContext.globalState.get as jest.Mock).mockReturnValue(undefined) + + const state = await provider.getState() + expect(state.alwaysApproveResubmit).toBe(false) + }) + + test("loads saved API config when switching modes", async () => { + provider.resolveWebviewView(mockWebviewView) + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + + // Mock ConfigManager methods + provider.configManager = { + GetModeConfigId: jest.fn().mockResolvedValue("test-id"), + ListConfig: jest.fn().mockResolvedValue([{ name: "test-config", id: "test-id", apiProvider: "anthropic" }]), + LoadConfig: jest.fn().mockResolvedValue({ apiProvider: "anthropic" }), + SetModeConfig: jest.fn(), + } as any + + // Switch to architect mode + await messageHandler({ type: "mode", text: "architect" }) + + // Should load the saved config for architect mode + expect(provider.configManager.GetModeConfigId).toHaveBeenCalledWith("architect") + expect(provider.configManager.LoadConfig).toHaveBeenCalledWith("test-config") + expect(mockContext.globalState.update).toHaveBeenCalledWith("currentApiConfigName", "test-config") + }) + + test("saves current config when switching to mode without config", async () => { + provider.resolveWebviewView(mockWebviewView) + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + + // Mock ConfigManager methods + provider.configManager = { + GetModeConfigId: jest.fn().mockResolvedValue(undefined), + ListConfig: jest + .fn() + .mockResolvedValue([{ name: "current-config", id: "current-id", apiProvider: "anthropic" }]), + SetModeConfig: jest.fn(), + } as any + + // Mock current config name + ;(mockContext.globalState.get as jest.Mock).mockImplementation((key: string) => { + if (key === "currentApiConfigName") { + return "current-config" + } + return undefined + }) + + // Switch to architect mode + await messageHandler({ type: "mode", text: "architect" }) + + // Should save current config as default for architect mode + expect(provider.configManager.SetModeConfig).toHaveBeenCalledWith("architect", "current-id") + }) + + test("saves config as default for current mode when loading config", async () => { + provider.resolveWebviewView(mockWebviewView) + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + + provider.configManager = { + LoadConfig: jest.fn().mockResolvedValue({ apiProvider: "anthropic", id: "new-id" }), + ListConfig: jest.fn().mockResolvedValue([{ name: "new-config", id: "new-id", apiProvider: "anthropic" }]), + SetModeConfig: jest.fn(), + GetModeConfigId: jest.fn().mockResolvedValue(undefined), + } as any + + // First set the mode + await messageHandler({ type: "mode", text: "architect" }) + + // Then load the config + await messageHandler({ type: "loadApiConfiguration", text: "new-config" }) + + // Should save new config as default for architect mode + expect(provider.configManager.SetModeConfig).toHaveBeenCalledWith("architect", "new-id") + }) + + test("handles request delay settings messages", async () => { + provider.resolveWebviewView(mockWebviewView) + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + + // Test alwaysApproveResubmit + await messageHandler({ type: "alwaysApproveResubmit", bool: true }) + expect(mockContext.globalState.update).toHaveBeenCalledWith("alwaysApproveResubmit", true) + expect(mockPostMessage).toHaveBeenCalled() + + // Test requestDelaySeconds + await messageHandler({ type: "requestDelaySeconds", value: 10 }) + expect(mockContext.globalState.update).toHaveBeenCalledWith("requestDelaySeconds", 10) + expect(mockPostMessage).toHaveBeenCalled() + }) + + test("handles updatePrompt message correctly", async () => { + provider.resolveWebviewView(mockWebviewView) + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + + // Mock existing prompts + const existingPrompts = { + code: "existing code prompt", + architect: "existing architect prompt", + } + ;(mockContext.globalState.get as jest.Mock).mockImplementation((key: string) => { + if (key === "customPrompts") { + return existingPrompts + } + return undefined + }) + + // Test updating a prompt + await messageHandler({ + type: "updatePrompt", + promptMode: "code", + customPrompt: "new code prompt", + }) + + // Verify state was updated correctly + expect(mockContext.globalState.update).toHaveBeenCalledWith("customPrompts", { + ...existingPrompts, + code: "new code prompt", + }) + + // Verify state was posted to webview + expect(mockPostMessage).toHaveBeenCalledWith( + expect.objectContaining({ + type: "state", + state: expect.objectContaining({ + customPrompts: { + ...existingPrompts, + code: "new code prompt", + }, + }), + }), + ) + }) + + test("customPrompts defaults to empty object", async () => { + // Mock globalState.get to return undefined for customPrompts + ;(mockContext.globalState.get as jest.Mock).mockImplementation((key: string) => { + if (key === "customPrompts") { + return undefined + } + return null + }) + + const state = await provider.getState() + expect(state.customPrompts).toEqual({}) + }) + + test("uses mode-specific custom instructions in Cline initialization", async () => { + // Setup mock state + const modeCustomInstructions = "Code mode instructions" + const mockApiConfig = { + apiProvider: "openrouter", + openRouterModelInfo: { supportsComputerUse: true }, + } + + jest.spyOn(provider, "getState").mockResolvedValue({ + apiConfiguration: mockApiConfig, + customPrompts: { + code: { customInstructions: modeCustomInstructions }, + }, + mode: "code", + diffEnabled: true, + fuzzyMatchThreshold: 1.0, + } as any) + + // Reset Cline mock + const { Cline } = require("../../Cline") + ;(Cline as jest.Mock).mockClear() + + // Initialize Cline with a task + await provider.initClineWithTask("Test task") + + // Verify Cline was initialized with mode-specific instructions + expect(Cline).toHaveBeenCalledWith( + provider, + mockApiConfig, + modeCustomInstructions, + true, + 1.0, + "Test task", + undefined, + undefined, + undefined, + ) + }) + test("handles mode-specific custom instructions updates", async () => { + provider.resolveWebviewView(mockWebviewView) + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + + // Mock existing prompts + const existingPrompts = { + code: { + roleDefinition: "Code role", + customInstructions: "Old instructions", + }, + } + mockContext.globalState.get = jest.fn((key: string) => { + if (key === "customPrompts") { + return existingPrompts + } + return undefined + }) + + // Update custom instructions for code mode + await messageHandler({ + type: "updatePrompt", + promptMode: "code", + customPrompt: { + roleDefinition: "Code role", + customInstructions: "New instructions", + }, + }) + + // Verify state was updated correctly + expect(mockContext.globalState.update).toHaveBeenCalledWith("customPrompts", { + code: { + roleDefinition: "Code role", + customInstructions: "New instructions", + }, + }) + }) + + test("saves mode config when updating API configuration", async () => { + // Setup mock context with mode and config name + mockContext = { + ...mockContext, + globalState: { + ...mockContext.globalState, + get: jest.fn((key: string) => { + if (key === "mode") { + return "code" + } else if (key === "currentApiConfigName") { + return "test-config" + } + return undefined + }), + update: jest.fn(), + keys: jest.fn().mockReturnValue([]), + }, + } as unknown as vscode.ExtensionContext + + // Create new provider with updated mock context + provider = new ClineProvider(mockContext, mockOutputChannel) + provider.resolveWebviewView(mockWebviewView) + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + + provider.configManager = { + ListConfig: jest.fn().mockResolvedValue([{ name: "test-config", id: "test-id", apiProvider: "anthropic" }]), + SetModeConfig: jest.fn(), + } as any + + // Update API configuration + await messageHandler({ + type: "apiConfiguration", + apiConfiguration: { apiProvider: "anthropic" }, + }) + + // Should save config as default for current mode + expect(provider.configManager.SetModeConfig).toHaveBeenCalledWith("code", "test-id") + }) + + test("file content includes line numbers", async () => { + const { extractTextFromFile } = require("../../../integrations/misc/extract-text") + const result = await extractTextFromFile("test.js") + expect(result).toBe("1 | const x = 1;\n2 | const y = 2;\n3 | const z = 3;") + }) + + describe("deleteMessage", () => { + beforeEach(() => { + // Mock window.showInformationMessage + ;(vscode.window.showInformationMessage as jest.Mock) = jest.fn() + provider.resolveWebviewView(mockWebviewView) + }) + + test('handles "Just this message" deletion correctly', async () => { + // Mock user selecting "Just this message" + ;(vscode.window.showInformationMessage as jest.Mock).mockResolvedValue("Just this message") + + // Setup mock messages + const mockMessages = [ + { ts: 1000, type: "say", say: "user_feedback" }, // User message 1 + { ts: 2000, type: "say", say: "tool" }, // Tool message + { ts: 3000, type: "say", say: "text", value: 4000 }, // Message to delete + { ts: 4000, type: "say", say: "browser_action" }, // Response to delete + { ts: 5000, type: "say", say: "user_feedback" }, // Next user message + { ts: 6000, type: "say", say: "user_feedback" }, // Final message + ] + + const mockApiHistory = [{ ts: 1000 }, { ts: 2000 }, { ts: 3000 }, { ts: 4000 }, { ts: 5000 }, { ts: 6000 }] + + // Setup Cline instance with mock data + const mockCline = { + clineMessages: mockMessages, + apiConversationHistory: mockApiHistory, + overwriteClineMessages: jest.fn(), + overwriteApiConversationHistory: jest.fn(), + taskId: "test-task-id", + abortTask: jest.fn(), + handleWebviewAskResponse: jest.fn(), + } + // @ts-ignore - accessing private property for testing + provider.cline = mockCline + + // Mock getTaskWithId + ;(provider as any).getTaskWithId = jest.fn().mockResolvedValue({ + historyItem: { id: "test-task-id" }, + }) + + // Trigger message deletion + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + await messageHandler({ type: "deleteMessage", value: 4000 }) + + // Verify correct messages were kept + expect(mockCline.overwriteClineMessages).toHaveBeenCalledWith([ + mockMessages[0], + mockMessages[1], + mockMessages[4], + mockMessages[5], + ]) + + // Verify correct API messages were kept + expect(mockCline.overwriteApiConversationHistory).toHaveBeenCalledWith([ + mockApiHistory[0], + mockApiHistory[1], + mockApiHistory[4], + mockApiHistory[5], + ]) + }) + + test('handles "This and all subsequent messages" deletion correctly', async () => { + // Mock user selecting "This and all subsequent messages" + ;(vscode.window.showInformationMessage as jest.Mock).mockResolvedValue("This and all subsequent messages") + + // Setup mock messages + const mockMessages = [ + { ts: 1000, type: "say", say: "user_feedback" }, + { ts: 2000, type: "say", say: "text", value: 3000 }, // Message to delete + { ts: 3000, type: "say", say: "user_feedback" }, + { ts: 4000, type: "say", say: "user_feedback" }, + ] + + const mockApiHistory = [{ ts: 1000 }, { ts: 2000 }, { ts: 3000 }, { ts: 4000 }] + + // Setup Cline instance with mock data + const mockCline = { + clineMessages: mockMessages, + apiConversationHistory: mockApiHistory, + overwriteClineMessages: jest.fn(), + overwriteApiConversationHistory: jest.fn(), + taskId: "test-task-id", + abortTask: jest.fn(), + handleWebviewAskResponse: jest.fn(), + } + // @ts-ignore - accessing private property for testing + provider.cline = mockCline + + // Mock getTaskWithId + ;(provider as any).getTaskWithId = jest.fn().mockResolvedValue({ + historyItem: { id: "test-task-id" }, + }) + + // Trigger message deletion + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + await messageHandler({ type: "deleteMessage", value: 3000 }) + + // Verify only messages before the deleted message were kept + expect(mockCline.overwriteClineMessages).toHaveBeenCalledWith([mockMessages[0]]) + + // Verify only API messages before the deleted message were kept + expect(mockCline.overwriteApiConversationHistory).toHaveBeenCalledWith([mockApiHistory[0]]) + }) + + test("handles Cancel correctly", async () => { + // Mock user selecting "Cancel" + ;(vscode.window.showInformationMessage as jest.Mock).mockResolvedValue("Cancel") + + const mockCline = { + clineMessages: [{ ts: 1000 }, { ts: 2000 }], + apiConversationHistory: [{ ts: 1000 }, { ts: 2000 }], + overwriteClineMessages: jest.fn(), + overwriteApiConversationHistory: jest.fn(), + taskId: "test-task-id", + } + // @ts-ignore - accessing private property for testing + provider.cline = mockCline + + // Trigger message deletion + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + await messageHandler({ type: "deleteMessage", value: 2000 }) + + // Verify no messages were deleted + expect(mockCline.overwriteClineMessages).not.toHaveBeenCalled() + expect(mockCline.overwriteApiConversationHistory).not.toHaveBeenCalled() + }) + }) + + describe("getSystemPrompt", () => { + beforeEach(() => { + mockPostMessage.mockClear() + provider.resolveWebviewView(mockWebviewView) + }) + + const getMessageHandler = () => { + const mockCalls = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls + expect(mockCalls.length).toBeGreaterThan(0) + return mockCalls[0][0] + } + + test("handles mcpEnabled setting correctly", async () => { + // Mock getState to return mcpEnabled: true + jest.spyOn(provider, "getState").mockResolvedValue({ + apiConfiguration: { + apiProvider: "openrouter" as const, + openRouterModelInfo: { + supportsComputerUse: true, + supportsPromptCache: false, + maxTokens: 4096, + contextWindow: 8192, + supportsImages: false, + inputPrice: 0.0, + outputPrice: 0.0, + description: undefined, + }, + }, + mcpEnabled: true, + mode: "code" as const, + } as any) + + const handler1 = getMessageHandler() + expect(typeof handler1).toBe("function") + await handler1({ type: "getSystemPrompt", mode: "code" }) + + // Verify mcpHub is passed when mcpEnabled is true + expect(mockPostMessage).toHaveBeenCalledWith( + expect.objectContaining({ + type: "systemPrompt", + text: expect.any(String), + }), + ) + + // Mock getState to return mcpEnabled: false + jest.spyOn(provider, "getState").mockResolvedValue({ + apiConfiguration: { + apiProvider: "openrouter" as const, + openRouterModelInfo: { + supportsComputerUse: true, + supportsPromptCache: false, + maxTokens: 4096, + contextWindow: 8192, + supportsImages: false, + inputPrice: 0.0, + outputPrice: 0.0, + description: undefined, + }, + }, + mcpEnabled: false, + mode: "code" as const, + } as any) + + const handler2 = getMessageHandler() + await handler2({ type: "getSystemPrompt", mode: "code" }) + + // Verify mcpHub is not passed when mcpEnabled is false + expect(mockPostMessage).toHaveBeenCalledWith( + expect.objectContaining({ + type: "systemPrompt", + text: expect.any(String), + }), + ) + }) + + test("handles errors gracefully", async () => { + // Mock SYSTEM_PROMPT to throw an error + const systemPrompt = require("../../prompts/system") + jest.spyOn(systemPrompt, "SYSTEM_PROMPT").mockRejectedValueOnce(new Error("Test error")) + + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + await messageHandler({ type: "getSystemPrompt", mode: "code" }) + + expect(vscode.window.showErrorMessage).toHaveBeenCalledWith("Failed to get system prompt") + }) + + test("uses mode-specific custom instructions in system prompt", async () => { + const systemPrompt = require("../../prompts/system") + const { addCustomInstructions } = systemPrompt + + // Mock getState to return mode-specific custom instructions + jest.spyOn(provider, "getState").mockResolvedValue({ + apiConfiguration: { + apiProvider: "openrouter", + openRouterModelInfo: { supportsComputerUse: true }, + }, + customPrompts: { + code: { customInstructions: "Code mode specific instructions" }, + }, + mode: "code", + mcpEnabled: false, + browserViewportSize: "900x600", + } as any) + + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + await messageHandler({ type: "getSystemPrompt", mode: "code" }) + + // Verify addCustomInstructions was called with mode-specific instructions + expect(addCustomInstructions).toHaveBeenCalledWith( + { + customInstructions: undefined, + customPrompts: { + code: { customInstructions: "Code mode specific instructions" }, + }, + preferredLanguage: undefined, + }, + expect.any(String), + "code", + ) + }) + + test("uses correct mode-specific instructions when mode is specified", async () => { + const systemPrompt = require("../../prompts/system") + const { addCustomInstructions } = systemPrompt + + // Mock getState to return instructions for multiple modes + jest.spyOn(provider, "getState").mockResolvedValue({ + apiConfiguration: { + apiProvider: "openrouter", + openRouterModelInfo: { supportsComputerUse: true }, + }, + customPrompts: { + code: { customInstructions: "Code mode instructions" }, + architect: { customInstructions: "Architect mode instructions" }, + }, + mode: "code", + mcpEnabled: false, + browserViewportSize: "900x600", + } as any) + + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + + // Request architect mode prompt + await messageHandler({ type: "getSystemPrompt", mode: "architect" }) + + // Verify architect mode instructions were used + expect(addCustomInstructions).toHaveBeenCalledWith( + { + customInstructions: undefined, + customPrompts: { + code: { customInstructions: "Code mode instructions" }, + architect: { customInstructions: "Architect mode instructions" }, + }, + preferredLanguage: undefined, + }, + expect.any(String), + "architect", + ) + }) + }) }) diff --git a/src/extension.ts b/src/extension.ts index 31ba8a7..165d4f1 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -27,13 +27,11 @@ export function activate(context: vscode.ExtensionContext) { outputChannel.appendLine("Cline extension activated") // Get default commands from configuration - const defaultCommands = vscode.workspace - .getConfiguration('roo-cline') - .get('allowedCommands') || []; + const defaultCommands = vscode.workspace.getConfiguration("roo-cline").get("allowedCommands") || [] // Initialize global state if not already set - if (!context.globalState.get('allowedCommands')) { - context.globalState.update('allowedCommands', defaultCommands); + if (!context.globalState.get("allowedCommands")) { + context.globalState.update("allowedCommands", defaultCommands) } const sidebarProvider = new ClineProvider(context, outputChannel) diff --git a/src/integrations/editor/DiffViewProvider.ts b/src/integrations/editor/DiffViewProvider.ts index 797bb90..a145a22 100644 --- a/src/integrations/editor/DiffViewProvider.ts +++ b/src/integrations/editor/DiffViewProvider.ts @@ -132,10 +132,10 @@ export class DiffViewProvider { // Apply the final content const finalEdit = new vscode.WorkspaceEdit() finalEdit.replace(document.uri, new vscode.Range(0, 0, document.lineCount, 0), accumulatedContent) - await vscode.workspace.applyEdit(finalEdit) - // Clear all decorations at the end (after applying final edit) - this.fadedOverlayController.clear() - this.activeLineController.clear() + await vscode.workspace.applyEdit(finalEdit) + // Clear all decorations at the end (after applying final edit) + this.fadedOverlayController.clear() + this.activeLineController.clear() } } @@ -352,4 +352,4 @@ export class DiffViewProvider { this.streamedLines = [] this.preDiagnostics = [] } -} \ No newline at end of file +} diff --git a/src/integrations/editor/__tests__/DiffViewProvider.test.ts b/src/integrations/editor/__tests__/DiffViewProvider.test.ts index debc408..8de10a6 100644 --- a/src/integrations/editor/__tests__/DiffViewProvider.test.ts +++ b/src/integrations/editor/__tests__/DiffViewProvider.test.ts @@ -1,8 +1,8 @@ -import { DiffViewProvider } from '../DiffViewProvider'; -import * as vscode from 'vscode'; +import { DiffViewProvider } from "../DiffViewProvider" +import * as vscode from "vscode" // Mock vscode -jest.mock('vscode', () => ({ +jest.mock("vscode", () => ({ workspace: { applyEdit: jest.fn(), }, @@ -19,34 +19,34 @@ jest.mock('vscode', () => ({ TextEditorRevealType: { InCenter: 2, }, -})); +})) // Mock DecorationController -jest.mock('../DecorationController', () => ({ +jest.mock("../DecorationController", () => ({ DecorationController: jest.fn().mockImplementation(() => ({ setActiveLine: jest.fn(), updateOverlayAfterLine: jest.fn(), clear: jest.fn(), })), -})); +})) -describe('DiffViewProvider', () => { - let diffViewProvider: DiffViewProvider; - const mockCwd = '/mock/cwd'; - let mockWorkspaceEdit: { replace: jest.Mock; delete: jest.Mock }; +describe("DiffViewProvider", () => { + let diffViewProvider: DiffViewProvider + const mockCwd = "/mock/cwd" + let mockWorkspaceEdit: { replace: jest.Mock; delete: jest.Mock } beforeEach(() => { - jest.clearAllMocks(); + jest.clearAllMocks() mockWorkspaceEdit = { replace: jest.fn(), delete: jest.fn(), - }; - (vscode.WorkspaceEdit as jest.Mock).mockImplementation(() => mockWorkspaceEdit); + } + ;(vscode.WorkspaceEdit as jest.Mock).mockImplementation(() => mockWorkspaceEdit) - diffViewProvider = new DiffViewProvider(mockCwd); + diffViewProvider = new DiffViewProvider(mockCwd) // Mock the necessary properties and methods - (diffViewProvider as any).relPath = 'test.txt'; - (diffViewProvider as any).activeDiffEditor = { + ;(diffViewProvider as any).relPath = "test.txt" + ;(diffViewProvider as any).activeDiffEditor = { document: { uri: { fsPath: `${mockCwd}/test.txt` }, getText: jest.fn(), @@ -58,43 +58,39 @@ describe('DiffViewProvider', () => { }, edit: jest.fn().mockResolvedValue(true), revealRange: jest.fn(), - }; - (diffViewProvider as any).activeLineController = { setActiveLine: jest.fn(), clear: jest.fn() }; - (diffViewProvider as any).fadedOverlayController = { updateOverlayAfterLine: jest.fn(), clear: jest.fn() }; - }); + } + ;(diffViewProvider as any).activeLineController = { setActiveLine: jest.fn(), clear: jest.fn() } + ;(diffViewProvider as any).fadedOverlayController = { updateOverlayAfterLine: jest.fn(), clear: jest.fn() } + }) - describe('update method', () => { - it('should preserve empty last line when original content has one', async () => { - (diffViewProvider as any).originalContent = 'Original content\n'; - await diffViewProvider.update('New content', true); + describe("update method", () => { + it("should preserve empty last line when original content has one", async () => { + ;(diffViewProvider as any).originalContent = "Original content\n" + await diffViewProvider.update("New content", true) expect(mockWorkspaceEdit.replace).toHaveBeenCalledWith( expect.anything(), expect.anything(), - 'New content\n' - ); - }); + "New content\n", + ) + }) - it('should not add extra newline when accumulated content already ends with one', async () => { - (diffViewProvider as any).originalContent = 'Original content\n'; - await diffViewProvider.update('New content\n', true); + it("should not add extra newline when accumulated content already ends with one", async () => { + ;(diffViewProvider as any).originalContent = "Original content\n" + await diffViewProvider.update("New content\n", true) expect(mockWorkspaceEdit.replace).toHaveBeenCalledWith( expect.anything(), expect.anything(), - 'New content\n' - ); - }); + "New content\n", + ) + }) - it('should not add newline when original content does not end with one', async () => { - (diffViewProvider as any).originalContent = 'Original content'; - await diffViewProvider.update('New content', true); + it("should not add newline when original content does not end with one", async () => { + ;(diffViewProvider as any).originalContent = "Original content" + await diffViewProvider.update("New content", true) - expect(mockWorkspaceEdit.replace).toHaveBeenCalledWith( - expect.anything(), - expect.anything(), - 'New content' - ); - }); - }); -}); \ No newline at end of file + expect(mockWorkspaceEdit.replace).toHaveBeenCalledWith(expect.anything(), expect.anything(), "New content") + }) + }) +}) diff --git a/src/integrations/editor/__tests__/detect-omission.test.ts b/src/integrations/editor/__tests__/detect-omission.test.ts index 558617e..3f0ffce 100644 --- a/src/integrations/editor/__tests__/detect-omission.test.ts +++ b/src/integrations/editor/__tests__/detect-omission.test.ts @@ -1,6 +1,6 @@ -import { detectCodeOmission } from '../detect-omission' +import { detectCodeOmission } from "../detect-omission" -describe('detectCodeOmission', () => { +describe("detectCodeOmission", () => { const originalContent = `function example() { // Some code const x = 1; @@ -10,124 +10,132 @@ describe('detectCodeOmission', () => { const generateLongContent = (commentLine: string, length: number = 90) => { return `${commentLine} - ${Array.from({ length }, (_, i) => `const x${i} = ${i};`).join('\n')} + ${Array.from({ length }, (_, i) => `const x${i} = ${i};`).join("\n")} const y = 2;` } - it('should skip comment checks for files under 100 lines', () => { + it("should skip comment checks for files under 100 lines", () => { const newContent = `// Lines 1-50 remain unchanged const z = 3;` const predictedLineCount = 50 expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false) }) - it('should not detect regular comments without omission keywords', () => { - const newContent = generateLongContent('// Adding new functionality') + it("should not detect regular comments without omission keywords", () => { + const newContent = generateLongContent("// Adding new functionality") const predictedLineCount = 150 expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false) }) - it('should not detect when comment is part of original content', () => { + it("should not detect when comment is part of original content", () => { const originalWithComment = `// Content remains unchanged ${originalContent}` - const newContent = generateLongContent('// Content remains unchanged') + const newContent = generateLongContent("// Content remains unchanged") const predictedLineCount = 150 expect(detectCodeOmission(originalWithComment, newContent, predictedLineCount)).toBe(false) }) - it('should not detect code that happens to contain omission keywords', () => { + it("should not detect code that happens to contain omission keywords", () => { const newContent = generateLongContent(`const remains = 'some value'; const unchanged = true;`) const predictedLineCount = 150 expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false) }) - it('should detect suspicious single-line comment when content is more than 20% shorter', () => { - const newContent = generateLongContent('// Previous content remains here\nconst x = 1;') + it("should detect suspicious single-line comment when content is more than 20% shorter", () => { + const newContent = generateLongContent("// Previous content remains here\nconst x = 1;") const predictedLineCount = 150 expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true) }) - it('should not flag suspicious single-line comment when content is less than 20% shorter', () => { - const newContent = generateLongContent('// Previous content remains here', 130) + it("should not flag suspicious single-line comment when content is less than 20% shorter", () => { + const newContent = generateLongContent("// Previous content remains here", 130) const predictedLineCount = 150 expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false) }) - it('should detect suspicious Python-style comment when content is more than 20% shorter', () => { - const newContent = generateLongContent('# Previous content remains here\nconst x = 1;') + it("should detect suspicious Python-style comment when content is more than 20% shorter", () => { + const newContent = generateLongContent("# Previous content remains here\nconst x = 1;") const predictedLineCount = 150 expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true) }) - it('should not flag suspicious Python-style comment when content is less than 20% shorter', () => { - const newContent = generateLongContent('# Previous content remains here', 130) + it("should not flag suspicious Python-style comment when content is less than 20% shorter", () => { + const newContent = generateLongContent("# Previous content remains here", 130) const predictedLineCount = 150 expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false) }) - it('should detect suspicious multi-line comment when content is more than 20% shorter', () => { - const newContent = generateLongContent('/* Previous content remains the same */\nconst x = 1;') + it("should detect suspicious multi-line comment when content is more than 20% shorter", () => { + const newContent = generateLongContent("/* Previous content remains the same */\nconst x = 1;") const predictedLineCount = 150 expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true) }) - it('should not flag suspicious multi-line comment when content is less than 20% shorter', () => { - const newContent = generateLongContent('/* Previous content remains the same */', 130) + it("should not flag suspicious multi-line comment when content is less than 20% shorter", () => { + const newContent = generateLongContent("/* Previous content remains the same */", 130) const predictedLineCount = 150 expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false) }) - it('should detect suspicious JSX comment when content is more than 20% shorter', () => { - const newContent = generateLongContent('{/* Rest of the code remains the same */}\nconst x = 1;') + it("should detect suspicious JSX comment when content is more than 20% shorter", () => { + const newContent = generateLongContent("{/* Rest of the code remains the same */}\nconst x = 1;") const predictedLineCount = 150 expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true) }) - it('should not flag suspicious JSX comment when content is less than 20% shorter', () => { - const newContent = generateLongContent('{/* Rest of the code remains the same */}', 130) + it("should not flag suspicious JSX comment when content is less than 20% shorter", () => { + const newContent = generateLongContent("{/* Rest of the code remains the same */}", 130) const predictedLineCount = 150 expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false) }) - it('should detect suspicious HTML comment when content is more than 20% shorter', () => { - const newContent = generateLongContent('\nconst x = 1;') + it("should detect suspicious HTML comment when content is more than 20% shorter", () => { + const newContent = generateLongContent("\nconst x = 1;") const predictedLineCount = 150 expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true) }) - it('should not flag suspicious HTML comment when content is less than 20% shorter', () => { - const newContent = generateLongContent('', 130) + it("should not flag suspicious HTML comment when content is less than 20% shorter", () => { + const newContent = generateLongContent("", 130) const predictedLineCount = 150 expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false) }) - it('should detect suspicious square bracket notation when content is more than 20% shorter', () => { - const newContent = generateLongContent('[Previous content from line 1-305 remains exactly the same]\nconst x = 1;') + it("should detect suspicious square bracket notation when content is more than 20% shorter", () => { + const newContent = generateLongContent( + "[Previous content from line 1-305 remains exactly the same]\nconst x = 1;", + ) const predictedLineCount = 150 expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true) }) - it('should not flag suspicious square bracket notation when content is less than 20% shorter', () => { - const newContent = generateLongContent('[Previous content from line 1-305 remains exactly the same]', 130) + it("should not flag suspicious square bracket notation when content is less than 20% shorter", () => { + const newContent = generateLongContent("[Previous content from line 1-305 remains exactly the same]", 130) const predictedLineCount = 150 expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false) }) - it('should not flag content very close to predicted length', () => { - const newContent = generateLongContent(`const x = 1; + it("should not flag content very close to predicted length", () => { + const newContent = generateLongContent( + `const x = 1; const y = 2; -// This is a legitimate comment that remains here`, 130) +// This is a legitimate comment that remains here`, + 130, + ) const predictedLineCount = 150 expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false) }) - it('should not flag when content is longer than predicted', () => { - const newContent = generateLongContent(`const x = 1; + it("should not flag when content is longer than predicted", () => { + const newContent = generateLongContent( + `const x = 1; const y = 2; // Previous content remains here but we added more const z = 3; -const w = 4;`, 160) +const w = 4;`, + 160, + ) const predictedLineCount = 150 expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false) }) diff --git a/src/integrations/editor/detect-omission.ts b/src/integrations/editor/detect-omission.ts index d8c3e14..50bef62 100644 --- a/src/integrations/editor/detect-omission.ts +++ b/src/integrations/editor/detect-omission.ts @@ -8,7 +8,7 @@ export function detectCodeOmission( originalFileContent: string, newFileContent: string, - predictedLineCount: number + predictedLineCount: number, ): boolean { // Skip all checks if predictedLineCount is less than 100 if (!predictedLineCount || predictedLineCount < 100) { @@ -20,7 +20,17 @@ export function detectCodeOmission( const originalLines = originalFileContent.split("\n") const newLines = newFileContent.split("\n") - const omissionKeywords = ["remain", "remains", "unchanged", "rest", "previous", "existing", "content", "same", "..."] + const omissionKeywords = [ + "remain", + "remains", + "unchanged", + "rest", + "previous", + "existing", + "content", + "same", + "...", + ] const commentPatterns = [ /^\s*\/\//, // Single-line comment for most languages @@ -39,7 +49,7 @@ export function detectCodeOmission( if (omissionKeywords.some((keyword) => words.includes(keyword))) { if (!originalLines.includes(line)) { // For files with 100+ lines, only flag if content is more than 20% shorter - if (lengthRatio <= 0.80) { + if (lengthRatio <= 0.8) { return true } } @@ -48,4 +58,4 @@ export function detectCodeOmission( } return false -} \ No newline at end of file +} diff --git a/src/integrations/misc/__tests__/extract-text.test.ts b/src/integrations/misc/__tests__/extract-text.test.ts index 5b91324..7e084d0 100644 --- a/src/integrations/misc/__tests__/extract-text.test.ts +++ b/src/integrations/misc/__tests__/extract-text.test.ts @@ -1,122 +1,122 @@ -import { addLineNumbers, everyLineHasLineNumbers, stripLineNumbers, truncateOutput } from '../extract-text'; +import { addLineNumbers, everyLineHasLineNumbers, stripLineNumbers, truncateOutput } from "../extract-text" -describe('addLineNumbers', () => { - it('should add line numbers starting from 1 by default', () => { - const input = 'line 1\nline 2\nline 3'; - const expected = '1 | line 1\n2 | line 2\n3 | line 3'; - expect(addLineNumbers(input)).toBe(expected); - }); +describe("addLineNumbers", () => { + it("should add line numbers starting from 1 by default", () => { + const input = "line 1\nline 2\nline 3" + const expected = "1 | line 1\n2 | line 2\n3 | line 3" + expect(addLineNumbers(input)).toBe(expected) + }) - it('should add line numbers starting from specified line number', () => { - const input = 'line 1\nline 2\nline 3'; - const expected = '10 | line 1\n11 | line 2\n12 | line 3'; - expect(addLineNumbers(input, 10)).toBe(expected); - }); + it("should add line numbers starting from specified line number", () => { + const input = "line 1\nline 2\nline 3" + const expected = "10 | line 1\n11 | line 2\n12 | line 3" + expect(addLineNumbers(input, 10)).toBe(expected) + }) - it('should handle empty content', () => { - expect(addLineNumbers('')).toBe('1 | '); - expect(addLineNumbers('', 5)).toBe('5 | '); - }); + it("should handle empty content", () => { + expect(addLineNumbers("")).toBe("1 | ") + expect(addLineNumbers("", 5)).toBe("5 | ") + }) - it('should handle single line content', () => { - expect(addLineNumbers('single line')).toBe('1 | single line'); - expect(addLineNumbers('single line', 42)).toBe('42 | single line'); - }); + it("should handle single line content", () => { + expect(addLineNumbers("single line")).toBe("1 | single line") + expect(addLineNumbers("single line", 42)).toBe("42 | single line") + }) - it('should pad line numbers based on the highest line number', () => { - const input = 'line 1\nline 2'; + it("should pad line numbers based on the highest line number", () => { + const input = "line 1\nline 2" // When starting from 99, highest line will be 100, so needs 3 spaces padding - const expected = ' 99 | line 1\n100 | line 2'; - expect(addLineNumbers(input, 99)).toBe(expected); - }); -}); + const expected = " 99 | line 1\n100 | line 2" + expect(addLineNumbers(input, 99)).toBe(expected) + }) +}) -describe('everyLineHasLineNumbers', () => { - it('should return true for content with line numbers', () => { - const input = '1 | line one\n2 | line two\n3 | line three'; - expect(everyLineHasLineNumbers(input)).toBe(true); - }); +describe("everyLineHasLineNumbers", () => { + it("should return true for content with line numbers", () => { + const input = "1 | line one\n2 | line two\n3 | line three" + expect(everyLineHasLineNumbers(input)).toBe(true) + }) - it('should return true for content with padded line numbers', () => { - const input = ' 1 | line one\n 2 | line two\n 3 | line three'; - expect(everyLineHasLineNumbers(input)).toBe(true); - }); + it("should return true for content with padded line numbers", () => { + const input = " 1 | line one\n 2 | line two\n 3 | line three" + expect(everyLineHasLineNumbers(input)).toBe(true) + }) - it('should return false for content without line numbers', () => { - const input = 'line one\nline two\nline three'; - expect(everyLineHasLineNumbers(input)).toBe(false); - }); + it("should return false for content without line numbers", () => { + const input = "line one\nline two\nline three" + expect(everyLineHasLineNumbers(input)).toBe(false) + }) - it('should return false for mixed content', () => { - const input = '1 | line one\nline two\n3 | line three'; - expect(everyLineHasLineNumbers(input)).toBe(false); - }); + it("should return false for mixed content", () => { + const input = "1 | line one\nline two\n3 | line three" + expect(everyLineHasLineNumbers(input)).toBe(false) + }) - it('should handle empty content', () => { - expect(everyLineHasLineNumbers('')).toBe(false); - }); + it("should handle empty content", () => { + expect(everyLineHasLineNumbers("")).toBe(false) + }) - it('should return false for content with pipe but no line numbers', () => { - const input = 'a | b\nc | d'; - expect(everyLineHasLineNumbers(input)).toBe(false); - }); -}); + it("should return false for content with pipe but no line numbers", () => { + const input = "a | b\nc | d" + expect(everyLineHasLineNumbers(input)).toBe(false) + }) +}) -describe('stripLineNumbers', () => { - it('should strip line numbers from content', () => { - const input = '1 | line one\n2 | line two\n3 | line three'; - const expected = 'line one\nline two\nline three'; - expect(stripLineNumbers(input)).toBe(expected); - }); +describe("stripLineNumbers", () => { + it("should strip line numbers from content", () => { + const input = "1 | line one\n2 | line two\n3 | line three" + const expected = "line one\nline two\nline three" + expect(stripLineNumbers(input)).toBe(expected) + }) - it('should strip padded line numbers', () => { - const input = ' 1 | line one\n 2 | line two\n 3 | line three'; - const expected = 'line one\nline two\nline three'; - expect(stripLineNumbers(input)).toBe(expected); - }); + it("should strip padded line numbers", () => { + const input = " 1 | line one\n 2 | line two\n 3 | line three" + const expected = "line one\nline two\nline three" + expect(stripLineNumbers(input)).toBe(expected) + }) - it('should handle content without line numbers', () => { - const input = 'line one\nline two\nline three'; - expect(stripLineNumbers(input)).toBe(input); - }); + it("should handle content without line numbers", () => { + const input = "line one\nline two\nline three" + expect(stripLineNumbers(input)).toBe(input) + }) - it('should handle empty content', () => { - expect(stripLineNumbers('')).toBe(''); - }); + it("should handle empty content", () => { + expect(stripLineNumbers("")).toBe("") + }) - it('should preserve content with pipe but no line numbers', () => { - const input = 'a | b\nc | d'; - expect(stripLineNumbers(input)).toBe(input); - }); + it("should preserve content with pipe but no line numbers", () => { + const input = "a | b\nc | d" + expect(stripLineNumbers(input)).toBe(input) + }) - it('should handle windows-style line endings', () => { - const input = '1 | line one\r\n2 | line two\r\n3 | line three'; - const expected = 'line one\r\nline two\r\nline three'; - expect(stripLineNumbers(input)).toBe(expected); - }); + it("should handle windows-style line endings", () => { + const input = "1 | line one\r\n2 | line two\r\n3 | line three" + const expected = "line one\r\nline two\r\nline three" + expect(stripLineNumbers(input)).toBe(expected) + }) - it('should handle content with varying line number widths', () => { - const input = ' 1 | line one\n 10 | line two\n100 | line three'; - const expected = 'line one\nline two\nline three'; - expect(stripLineNumbers(input)).toBe(expected); - }); -}); + it("should handle content with varying line number widths", () => { + const input = " 1 | line one\n 10 | line two\n100 | line three" + const expected = "line one\nline two\nline three" + expect(stripLineNumbers(input)).toBe(expected) + }) +}) -describe('truncateOutput', () => { - it('returns original content when no line limit provided', () => { - const content = 'line1\nline2\nline3' +describe("truncateOutput", () => { + it("returns original content when no line limit provided", () => { + const content = "line1\nline2\nline3" expect(truncateOutput(content)).toBe(content) }) - it('returns original content when lines are under limit', () => { - const content = 'line1\nline2\nline3' + it("returns original content when lines are under limit", () => { + const content = "line1\nline2\nline3" expect(truncateOutput(content, 5)).toBe(content) }) - it('truncates content with 20/80 split when over limit', () => { + it("truncates content with 20/80 split when over limit", () => { // Create 25 lines of content const lines = Array.from({ length: 25 }, (_, i) => `line${i + 1}`) - const content = lines.join('\n') + const content = lines.join("\n") // Set limit to 10 lines const result = truncateOutput(content, 10) @@ -126,51 +126,42 @@ describe('truncateOutput', () => { // - Last 8 lines (80% of 10) // - Omission indicator in between const expectedLines = [ - 'line1', - 'line2', - '', - '[...15 lines omitted...]', - '', - 'line18', - 'line19', - 'line20', - 'line21', - 'line22', - 'line23', - 'line24', - 'line25' + "line1", + "line2", + "", + "[...15 lines omitted...]", + "", + "line18", + "line19", + "line20", + "line21", + "line22", + "line23", + "line24", + "line25", ] - expect(result).toBe(expectedLines.join('\n')) + expect(result).toBe(expectedLines.join("\n")) }) - it('handles empty content', () => { - expect(truncateOutput('', 10)).toBe('') + it("handles empty content", () => { + expect(truncateOutput("", 10)).toBe("") }) - it('handles single line content', () => { - expect(truncateOutput('single line', 10)).toBe('single line') + it("handles single line content", () => { + expect(truncateOutput("single line", 10)).toBe("single line") }) - it('handles windows-style line endings', () => { + it("handles windows-style line endings", () => { // Create content with windows line endings const lines = Array.from({ length: 15 }, (_, i) => `line${i + 1}`) - const content = lines.join('\r\n') + const content = lines.join("\r\n") const result = truncateOutput(content, 5) - + // Should keep first line (20% of 5 = 1) and last 4 lines (80% of 5 = 4) // Split result by either \r\n or \n to normalize line endings const resultLines = result.split(/\r?\n/) - const expectedLines = [ - 'line1', - '', - '[...10 lines omitted...]', - '', - 'line12', - 'line13', - 'line14', - 'line15' - ] + const expectedLines = ["line1", "", "[...10 lines omitted...]", "", "line12", "line13", "line14", "line15"] expect(resultLines).toEqual(expectedLines) }) -}) \ No newline at end of file +}) diff --git a/src/integrations/misc/extract-text.ts b/src/integrations/misc/extract-text.ts index ee70652..0354570 100644 --- a/src/integrations/misc/extract-text.ts +++ b/src/integrations/misc/extract-text.ts @@ -55,19 +55,20 @@ async function extractTextFromIPYNB(filePath: string): Promise { } export function addLineNumbers(content: string, startLine: number = 1): string { - const lines = content.split('\n') + const lines = content.split("\n") const maxLineNumberWidth = String(startLine + lines.length - 1).length return lines .map((line, index) => { - const lineNumber = String(startLine + index).padStart(maxLineNumberWidth, ' ') + const lineNumber = String(startLine + index).padStart(maxLineNumberWidth, " ") return `${lineNumber} | ${line}` - }).join('\n') + }) + .join("\n") } // Checks if every line in the content has line numbers prefixed (e.g., "1 | content" or "123 | content") // Line numbers must be followed by a single pipe character (not double pipes) export function everyLineHasLineNumbers(content: string): boolean { const lines = content.split(/\r?\n/) - return lines.length > 0 && lines.every(line => /^\s*\d+\s+\|(?!\|)/.test(line)) + return lines.length > 0 && lines.every((line) => /^\s*\d+\s+\|(?!\|)/.test(line)) } // Strips line numbers from content while preserving the actual content @@ -76,16 +77,16 @@ export function everyLineHasLineNumbers(content: string): boolean { export function stripLineNumbers(content: string): string { // Split into lines to handle each line individually const lines = content.split(/\r?\n/) - + // Process each line - const processedLines = lines.map(line => { + const processedLines = lines.map((line) => { // Match line number pattern and capture everything after the pipe const match = line.match(/^\s*\d+\s+\|(?!\|)\s?(.*)$/) return match ? match[1] : line }) - + // Join back with original line endings - const lineEnding = content.includes('\r\n') ? '\r\n' : '\n' + const lineEnding = content.includes("\r\n") ? "\r\n" : "\n" return processedLines.join(lineEnding) } @@ -109,7 +110,7 @@ export function truncateOutput(content: string, lineLimit?: number): string { return content } - const lines = content.split('\n') + const lines = content.split("\n") if (lines.length <= lineLimit) { return content } @@ -119,6 +120,6 @@ export function truncateOutput(content: string, lineLimit?: number): string { return [ ...lines.slice(0, beforeLimit), `\n[...${lines.length - lineLimit} lines omitted...]\n`, - ...lines.slice(-afterLimit) - ].join('\n') -} \ No newline at end of file + ...lines.slice(-afterLimit), + ].join("\n") +} diff --git a/src/integrations/misc/open-file.ts b/src/integrations/misc/open-file.ts index 08a3ce1..daf36f1 100644 --- a/src/integrations/misc/open-file.ts +++ b/src/integrations/misc/open-file.ts @@ -21,8 +21,8 @@ export async function openImage(dataUri: string) { } interface OpenFileOptions { - create?: boolean; - content?: string; + create?: boolean + content?: string } export async function openFile(filePath: string, options: OpenFileOptions = {}) { @@ -30,13 +30,11 @@ export async function openFile(filePath: string, options: OpenFileOptions = {}) // Get workspace root const workspaceRoot = vscode.workspace.workspaceFolders?.[0]?.uri.fsPath if (!workspaceRoot) { - throw new Error('No workspace root found') + throw new Error("No workspace root found") } // If path starts with ./, resolve it relative to workspace root - const fullPath = filePath.startsWith('./') ? - path.join(workspaceRoot, filePath.slice(2)) : - filePath + const fullPath = filePath.startsWith("./") ? path.join(workspaceRoot, filePath.slice(2)) : filePath const uri = vscode.Uri.file(fullPath) @@ -46,12 +44,12 @@ export async function openFile(filePath: string, options: OpenFileOptions = {}) } catch { // File doesn't exist if (!options.create) { - throw new Error('File does not exist') + throw new Error("File does not exist") } - + // Create with provided content or empty string - const content = options.content || '' - await vscode.workspace.fs.writeFile(uri, Buffer.from(content, 'utf8')) + const content = options.content || "" + await vscode.workspace.fs.writeFile(uri, Buffer.from(content, "utf8")) } // Check if the document is already open in a tab group that's not in the active editor's column diff --git a/src/integrations/terminal/TerminalManager.ts b/src/integrations/terminal/TerminalManager.ts index 655e037..5234791 100644 --- a/src/integrations/terminal/TerminalManager.ts +++ b/src/integrations/terminal/TerminalManager.ts @@ -146,7 +146,9 @@ export class TerminalManager { process.run(terminal, command) } else { // docs recommend waiting 3s for shell integration to activate - pWaitFor(() => (terminalInfo.terminal as ExtendedTerminal).shellIntegration !== undefined, { timeout: 4000 }).finally(() => { + pWaitFor(() => (terminalInfo.terminal as ExtendedTerminal).shellIntegration !== undefined, { + timeout: 4000, + }).finally(() => { const existingProcess = this.processes.get(terminalInfo.id) if (existingProcess && existingProcess.waitForShellIntegration) { existingProcess.waitForShellIntegration = false diff --git a/src/integrations/terminal/TerminalRegistry.ts b/src/integrations/terminal/TerminalRegistry.ts index 5937047..e016147 100644 --- a/src/integrations/terminal/TerminalRegistry.ts +++ b/src/integrations/terminal/TerminalRegistry.ts @@ -19,8 +19,8 @@ export class TerminalRegistry { name: "Roo Cline", iconPath: new vscode.ThemeIcon("rocket"), env: { - PAGER: "cat" - } + PAGER: "cat", + }, }) const newInfo: TerminalInfo = { terminal, diff --git a/src/integrations/terminal/__tests__/TerminalProcess.test.ts b/src/integrations/terminal/__tests__/TerminalProcess.test.ts index 6294780..9ccbaef 100644 --- a/src/integrations/terminal/__tests__/TerminalProcess.test.ts +++ b/src/integrations/terminal/__tests__/TerminalProcess.test.ts @@ -6,224 +6,228 @@ import { EventEmitter } from "events" jest.mock("vscode") describe("TerminalProcess", () => { - let terminalProcess: TerminalProcess - let mockTerminal: jest.Mocked - let mockExecution: any - let mockStream: AsyncIterableIterator + let terminalProcess: TerminalProcess + let mockTerminal: jest.Mocked< + vscode.Terminal & { + shellIntegration: { + executeCommand: jest.Mock + } + } + > + let mockExecution: any + let mockStream: AsyncIterableIterator - beforeEach(() => { - terminalProcess = new TerminalProcess() - - // Create properly typed mock terminal - mockTerminal = { - shellIntegration: { - executeCommand: jest.fn() - }, - name: "Mock Terminal", - processId: Promise.resolve(123), - creationOptions: {}, - exitStatus: undefined, - state: { isInteractedWith: true }, - dispose: jest.fn(), - hide: jest.fn(), - show: jest.fn(), - sendText: jest.fn() - } as unknown as jest.Mocked + beforeEach(() => { + terminalProcess = new TerminalProcess() - // Reset event listeners - terminalProcess.removeAllListeners() - }) + // Create properly typed mock terminal + mockTerminal = { + shellIntegration: { + executeCommand: jest.fn(), + }, + name: "Mock Terminal", + processId: Promise.resolve(123), + creationOptions: {}, + exitStatus: undefined, + state: { isInteractedWith: true }, + dispose: jest.fn(), + hide: jest.fn(), + show: jest.fn(), + sendText: jest.fn(), + } as unknown as jest.Mocked< + vscode.Terminal & { + shellIntegration: { + executeCommand: jest.Mock + } + } + > - describe("run", () => { - it("handles shell integration commands correctly", async () => { - const lines: string[] = [] - terminalProcess.on("line", (line) => { - // Skip empty lines used for loading spinner - if (line !== "") { - lines.push(line) - } - }) + // Reset event listeners + terminalProcess.removeAllListeners() + }) - // Mock stream data with shell integration sequences - mockStream = (async function* () { - // The first chunk contains the command start sequence - yield "Initial output\n" - yield "More output\n" - // The last chunk contains the command end sequence - yield "Final output" - })() + describe("run", () => { + it("handles shell integration commands correctly", async () => { + const lines: string[] = [] + terminalProcess.on("line", (line) => { + // Skip empty lines used for loading spinner + if (line !== "") { + lines.push(line) + } + }) - mockExecution = { - read: jest.fn().mockReturnValue(mockStream) - } + // Mock stream data with shell integration sequences + mockStream = (async function* () { + // The first chunk contains the command start sequence + yield "Initial output\n" + yield "More output\n" + // The last chunk contains the command end sequence + yield "Final output" + })() - mockTerminal.shellIntegration.executeCommand.mockReturnValue(mockExecution) + mockExecution = { + read: jest.fn().mockReturnValue(mockStream), + } - const completedPromise = new Promise((resolve) => { - terminalProcess.once("completed", resolve) - }) + mockTerminal.shellIntegration.executeCommand.mockReturnValue(mockExecution) - await terminalProcess.run(mockTerminal, "test command") - await completedPromise + const completedPromise = new Promise((resolve) => { + terminalProcess.once("completed", resolve) + }) - expect(lines).toEqual(["Initial output", "More output", "Final output"]) - expect(terminalProcess.isHot).toBe(false) - }) + await terminalProcess.run(mockTerminal, "test command") + await completedPromise - it("handles terminals without shell integration", async () => { - const noShellTerminal = { - sendText: jest.fn(), - shellIntegration: undefined - } as unknown as vscode.Terminal + expect(lines).toEqual(["Initial output", "More output", "Final output"]) + expect(terminalProcess.isHot).toBe(false) + }) - const noShellPromise = new Promise((resolve) => { - terminalProcess.once("no_shell_integration", resolve) - }) + it("handles terminals without shell integration", async () => { + const noShellTerminal = { + sendText: jest.fn(), + shellIntegration: undefined, + } as unknown as vscode.Terminal - await terminalProcess.run(noShellTerminal, "test command") - await noShellPromise + const noShellPromise = new Promise((resolve) => { + terminalProcess.once("no_shell_integration", resolve) + }) - expect(noShellTerminal.sendText).toHaveBeenCalledWith("test command", true) - }) + await terminalProcess.run(noShellTerminal, "test command") + await noShellPromise - it("sets hot state for compiling commands", async () => { - const lines: string[] = [] - terminalProcess.on("line", (line) => { - if (line !== "") { - lines.push(line) - } - }) + expect(noShellTerminal.sendText).toHaveBeenCalledWith("test command", true) + }) - // Create a promise that resolves when the first chunk is processed - const firstChunkProcessed = new Promise(resolve => { - terminalProcess.on("line", () => resolve()) - }) + it("sets hot state for compiling commands", async () => { + const lines: string[] = [] + terminalProcess.on("line", (line) => { + if (line !== "") { + lines.push(line) + } + }) - mockStream = (async function* () { - yield "compiling...\n" - // Wait to ensure hot state check happens after first chunk - await new Promise(resolve => setTimeout(resolve, 10)) - yield "still compiling...\n" - yield "done" - })() + // Create a promise that resolves when the first chunk is processed + const firstChunkProcessed = new Promise((resolve) => { + terminalProcess.on("line", () => resolve()) + }) - mockExecution = { - read: jest.fn().mockReturnValue(mockStream) - } + mockStream = (async function* () { + yield "compiling...\n" + // Wait to ensure hot state check happens after first chunk + await new Promise((resolve) => setTimeout(resolve, 10)) + yield "still compiling...\n" + yield "done" + })() - mockTerminal.shellIntegration.executeCommand.mockReturnValue(mockExecution) + mockExecution = { + read: jest.fn().mockReturnValue(mockStream), + } - // Start the command execution - const runPromise = terminalProcess.run(mockTerminal, "npm run build") - - // Wait for the first chunk to be processed - await firstChunkProcessed - - // Hot state should be true while compiling - expect(terminalProcess.isHot).toBe(true) + mockTerminal.shellIntegration.executeCommand.mockReturnValue(mockExecution) - // Complete the execution - const completedPromise = new Promise((resolve) => { - terminalProcess.once("completed", resolve) - }) + // Start the command execution + const runPromise = terminalProcess.run(mockTerminal, "npm run build") - await runPromise - await completedPromise + // Wait for the first chunk to be processed + await firstChunkProcessed - expect(lines).toEqual(["compiling...", "still compiling...", "done"]) - }) - }) + // Hot state should be true while compiling + expect(terminalProcess.isHot).toBe(true) - describe("buffer processing", () => { - it("correctly processes and emits lines", () => { - const lines: string[] = [] - terminalProcess.on("line", (line) => lines.push(line)) + // Complete the execution + const completedPromise = new Promise((resolve) => { + terminalProcess.once("completed", resolve) + }) - // Simulate incoming chunks - terminalProcess["emitIfEol"]("first line\n") - terminalProcess["emitIfEol"]("second") - terminalProcess["emitIfEol"](" line\n") - terminalProcess["emitIfEol"]("third line") + await runPromise + await completedPromise - expect(lines).toEqual(["first line", "second line"]) + expect(lines).toEqual(["compiling...", "still compiling...", "done"]) + }) + }) - // Process remaining buffer - terminalProcess["emitRemainingBufferIfListening"]() - expect(lines).toEqual(["first line", "second line", "third line"]) - }) + describe("buffer processing", () => { + it("correctly processes and emits lines", () => { + const lines: string[] = [] + terminalProcess.on("line", (line) => lines.push(line)) - it("handles Windows-style line endings", () => { - const lines: string[] = [] - terminalProcess.on("line", (line) => lines.push(line)) + // Simulate incoming chunks + terminalProcess["emitIfEol"]("first line\n") + terminalProcess["emitIfEol"]("second") + terminalProcess["emitIfEol"](" line\n") + terminalProcess["emitIfEol"]("third line") - terminalProcess["emitIfEol"]("line1\r\nline2\r\n") + expect(lines).toEqual(["first line", "second line"]) - expect(lines).toEqual(["line1", "line2"]) - }) - }) + // Process remaining buffer + terminalProcess["emitRemainingBufferIfListening"]() + expect(lines).toEqual(["first line", "second line", "third line"]) + }) - describe("removeLastLineArtifacts", () => { - it("removes terminal artifacts from output", () => { - const cases = [ - ["output%", "output"], - ["output$ ", "output"], - ["output#", "output"], - ["output> ", "output"], - ["multi\nline%", "multi\nline"], - ["no artifacts", "no artifacts"] - ] + it("handles Windows-style line endings", () => { + const lines: string[] = [] + terminalProcess.on("line", (line) => lines.push(line)) - for (const [input, expected] of cases) { - expect(terminalProcess["removeLastLineArtifacts"](input)).toBe(expected) - } - }) - }) + terminalProcess["emitIfEol"]("line1\r\nline2\r\n") - describe("continue", () => { - it("stops listening and emits continue event", () => { - const continueSpy = jest.fn() - terminalProcess.on("continue", continueSpy) + expect(lines).toEqual(["line1", "line2"]) + }) + }) - terminalProcess.continue() + describe("removeLastLineArtifacts", () => { + it("removes terminal artifacts from output", () => { + const cases = [ + ["output%", "output"], + ["output$ ", "output"], + ["output#", "output"], + ["output> ", "output"], + ["multi\nline%", "multi\nline"], + ["no artifacts", "no artifacts"], + ] - expect(continueSpy).toHaveBeenCalled() - expect(terminalProcess["isListening"]).toBe(false) - }) - }) + for (const [input, expected] of cases) { + expect(terminalProcess["removeLastLineArtifacts"](input)).toBe(expected) + } + }) + }) - describe("getUnretrievedOutput", () => { - it("returns and clears unretrieved output", () => { - terminalProcess["fullOutput"] = "previous\nnew output" - terminalProcess["lastRetrievedIndex"] = 9 // After "previous\n" + describe("continue", () => { + it("stops listening and emits continue event", () => { + const continueSpy = jest.fn() + terminalProcess.on("continue", continueSpy) - const unretrieved = terminalProcess.getUnretrievedOutput() + terminalProcess.continue() - expect(unretrieved).toBe("new output") - expect(terminalProcess["lastRetrievedIndex"]).toBe(terminalProcess["fullOutput"].length) - }) - }) + expect(continueSpy).toHaveBeenCalled() + expect(terminalProcess["isListening"]).toBe(false) + }) + }) - describe("mergePromise", () => { - it("merges promise methods with terminal process", async () => { - const process = new TerminalProcess() - const promise = Promise.resolve() + describe("getUnretrievedOutput", () => { + it("returns and clears unretrieved output", () => { + terminalProcess["fullOutput"] = "previous\nnew output" + terminalProcess["lastRetrievedIndex"] = 9 // After "previous\n" - const merged = mergePromise(process, promise) + const unretrieved = terminalProcess.getUnretrievedOutput() - expect(merged).toHaveProperty("then") - expect(merged).toHaveProperty("catch") - expect(merged).toHaveProperty("finally") - expect(merged instanceof TerminalProcess).toBe(true) + expect(unretrieved).toBe("new output") + expect(terminalProcess["lastRetrievedIndex"]).toBe(terminalProcess["fullOutput"].length) + }) + }) - await expect(merged).resolves.toBeUndefined() - }) - }) -}) \ No newline at end of file + describe("mergePromise", () => { + it("merges promise methods with terminal process", async () => { + const process = new TerminalProcess() + const promise = Promise.resolve() + + const merged = mergePromise(process, promise) + + expect(merged).toHaveProperty("then") + expect(merged).toHaveProperty("catch") + expect(merged).toHaveProperty("finally") + expect(merged instanceof TerminalProcess).toBe(true) + + await expect(merged).resolves.toBeUndefined() + }) + }) +}) diff --git a/src/integrations/terminal/__tests__/TerminalRegistry.test.ts b/src/integrations/terminal/__tests__/TerminalRegistry.test.ts index 9aa2483..6f535f0 100644 --- a/src/integrations/terminal/__tests__/TerminalRegistry.test.ts +++ b/src/integrations/terminal/__tests__/TerminalRegistry.test.ts @@ -4,34 +4,34 @@ import { TerminalRegistry } from "../TerminalRegistry" // Mock vscode.window.createTerminal const mockCreateTerminal = jest.fn() jest.mock("vscode", () => ({ - window: { - createTerminal: (...args: any[]) => { - mockCreateTerminal(...args) - return { - exitStatus: undefined, - } - }, - }, - ThemeIcon: jest.fn(), + window: { + createTerminal: (...args: any[]) => { + mockCreateTerminal(...args) + return { + exitStatus: undefined, + } + }, + }, + ThemeIcon: jest.fn(), })) describe("TerminalRegistry", () => { - beforeEach(() => { - mockCreateTerminal.mockClear() - }) + beforeEach(() => { + mockCreateTerminal.mockClear() + }) - describe("createTerminal", () => { - it("creates terminal with PAGER set to cat", () => { - TerminalRegistry.createTerminal("/test/path") + describe("createTerminal", () => { + it("creates terminal with PAGER set to cat", () => { + TerminalRegistry.createTerminal("/test/path") - expect(mockCreateTerminal).toHaveBeenCalledWith({ - cwd: "/test/path", - name: "Roo Cline", - iconPath: expect.any(Object), - env: { - PAGER: "cat" - } - }) - }) - }) -}) \ No newline at end of file + expect(mockCreateTerminal).toHaveBeenCalledWith({ + cwd: "/test/path", + name: "Roo Cline", + iconPath: expect.any(Object), + env: { + PAGER: "cat", + }, + }) + }) + }) +}) diff --git a/src/integrations/workspace/WorkspaceTracker.ts b/src/integrations/workspace/WorkspaceTracker.ts index d97d099..550de84 100644 --- a/src/integrations/workspace/WorkspaceTracker.ts +++ b/src/integrations/workspace/WorkspaceTracker.ts @@ -35,7 +35,7 @@ class WorkspaceTracker { watcher.onDidCreate(async (uri) => { await this.addFilePath(uri.fsPath) this.workspaceDidUpdate() - }) + }), ) // Renaming files triggers a delete and create event @@ -44,7 +44,7 @@ class WorkspaceTracker { if (await this.removeFilePath(uri.fsPath)) { this.workspaceDidUpdate() } - }) + }), ) this.disposables.push(watcher) @@ -64,7 +64,7 @@ class WorkspaceTracker { filePaths: Array.from(this.filePaths).map((file) => { const relativePath = path.relative(cwd, file).toPosix() return file.endsWith("/") ? relativePath + "/" : relativePath - }) + }), }) this.updateTimer = null }, 300) // Debounce for 300ms diff --git a/src/integrations/workspace/__tests__/WorkspaceTracker.test.ts b/src/integrations/workspace/__tests__/WorkspaceTracker.test.ts index e6c6767..44b5648 100644 --- a/src/integrations/workspace/__tests__/WorkspaceTracker.test.ts +++ b/src/integrations/workspace/__tests__/WorkspaceTracker.test.ts @@ -10,144 +10,146 @@ const mockOnDidChange = jest.fn() const mockDispose = jest.fn() const mockWatcher = { - onDidCreate: mockOnDidCreate.mockReturnValue({ dispose: mockDispose }), - onDidDelete: mockOnDidDelete.mockReturnValue({ dispose: mockDispose }), - dispose: mockDispose + onDidCreate: mockOnDidCreate.mockReturnValue({ dispose: mockDispose }), + onDidDelete: mockOnDidDelete.mockReturnValue({ dispose: mockDispose }), + dispose: mockDispose, } jest.mock("vscode", () => ({ - workspace: { - workspaceFolders: [{ - uri: { fsPath: "/test/workspace" }, - name: "test", - index: 0 - }], - createFileSystemWatcher: jest.fn(() => mockWatcher), - fs: { - stat: jest.fn().mockResolvedValue({ type: 1 }) // FileType.File = 1 - } - }, - FileType: { File: 1, Directory: 2 } + workspace: { + workspaceFolders: [ + { + uri: { fsPath: "/test/workspace" }, + name: "test", + index: 0, + }, + ], + createFileSystemWatcher: jest.fn(() => mockWatcher), + fs: { + stat: jest.fn().mockResolvedValue({ type: 1 }), // FileType.File = 1 + }, + }, + FileType: { File: 1, Directory: 2 }, })) jest.mock("../../../services/glob/list-files") describe("WorkspaceTracker", () => { - let workspaceTracker: WorkspaceTracker - let mockProvider: ClineProvider + let workspaceTracker: WorkspaceTracker + let mockProvider: ClineProvider - beforeEach(() => { - jest.clearAllMocks() - jest.useFakeTimers() + beforeEach(() => { + jest.clearAllMocks() + jest.useFakeTimers() - // Create provider mock - mockProvider = { - postMessageToWebview: jest.fn().mockResolvedValue(undefined) - } as unknown as ClineProvider & { postMessageToWebview: jest.Mock } + // Create provider mock + mockProvider = { + postMessageToWebview: jest.fn().mockResolvedValue(undefined), + } as unknown as ClineProvider & { postMessageToWebview: jest.Mock } - // Create tracker instance - workspaceTracker = new WorkspaceTracker(mockProvider) - }) + // Create tracker instance + workspaceTracker = new WorkspaceTracker(mockProvider) + }) - it("should initialize with workspace files", async () => { - const mockFiles = [["/test/workspace/file1.ts", "/test/workspace/file2.ts"], false] - ;(listFiles as jest.Mock).mockResolvedValue(mockFiles) - - await workspaceTracker.initializeFilePaths() - jest.runAllTimers() + it("should initialize with workspace files", async () => { + const mockFiles = [["/test/workspace/file1.ts", "/test/workspace/file2.ts"], false] + ;(listFiles as jest.Mock).mockResolvedValue(mockFiles) - expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({ - type: "workspaceUpdated", - filePaths: expect.arrayContaining(["file1.ts", "file2.ts"]) - }) - expect((mockProvider.postMessageToWebview as jest.Mock).mock.calls[0][0].filePaths).toHaveLength(2) - }) + await workspaceTracker.initializeFilePaths() + jest.runAllTimers() - it("should handle file creation events", async () => { - // Get the creation callback and call it - const [[callback]] = mockOnDidCreate.mock.calls - await callback({ fsPath: "/test/workspace/newfile.ts" }) - jest.runAllTimers() + expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "workspaceUpdated", + filePaths: expect.arrayContaining(["file1.ts", "file2.ts"]), + }) + expect((mockProvider.postMessageToWebview as jest.Mock).mock.calls[0][0].filePaths).toHaveLength(2) + }) - expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({ - type: "workspaceUpdated", - filePaths: ["newfile.ts"] - }) - }) + it("should handle file creation events", async () => { + // Get the creation callback and call it + const [[callback]] = mockOnDidCreate.mock.calls + await callback({ fsPath: "/test/workspace/newfile.ts" }) + jest.runAllTimers() - it("should handle file deletion events", async () => { - // First add a file - const [[createCallback]] = mockOnDidCreate.mock.calls - await createCallback({ fsPath: "/test/workspace/file.ts" }) - jest.runAllTimers() - - // Then delete it - const [[deleteCallback]] = mockOnDidDelete.mock.calls - await deleteCallback({ fsPath: "/test/workspace/file.ts" }) - jest.runAllTimers() + expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "workspaceUpdated", + filePaths: ["newfile.ts"], + }) + }) - // The last call should have empty filePaths - expect(mockProvider.postMessageToWebview).toHaveBeenLastCalledWith({ - type: "workspaceUpdated", - filePaths: [] - }) - }) + it("should handle file deletion events", async () => { + // First add a file + const [[createCallback]] = mockOnDidCreate.mock.calls + await createCallback({ fsPath: "/test/workspace/file.ts" }) + jest.runAllTimers() - it("should handle directory paths correctly", async () => { - // Mock stat to return directory type - ;(vscode.workspace.fs.stat as jest.Mock).mockResolvedValueOnce({ type: 2 }) // FileType.Directory = 2 - - const [[callback]] = mockOnDidCreate.mock.calls - await callback({ fsPath: "/test/workspace/newdir" }) - jest.runAllTimers() + // Then delete it + const [[deleteCallback]] = mockOnDidDelete.mock.calls + await deleteCallback({ fsPath: "/test/workspace/file.ts" }) + jest.runAllTimers() - expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({ - type: "workspaceUpdated", - filePaths: expect.arrayContaining(["newdir"]) - }) - const lastCall = (mockProvider.postMessageToWebview as jest.Mock).mock.calls.slice(-1)[0] - expect(lastCall[0].filePaths).toHaveLength(1) - }) + // The last call should have empty filePaths + expect(mockProvider.postMessageToWebview).toHaveBeenLastCalledWith({ + type: "workspaceUpdated", + filePaths: [], + }) + }) - it("should respect file limits", async () => { - // Create array of unique file paths for initial load - const files = Array.from({ length: 1001 }, (_, i) => `/test/workspace/file${i}.ts`) - ;(listFiles as jest.Mock).mockResolvedValue([files, false]) - - await workspaceTracker.initializeFilePaths() - jest.runAllTimers() + it("should handle directory paths correctly", async () => { + // Mock stat to return directory type + ;(vscode.workspace.fs.stat as jest.Mock).mockResolvedValueOnce({ type: 2 }) // FileType.Directory = 2 - // Should only have 1000 files initially - const expectedFiles = Array.from({ length: 1000 }, (_, i) => `file${i}.ts`).sort() - const calls = (mockProvider.postMessageToWebview as jest.Mock).mock.calls - - expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({ - type: "workspaceUpdated", - filePaths: expect.arrayContaining(expectedFiles) - }) - expect(calls[0][0].filePaths).toHaveLength(1000) + const [[callback]] = mockOnDidCreate.mock.calls + await callback({ fsPath: "/test/workspace/newdir" }) + jest.runAllTimers() - // Should allow adding up to 2000 total files - const [[callback]] = mockOnDidCreate.mock.calls - for (let i = 0; i < 1000; i++) { - await callback({ fsPath: `/test/workspace/extra${i}.ts` }) - } - jest.runAllTimers() + expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "workspaceUpdated", + filePaths: expect.arrayContaining(["newdir"]), + }) + const lastCall = (mockProvider.postMessageToWebview as jest.Mock).mock.calls.slice(-1)[0] + expect(lastCall[0].filePaths).toHaveLength(1) + }) - const lastCall = (mockProvider.postMessageToWebview as jest.Mock).mock.calls.slice(-1)[0] - expect(lastCall[0].filePaths).toHaveLength(2000) + it("should respect file limits", async () => { + // Create array of unique file paths for initial load + const files = Array.from({ length: 1001 }, (_, i) => `/test/workspace/file${i}.ts`) + ;(listFiles as jest.Mock).mockResolvedValue([files, false]) - // Adding one more file beyond 2000 should not increase the count - await callback({ fsPath: "/test/workspace/toomany.ts" }) - jest.runAllTimers() + await workspaceTracker.initializeFilePaths() + jest.runAllTimers() - const finalCall = (mockProvider.postMessageToWebview as jest.Mock).mock.calls.slice(-1)[0] - expect(finalCall[0].filePaths).toHaveLength(2000) - }) + // Should only have 1000 files initially + const expectedFiles = Array.from({ length: 1000 }, (_, i) => `file${i}.ts`).sort() + const calls = (mockProvider.postMessageToWebview as jest.Mock).mock.calls - it("should clean up watchers and timers on dispose", () => { - workspaceTracker.dispose() - expect(mockDispose).toHaveBeenCalled() - jest.runAllTimers() // Ensure any pending timers are cleared - }) -}) \ No newline at end of file + expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "workspaceUpdated", + filePaths: expect.arrayContaining(expectedFiles), + }) + expect(calls[0][0].filePaths).toHaveLength(1000) + + // Should allow adding up to 2000 total files + const [[callback]] = mockOnDidCreate.mock.calls + for (let i = 0; i < 1000; i++) { + await callback({ fsPath: `/test/workspace/extra${i}.ts` }) + } + jest.runAllTimers() + + const lastCall = (mockProvider.postMessageToWebview as jest.Mock).mock.calls.slice(-1)[0] + expect(lastCall[0].filePaths).toHaveLength(2000) + + // Adding one more file beyond 2000 should not increase the count + await callback({ fsPath: "/test/workspace/toomany.ts" }) + jest.runAllTimers() + + const finalCall = (mockProvider.postMessageToWebview as jest.Mock).mock.calls.slice(-1)[0] + expect(finalCall[0].filePaths).toHaveLength(2000) + }) + + it("should clean up watchers and timers on dispose", () => { + workspaceTracker.dispose() + expect(mockDispose).toHaveBeenCalled() + jest.runAllTimers() // Ensure any pending timers are cleared + }) +}) diff --git a/src/services/browser/BrowserSession.ts b/src/services/browser/BrowserSession.ts index 54e04b8..bed0332 100644 --- a/src/services/browser/BrowserSession.ts +++ b/src/services/browser/BrowserSession.ts @@ -136,7 +136,7 @@ export class BrowserSession { let screenshotBase64 = await this.page.screenshot({ ...options, type: "webp", - quality: (await this.context.globalState.get("screenshotQuality") as number | undefined) ?? 75, + quality: ((await this.context.globalState.get("screenshotQuality")) as number | undefined) ?? 75, }) let screenshot = `data:image/webp;base64,${screenshotBase64}` @@ -247,7 +247,7 @@ export class BrowserSession { } async scrollDown(): Promise { - const size = (await this.context.globalState.get("browserViewportSize") as string | undefined) || "900x600" + const size = ((await this.context.globalState.get("browserViewportSize")) as string | undefined) || "900x600" const height = parseInt(size.split("x")[1]) return this.doAction(async (page) => { await page.evaluate((scrollHeight) => { @@ -261,7 +261,7 @@ export class BrowserSession { } async scrollUp(): Promise { - const size = (await this.context.globalState.get("browserViewportSize") as string | undefined) || "900x600" + const size = ((await this.context.globalState.get("browserViewportSize")) as string | undefined) || "900x600" const height = parseInt(size.split("x")[1]) return this.doAction(async (page) => { await page.evaluate((scrollHeight) => { diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index 9004a78..b13851f 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -40,11 +40,11 @@ const StdioConfigSchema = z.object({ args: z.array(z.string()).optional(), env: z.record(z.string()).optional(), alwaysAllow: AlwaysAllowSchema.optional(), - disabled: z.boolean().optional() + disabled: z.boolean().optional(), }) const McpSettingsSchema = z.object({ - mcpServers: z.record(StdioConfigSchema) + mcpServers: z.record(StdioConfigSchema), }) export class McpHub { @@ -63,9 +63,7 @@ export class McpHub { getServers(): McpServer[] { // Only return enabled servers - return this.connections - .filter((conn) => !conn.server.disabled) - .map((conn) => conn.server) + return this.connections.filter((conn) => !conn.server.disabled).map((conn) => conn.server) } async getMcpServersPath(): Promise { @@ -300,9 +298,9 @@ export class McpHub { const alwaysAllowConfig = config.mcpServers[serverName]?.alwaysAllow || [] // Mark tools as always allowed based on settings - const tools = (response?.tools || []).map(tool => ({ + const tools = (response?.tools || []).map((tool) => ({ ...tool, - alwaysAllow: alwaysAllowConfig.includes(tool.name) + alwaysAllow: alwaysAllowConfig.includes(tool.name), })) console.log(`[MCP] Fetched tools for ${serverName}:`, tools) @@ -471,28 +469,28 @@ export class McpHub { } // Public methods for server management - + public async toggleServerDisabled(serverName: string, disabled: boolean): Promise { let settingsPath: string try { settingsPath = await this.getMcpSettingsFilePath() - + // Ensure the settings file exists and is accessible try { await fs.access(settingsPath) } catch (error) { - console.error('Settings file not accessible:', error) - throw new Error('Settings file not accessible') + console.error("Settings file not accessible:", error) + throw new Error("Settings file not accessible") } const content = await fs.readFile(settingsPath, "utf-8") const config = JSON.parse(content) // Validate the config structure - if (!config || typeof config !== 'object') { - throw new Error('Invalid config structure') + if (!config || typeof config !== "object") { + throw new Error("Invalid config structure") } - - if (!config.mcpServers || typeof config.mcpServers !== 'object') { + + if (!config.mcpServers || typeof config.mcpServers !== "object") { config.mcpServers = {} } @@ -500,28 +498,28 @@ export class McpHub { // Create a new server config object to ensure clean structure const serverConfig = { ...config.mcpServers[serverName], - disabled + disabled, } - + // Ensure required fields exist if (!serverConfig.alwaysAllow) { serverConfig.alwaysAllow = [] } - + config.mcpServers[serverName] = serverConfig - + // Write the entire config back const updatedConfig = { - mcpServers: config.mcpServers + mcpServers: config.mcpServers, } - + await fs.writeFile(settingsPath, JSON.stringify(updatedConfig, null, 2)) - const connection = this.connections.find(conn => conn.server.name === serverName) + const connection = this.connections.find((conn) => conn.server.name === serverName) if (connection) { try { connection.server.disabled = disabled - + // Only refresh capabilities if connected if (connection.server.status === "connected") { connection.server.tools = await this.fetchToolsList(serverName) @@ -540,7 +538,9 @@ export class McpHub { if (error instanceof Error) { console.error("Error details:", error.message, error.stack) } - vscode.window.showErrorMessage(`Failed to update server state: ${error instanceof Error ? error.message : String(error)}`) + vscode.window.showErrorMessage( + `Failed to update server state: ${error instanceof Error ? error.message : String(error)}`, + ) throw error } } @@ -617,12 +617,11 @@ export class McpHub { await fs.writeFile(settingsPath, JSON.stringify(config, null, 2)) // Update the tools list to reflect the change - const connection = this.connections.find(conn => conn.server.name === serverName) + const connection = this.connections.find((conn) => conn.server.name === serverName) if (connection) { connection.server.tools = await this.fetchToolsList(serverName) await this.notifyWebviewOfServerChanges() } - } catch (error) { console.error("Failed to update always allow settings:", error) vscode.window.showErrorMessage("Failed to update always allow settings") diff --git a/src/services/mcp/__tests__/McpHub.test.ts b/src/services/mcp/__tests__/McpHub.test.ts index cd63e29..dd46183 100644 --- a/src/services/mcp/__tests__/McpHub.test.ts +++ b/src/services/mcp/__tests__/McpHub.test.ts @@ -1,290 +1,292 @@ -import type { McpHub as McpHubType } from '../McpHub' -import type { ClineProvider } from '../../../core/webview/ClineProvider' -import type { ExtensionContext, Uri } from 'vscode' -import type { McpConnection } from '../McpHub' +import type { McpHub as McpHubType } from "../McpHub" +import type { ClineProvider } from "../../../core/webview/ClineProvider" +import type { ExtensionContext, Uri } from "vscode" +import type { McpConnection } from "../McpHub" -const vscode = require('vscode') -const fs = require('fs/promises') -const { McpHub } = require('../McpHub') +const vscode = require("vscode") +const fs = require("fs/promises") +const { McpHub } = require("../McpHub") -jest.mock('vscode') -jest.mock('fs/promises') -jest.mock('../../../core/webview/ClineProvider') +jest.mock("vscode") +jest.mock("fs/promises") +jest.mock("../../../core/webview/ClineProvider") -describe('McpHub', () => { - let mcpHub: McpHubType - let mockProvider: Partial - const mockSettingsPath = '/mock/settings/path/cline_mcp_settings.json' +describe("McpHub", () => { + let mcpHub: McpHubType + let mockProvider: Partial + const mockSettingsPath = "/mock/settings/path/cline_mcp_settings.json" - beforeEach(() => { - jest.clearAllMocks() + beforeEach(() => { + jest.clearAllMocks() - const mockUri: Uri = { - scheme: 'file', - authority: '', - path: '/test/path', - query: '', - fragment: '', - fsPath: '/test/path', - with: jest.fn(), - toJSON: jest.fn() - } + const mockUri: Uri = { + scheme: "file", + authority: "", + path: "/test/path", + query: "", + fragment: "", + fsPath: "/test/path", + with: jest.fn(), + toJSON: jest.fn(), + } - mockProvider = { - ensureSettingsDirectoryExists: jest.fn().mockResolvedValue('/mock/settings/path'), - ensureMcpServersDirectoryExists: jest.fn().mockResolvedValue('/mock/settings/path'), - postMessageToWebview: jest.fn(), - context: { - subscriptions: [], - workspaceState: {} as any, - globalState: {} as any, - secrets: {} as any, - extensionUri: mockUri, - extensionPath: '/test/path', - storagePath: '/test/storage', - globalStoragePath: '/test/global-storage', - environmentVariableCollection: {} as any, - extension: { - id: 'test-extension', - extensionUri: mockUri, - extensionPath: '/test/path', - extensionKind: 1, - isActive: true, - packageJSON: { - version: '1.0.0' - }, - activate: jest.fn(), - exports: undefined - } as any, - asAbsolutePath: (path: string) => path, - storageUri: mockUri, - globalStorageUri: mockUri, - logUri: mockUri, - extensionMode: 1, - logPath: '/test/path', - languageModelAccessInformation: {} as any - } as ExtensionContext - } + mockProvider = { + ensureSettingsDirectoryExists: jest.fn().mockResolvedValue("/mock/settings/path"), + ensureMcpServersDirectoryExists: jest.fn().mockResolvedValue("/mock/settings/path"), + postMessageToWebview: jest.fn(), + context: { + subscriptions: [], + workspaceState: {} as any, + globalState: {} as any, + secrets: {} as any, + extensionUri: mockUri, + extensionPath: "/test/path", + storagePath: "/test/storage", + globalStoragePath: "/test/global-storage", + environmentVariableCollection: {} as any, + extension: { + id: "test-extension", + extensionUri: mockUri, + extensionPath: "/test/path", + extensionKind: 1, + isActive: true, + packageJSON: { + version: "1.0.0", + }, + activate: jest.fn(), + exports: undefined, + } as any, + asAbsolutePath: (path: string) => path, + storageUri: mockUri, + globalStorageUri: mockUri, + logUri: mockUri, + extensionMode: 1, + logPath: "/test/path", + languageModelAccessInformation: {} as any, + } as ExtensionContext, + } - // Mock fs.readFile for initial settings - ;(fs.readFile as jest.Mock).mockResolvedValue(JSON.stringify({ - mcpServers: { - 'test-server': { - command: 'node', - args: ['test.js'], - alwaysAllow: ['allowed-tool'] - } - } - })) + // Mock fs.readFile for initial settings + ;(fs.readFile as jest.Mock).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "test-server": { + command: "node", + args: ["test.js"], + alwaysAllow: ["allowed-tool"], + }, + }, + }), + ) - mcpHub = new McpHub(mockProvider as ClineProvider) - }) + mcpHub = new McpHub(mockProvider as ClineProvider) + }) - describe('toggleToolAlwaysAllow', () => { - it('should add tool to always allow list when enabling', async () => { - const mockConfig = { - mcpServers: { - 'test-server': { - command: 'node', - args: ['test.js'], - alwaysAllow: [] - } - } - } + describe("toggleToolAlwaysAllow", () => { + it("should add tool to always allow list when enabling", async () => { + const mockConfig = { + mcpServers: { + "test-server": { + command: "node", + args: ["test.js"], + alwaysAllow: [], + }, + }, + } - // Mock reading initial config - ;(fs.readFile as jest.Mock).mockResolvedValueOnce(JSON.stringify(mockConfig)) + // Mock reading initial config + ;(fs.readFile as jest.Mock).mockResolvedValueOnce(JSON.stringify(mockConfig)) - await mcpHub.toggleToolAlwaysAllow('test-server', 'new-tool', true) + await mcpHub.toggleToolAlwaysAllow("test-server", "new-tool", true) - // Verify the config was updated correctly - const writeCall = (fs.writeFile as jest.Mock).mock.calls[0] - const writtenConfig = JSON.parse(writeCall[1]) - expect(writtenConfig.mcpServers['test-server'].alwaysAllow).toContain('new-tool') - }) + // Verify the config was updated correctly + const writeCall = (fs.writeFile as jest.Mock).mock.calls[0] + const writtenConfig = JSON.parse(writeCall[1]) + expect(writtenConfig.mcpServers["test-server"].alwaysAllow).toContain("new-tool") + }) - it('should remove tool from always allow list when disabling', async () => { - const mockConfig = { - mcpServers: { - 'test-server': { - command: 'node', - args: ['test.js'], - alwaysAllow: ['existing-tool'] - } - } - } + it("should remove tool from always allow list when disabling", async () => { + const mockConfig = { + mcpServers: { + "test-server": { + command: "node", + args: ["test.js"], + alwaysAllow: ["existing-tool"], + }, + }, + } - // Mock reading initial config - ;(fs.readFile as jest.Mock).mockResolvedValueOnce(JSON.stringify(mockConfig)) + // Mock reading initial config + ;(fs.readFile as jest.Mock).mockResolvedValueOnce(JSON.stringify(mockConfig)) - await mcpHub.toggleToolAlwaysAllow('test-server', 'existing-tool', false) + await mcpHub.toggleToolAlwaysAllow("test-server", "existing-tool", false) - // Verify the config was updated correctly - const writeCall = (fs.writeFile as jest.Mock).mock.calls[0] - const writtenConfig = JSON.parse(writeCall[1]) - expect(writtenConfig.mcpServers['test-server'].alwaysAllow).not.toContain('existing-tool') - }) + // Verify the config was updated correctly + const writeCall = (fs.writeFile as jest.Mock).mock.calls[0] + const writtenConfig = JSON.parse(writeCall[1]) + expect(writtenConfig.mcpServers["test-server"].alwaysAllow).not.toContain("existing-tool") + }) - it('should initialize alwaysAllow if it does not exist', async () => { - const mockConfig = { - mcpServers: { - 'test-server': { - command: 'node', - args: ['test.js'] - } - } - } + it("should initialize alwaysAllow if it does not exist", async () => { + const mockConfig = { + mcpServers: { + "test-server": { + command: "node", + args: ["test.js"], + }, + }, + } - // Mock reading initial config - ;(fs.readFile as jest.Mock).mockResolvedValueOnce(JSON.stringify(mockConfig)) + // Mock reading initial config + ;(fs.readFile as jest.Mock).mockResolvedValueOnce(JSON.stringify(mockConfig)) - await mcpHub.toggleToolAlwaysAllow('test-server', 'new-tool', true) + await mcpHub.toggleToolAlwaysAllow("test-server", "new-tool", true) - // Verify the config was updated with initialized alwaysAllow - const writeCall = (fs.writeFile as jest.Mock).mock.calls[0] - const writtenConfig = JSON.parse(writeCall[1]) - expect(writtenConfig.mcpServers['test-server'].alwaysAllow).toBeDefined() - expect(writtenConfig.mcpServers['test-server'].alwaysAllow).toContain('new-tool') - }) - }) + // Verify the config was updated with initialized alwaysAllow + const writeCall = (fs.writeFile as jest.Mock).mock.calls[0] + const writtenConfig = JSON.parse(writeCall[1]) + expect(writtenConfig.mcpServers["test-server"].alwaysAllow).toBeDefined() + expect(writtenConfig.mcpServers["test-server"].alwaysAllow).toContain("new-tool") + }) + }) - describe('server disabled state', () => { - it('should toggle server disabled state', async () => { - const mockConfig = { - mcpServers: { - 'test-server': { - command: 'node', - args: ['test.js'], - disabled: false - } - } - } + describe("server disabled state", () => { + it("should toggle server disabled state", async () => { + const mockConfig = { + mcpServers: { + "test-server": { + command: "node", + args: ["test.js"], + disabled: false, + }, + }, + } - // Mock reading initial config - ;(fs.readFile as jest.Mock).mockResolvedValueOnce(JSON.stringify(mockConfig)) + // Mock reading initial config + ;(fs.readFile as jest.Mock).mockResolvedValueOnce(JSON.stringify(mockConfig)) - await mcpHub.toggleServerDisabled('test-server', true) + await mcpHub.toggleServerDisabled("test-server", true) - // Verify the config was updated correctly - const writeCall = (fs.writeFile as jest.Mock).mock.calls[0] - const writtenConfig = JSON.parse(writeCall[1]) - expect(writtenConfig.mcpServers['test-server'].disabled).toBe(true) - }) + // Verify the config was updated correctly + const writeCall = (fs.writeFile as jest.Mock).mock.calls[0] + const writtenConfig = JSON.parse(writeCall[1]) + expect(writtenConfig.mcpServers["test-server"].disabled).toBe(true) + }) - it('should filter out disabled servers from getServers', () => { - const mockConnections: McpConnection[] = [ - { - server: { - name: 'enabled-server', - config: '{}', - status: 'connected', - disabled: false - }, - client: {} as any, - transport: {} as any - }, - { - server: { - name: 'disabled-server', - config: '{}', - status: 'connected', - disabled: true - }, - client: {} as any, - transport: {} as any - } - ] + it("should filter out disabled servers from getServers", () => { + const mockConnections: McpConnection[] = [ + { + server: { + name: "enabled-server", + config: "{}", + status: "connected", + disabled: false, + }, + client: {} as any, + transport: {} as any, + }, + { + server: { + name: "disabled-server", + config: "{}", + status: "connected", + disabled: true, + }, + client: {} as any, + transport: {} as any, + }, + ] - mcpHub.connections = mockConnections - const servers = mcpHub.getServers() + mcpHub.connections = mockConnections + const servers = mcpHub.getServers() - expect(servers.length).toBe(1) - expect(servers[0].name).toBe('enabled-server') - }) + expect(servers.length).toBe(1) + expect(servers[0].name).toBe("enabled-server") + }) - it('should prevent calling tools on disabled servers', async () => { - const mockConnection: McpConnection = { - server: { - name: 'disabled-server', - config: '{}', - status: 'connected', - disabled: true - }, - client: { - request: jest.fn().mockResolvedValue({ result: 'success' }) - } as any, - transport: {} as any - } + it("should prevent calling tools on disabled servers", async () => { + const mockConnection: McpConnection = { + server: { + name: "disabled-server", + config: "{}", + status: "connected", + disabled: true, + }, + client: { + request: jest.fn().mockResolvedValue({ result: "success" }), + } as any, + transport: {} as any, + } - mcpHub.connections = [mockConnection] + mcpHub.connections = [mockConnection] - await expect(mcpHub.callTool('disabled-server', 'some-tool', {})) - .rejects - .toThrow('Server "disabled-server" is disabled and cannot be used') - }) + await expect(mcpHub.callTool("disabled-server", "some-tool", {})).rejects.toThrow( + 'Server "disabled-server" is disabled and cannot be used', + ) + }) - it('should prevent reading resources from disabled servers', async () => { - const mockConnection: McpConnection = { - server: { - name: 'disabled-server', - config: '{}', - status: 'connected', - disabled: true - }, - client: { - request: jest.fn() - } as any, - transport: {} as any - } + it("should prevent reading resources from disabled servers", async () => { + const mockConnection: McpConnection = { + server: { + name: "disabled-server", + config: "{}", + status: "connected", + disabled: true, + }, + client: { + request: jest.fn(), + } as any, + transport: {} as any, + } - mcpHub.connections = [mockConnection] + mcpHub.connections = [mockConnection] - await expect(mcpHub.readResource('disabled-server', 'some/uri')) - .rejects - .toThrow('Server "disabled-server" is disabled') - }) - }) + await expect(mcpHub.readResource("disabled-server", "some/uri")).rejects.toThrow( + 'Server "disabled-server" is disabled', + ) + }) + }) - describe('callTool', () => { - it('should execute tool successfully', async () => { - // Mock the connection with a minimal client implementation - const mockConnection: McpConnection = { - server: { - name: 'test-server', - config: JSON.stringify({}), - status: 'connected' as const - }, - client: { - request: jest.fn().mockResolvedValue({ result: 'success' }) - } as any, - transport: { - start: jest.fn(), - close: jest.fn(), - stderr: { on: jest.fn() } - } as any - } + describe("callTool", () => { + it("should execute tool successfully", async () => { + // Mock the connection with a minimal client implementation + const mockConnection: McpConnection = { + server: { + name: "test-server", + config: JSON.stringify({}), + status: "connected" as const, + }, + client: { + request: jest.fn().mockResolvedValue({ result: "success" }), + } as any, + transport: { + start: jest.fn(), + close: jest.fn(), + stderr: { on: jest.fn() }, + } as any, + } - mcpHub.connections = [mockConnection] + mcpHub.connections = [mockConnection] - await mcpHub.callTool('test-server', 'some-tool', {}) + await mcpHub.callTool("test-server", "some-tool", {}) - // Verify the request was made with correct parameters - expect(mockConnection.client.request).toHaveBeenCalledWith( - { - method: 'tools/call', - params: { - name: 'some-tool', - arguments: {} - } - }, - expect.any(Object) - ) - }) + // Verify the request was made with correct parameters + expect(mockConnection.client.request).toHaveBeenCalledWith( + { + method: "tools/call", + params: { + name: "some-tool", + arguments: {}, + }, + }, + expect.any(Object), + ) + }) - it('should throw error if server not found', async () => { - await expect(mcpHub.callTool('non-existent-server', 'some-tool', {})) - .rejects - .toThrow('No connection found for server: non-existent-server') - }) - }) -}) \ No newline at end of file + it("should throw error if server not found", async () => { + await expect(mcpHub.callTool("non-existent-server", "some-tool", {})).rejects.toThrow( + "No connection found for server: non-existent-server", + ) + }) + }) +}) diff --git a/src/services/tree-sitter/__tests__/index.test.ts b/src/services/tree-sitter/__tests__/index.test.ts index 614cccf..4a5782d 100644 --- a/src/services/tree-sitter/__tests__/index.test.ts +++ b/src/services/tree-sitter/__tests__/index.test.ts @@ -1,254 +1,246 @@ -import { parseSourceCodeForDefinitionsTopLevel } from '../index'; -import { listFiles } from '../../glob/list-files'; -import { loadRequiredLanguageParsers } from '../languageParser'; -import { fileExistsAtPath } from '../../../utils/fs'; -import * as fs from 'fs/promises'; -import * as path from 'path'; +import { parseSourceCodeForDefinitionsTopLevel } from "../index" +import { listFiles } from "../../glob/list-files" +import { loadRequiredLanguageParsers } from "../languageParser" +import { fileExistsAtPath } from "../../../utils/fs" +import * as fs from "fs/promises" +import * as path from "path" // Mock dependencies -jest.mock('../../glob/list-files'); -jest.mock('../languageParser'); -jest.mock('../../../utils/fs'); -jest.mock('fs/promises'); +jest.mock("../../glob/list-files") +jest.mock("../languageParser") +jest.mock("../../../utils/fs") +jest.mock("fs/promises") -describe('Tree-sitter Service', () => { - beforeEach(() => { - jest.clearAllMocks(); - (fileExistsAtPath as jest.Mock).mockResolvedValue(true); - }); +describe("Tree-sitter Service", () => { + beforeEach(() => { + jest.clearAllMocks() + ;(fileExistsAtPath as jest.Mock).mockResolvedValue(true) + }) - describe('parseSourceCodeForDefinitionsTopLevel', () => { - it('should handle non-existent directory', async () => { - (fileExistsAtPath as jest.Mock).mockResolvedValue(false); - - const result = await parseSourceCodeForDefinitionsTopLevel('/non/existent/path'); - expect(result).toBe('This directory does not exist or you do not have permission to access it.'); - }); + describe("parseSourceCodeForDefinitionsTopLevel", () => { + it("should handle non-existent directory", async () => { + ;(fileExistsAtPath as jest.Mock).mockResolvedValue(false) - it('should handle empty directory', async () => { - (listFiles as jest.Mock).mockResolvedValue([[], new Set()]); - - const result = await parseSourceCodeForDefinitionsTopLevel('/test/path'); - expect(result).toBe('No source code definitions found.'); - }); + const result = await parseSourceCodeForDefinitionsTopLevel("/non/existent/path") + expect(result).toBe("This directory does not exist or you do not have permission to access it.") + }) - it('should parse TypeScript files correctly', async () => { - const mockFiles = [ - '/test/path/file1.ts', - '/test/path/file2.tsx', - '/test/path/readme.md' - ]; - - (listFiles as jest.Mock).mockResolvedValue([mockFiles, new Set()]); - - const mockParser = { - parse: jest.fn().mockReturnValue({ - rootNode: 'mockNode' - }) - }; - - const mockQuery = { - captures: jest.fn().mockReturnValue([ - { - node: { - startPosition: { row: 0 }, - endPosition: { row: 0 } - }, - name: 'name.definition' - } - ]) - }; + it("should handle empty directory", async () => { + ;(listFiles as jest.Mock).mockResolvedValue([[], new Set()]) - (loadRequiredLanguageParsers as jest.Mock).mockResolvedValue({ - ts: { parser: mockParser, query: mockQuery }, - tsx: { parser: mockParser, query: mockQuery } - }); + const result = await parseSourceCodeForDefinitionsTopLevel("/test/path") + expect(result).toBe("No source code definitions found.") + }) - (fs.readFile as jest.Mock).mockResolvedValue( - 'export class TestClass {\n constructor() {}\n}' - ); + it("should parse TypeScript files correctly", async () => { + const mockFiles = ["/test/path/file1.ts", "/test/path/file2.tsx", "/test/path/readme.md"] - const result = await parseSourceCodeForDefinitionsTopLevel('/test/path'); - - expect(result).toContain('file1.ts'); - expect(result).toContain('file2.tsx'); - expect(result).not.toContain('readme.md'); - expect(result).toContain('export class TestClass'); - }); + ;(listFiles as jest.Mock).mockResolvedValue([mockFiles, new Set()]) - it('should handle multiple definition types', async () => { - const mockFiles = ['/test/path/file.ts']; - (listFiles as jest.Mock).mockResolvedValue([mockFiles, new Set()]); - - const mockParser = { - parse: jest.fn().mockReturnValue({ - rootNode: 'mockNode' - }) - }; - - const mockQuery = { - captures: jest.fn().mockReturnValue([ - { - node: { - startPosition: { row: 0 }, - endPosition: { row: 0 } - }, - name: 'name.definition.class' - }, - { - node: { - startPosition: { row: 2 }, - endPosition: { row: 2 } - }, - name: 'name.definition.function' - } - ]) - }; + const mockParser = { + parse: jest.fn().mockReturnValue({ + rootNode: "mockNode", + }), + } - (loadRequiredLanguageParsers as jest.Mock).mockResolvedValue({ - ts: { parser: mockParser, query: mockQuery } - }); + const mockQuery = { + captures: jest.fn().mockReturnValue([ + { + node: { + startPosition: { row: 0 }, + endPosition: { row: 0 }, + }, + name: "name.definition", + }, + ]), + } - const fileContent = - 'class TestClass {\n' + - ' constructor() {}\n' + - ' testMethod() {}\n' + - '}'; - - (fs.readFile as jest.Mock).mockResolvedValue(fileContent); + ;(loadRequiredLanguageParsers as jest.Mock).mockResolvedValue({ + ts: { parser: mockParser, query: mockQuery }, + tsx: { parser: mockParser, query: mockQuery }, + }) + ;(fs.readFile as jest.Mock).mockResolvedValue("export class TestClass {\n constructor() {}\n}") - const result = await parseSourceCodeForDefinitionsTopLevel('/test/path'); - - expect(result).toContain('class TestClass'); - expect(result).toContain('testMethod()'); - expect(result).toContain('|----'); - }); + const result = await parseSourceCodeForDefinitionsTopLevel("/test/path") - it('should handle parsing errors gracefully', async () => { - const mockFiles = ['/test/path/file.ts']; - (listFiles as jest.Mock).mockResolvedValue([mockFiles, new Set()]); - - const mockParser = { - parse: jest.fn().mockImplementation(() => { - throw new Error('Parsing error'); - }) - }; - - const mockQuery = { - captures: jest.fn() - }; + expect(result).toContain("file1.ts") + expect(result).toContain("file2.tsx") + expect(result).not.toContain("readme.md") + expect(result).toContain("export class TestClass") + }) - (loadRequiredLanguageParsers as jest.Mock).mockResolvedValue({ - ts: { parser: mockParser, query: mockQuery } - }); + it("should handle multiple definition types", async () => { + const mockFiles = ["/test/path/file.ts"] + ;(listFiles as jest.Mock).mockResolvedValue([mockFiles, new Set()]) - (fs.readFile as jest.Mock).mockResolvedValue('invalid code'); + const mockParser = { + parse: jest.fn().mockReturnValue({ + rootNode: "mockNode", + }), + } - const result = await parseSourceCodeForDefinitionsTopLevel('/test/path'); - expect(result).toBe('No source code definitions found.'); - }); + const mockQuery = { + captures: jest.fn().mockReturnValue([ + { + node: { + startPosition: { row: 0 }, + endPosition: { row: 0 }, + }, + name: "name.definition.class", + }, + { + node: { + startPosition: { row: 2 }, + endPosition: { row: 2 }, + }, + name: "name.definition.function", + }, + ]), + } - it('should respect file limit', async () => { - const mockFiles = Array(100).fill(0).map((_, i) => `/test/path/file${i}.ts`); - (listFiles as jest.Mock).mockResolvedValue([mockFiles, new Set()]); - - const mockParser = { - parse: jest.fn().mockReturnValue({ - rootNode: 'mockNode' - }) - }; - - const mockQuery = { - captures: jest.fn().mockReturnValue([]) - }; + ;(loadRequiredLanguageParsers as jest.Mock).mockResolvedValue({ + ts: { parser: mockParser, query: mockQuery }, + }) - (loadRequiredLanguageParsers as jest.Mock).mockResolvedValue({ - ts: { parser: mockParser, query: mockQuery } - }); + const fileContent = "class TestClass {\n" + " constructor() {}\n" + " testMethod() {}\n" + "}" - await parseSourceCodeForDefinitionsTopLevel('/test/path'); - - // Should only process first 50 files - expect(mockParser.parse).toHaveBeenCalledTimes(50); - }); + ;(fs.readFile as jest.Mock).mockResolvedValue(fileContent) - it('should handle various supported file extensions', async () => { - const mockFiles = [ - '/test/path/script.js', - '/test/path/app.py', - '/test/path/main.rs', - '/test/path/program.cpp', - '/test/path/code.go' - ]; - - (listFiles as jest.Mock).mockResolvedValue([mockFiles, new Set()]); - - const mockParser = { - parse: jest.fn().mockReturnValue({ - rootNode: 'mockNode' - }) - }; - - const mockQuery = { - captures: jest.fn().mockReturnValue([{ - node: { - startPosition: { row: 0 }, - endPosition: { row: 0 } - }, - name: 'name' - }]) - }; + const result = await parseSourceCodeForDefinitionsTopLevel("/test/path") - (loadRequiredLanguageParsers as jest.Mock).mockResolvedValue({ - js: { parser: mockParser, query: mockQuery }, - py: { parser: mockParser, query: mockQuery }, - rs: { parser: mockParser, query: mockQuery }, - cpp: { parser: mockParser, query: mockQuery }, - go: { parser: mockParser, query: mockQuery } - }); + expect(result).toContain("class TestClass") + expect(result).toContain("testMethod()") + expect(result).toContain("|----") + }) - (fs.readFile as jest.Mock).mockResolvedValue('function test() {}'); + it("should handle parsing errors gracefully", async () => { + const mockFiles = ["/test/path/file.ts"] + ;(listFiles as jest.Mock).mockResolvedValue([mockFiles, new Set()]) - const result = await parseSourceCodeForDefinitionsTopLevel('/test/path'); - - expect(result).toContain('script.js'); - expect(result).toContain('app.py'); - expect(result).toContain('main.rs'); - expect(result).toContain('program.cpp'); - expect(result).toContain('code.go'); - }); + const mockParser = { + parse: jest.fn().mockImplementation(() => { + throw new Error("Parsing error") + }), + } - it('should normalize paths in output', async () => { - const mockFiles = ['/test/path/dir\\file.ts']; - (listFiles as jest.Mock).mockResolvedValue([mockFiles, new Set()]); - - const mockParser = { - parse: jest.fn().mockReturnValue({ - rootNode: 'mockNode' - }) - }; - - const mockQuery = { - captures: jest.fn().mockReturnValue([{ - node: { - startPosition: { row: 0 }, - endPosition: { row: 0 } - }, - name: 'name' - }]) - }; + const mockQuery = { + captures: jest.fn(), + } - (loadRequiredLanguageParsers as jest.Mock).mockResolvedValue({ - ts: { parser: mockParser, query: mockQuery } - }); + ;(loadRequiredLanguageParsers as jest.Mock).mockResolvedValue({ + ts: { parser: mockParser, query: mockQuery }, + }) + ;(fs.readFile as jest.Mock).mockResolvedValue("invalid code") - (fs.readFile as jest.Mock).mockResolvedValue('class Test {}'); + const result = await parseSourceCodeForDefinitionsTopLevel("/test/path") + expect(result).toBe("No source code definitions found.") + }) - const result = await parseSourceCodeForDefinitionsTopLevel('/test/path'); - - // Should use forward slashes regardless of platform - expect(result).toContain('dir/file.ts'); - expect(result).not.toContain('dir\\file.ts'); - }); - }); -}); \ No newline at end of file + it("should respect file limit", async () => { + const mockFiles = Array(100) + .fill(0) + .map((_, i) => `/test/path/file${i}.ts`) + ;(listFiles as jest.Mock).mockResolvedValue([mockFiles, new Set()]) + + const mockParser = { + parse: jest.fn().mockReturnValue({ + rootNode: "mockNode", + }), + } + + const mockQuery = { + captures: jest.fn().mockReturnValue([]), + } + + ;(loadRequiredLanguageParsers as jest.Mock).mockResolvedValue({ + ts: { parser: mockParser, query: mockQuery }, + }) + + await parseSourceCodeForDefinitionsTopLevel("/test/path") + + // Should only process first 50 files + expect(mockParser.parse).toHaveBeenCalledTimes(50) + }) + + it("should handle various supported file extensions", async () => { + const mockFiles = [ + "/test/path/script.js", + "/test/path/app.py", + "/test/path/main.rs", + "/test/path/program.cpp", + "/test/path/code.go", + ] + + ;(listFiles as jest.Mock).mockResolvedValue([mockFiles, new Set()]) + + const mockParser = { + parse: jest.fn().mockReturnValue({ + rootNode: "mockNode", + }), + } + + const mockQuery = { + captures: jest.fn().mockReturnValue([ + { + node: { + startPosition: { row: 0 }, + endPosition: { row: 0 }, + }, + name: "name", + }, + ]), + } + + ;(loadRequiredLanguageParsers as jest.Mock).mockResolvedValue({ + js: { parser: mockParser, query: mockQuery }, + py: { parser: mockParser, query: mockQuery }, + rs: { parser: mockParser, query: mockQuery }, + cpp: { parser: mockParser, query: mockQuery }, + go: { parser: mockParser, query: mockQuery }, + }) + ;(fs.readFile as jest.Mock).mockResolvedValue("function test() {}") + + const result = await parseSourceCodeForDefinitionsTopLevel("/test/path") + + expect(result).toContain("script.js") + expect(result).toContain("app.py") + expect(result).toContain("main.rs") + expect(result).toContain("program.cpp") + expect(result).toContain("code.go") + }) + + it("should normalize paths in output", async () => { + const mockFiles = ["/test/path/dir\\file.ts"] + ;(listFiles as jest.Mock).mockResolvedValue([mockFiles, new Set()]) + + const mockParser = { + parse: jest.fn().mockReturnValue({ + rootNode: "mockNode", + }), + } + + const mockQuery = { + captures: jest.fn().mockReturnValue([ + { + node: { + startPosition: { row: 0 }, + endPosition: { row: 0 }, + }, + name: "name", + }, + ]), + } + + ;(loadRequiredLanguageParsers as jest.Mock).mockResolvedValue({ + ts: { parser: mockParser, query: mockQuery }, + }) + ;(fs.readFile as jest.Mock).mockResolvedValue("class Test {}") + + const result = await parseSourceCodeForDefinitionsTopLevel("/test/path") + + // Should use forward slashes regardless of platform + expect(result).toContain("dir/file.ts") + expect(result).not.toContain("dir\\file.ts") + }) + }) +}) diff --git a/src/services/tree-sitter/__tests__/languageParser.test.ts b/src/services/tree-sitter/__tests__/languageParser.test.ts index 538a2eb..1b92d81 100644 --- a/src/services/tree-sitter/__tests__/languageParser.test.ts +++ b/src/services/tree-sitter/__tests__/languageParser.test.ts @@ -1,128 +1,118 @@ -import { loadRequiredLanguageParsers } from '../languageParser'; -import Parser from 'web-tree-sitter'; +import { loadRequiredLanguageParsers } from "../languageParser" +import Parser from "web-tree-sitter" // Mock web-tree-sitter -const mockSetLanguage = jest.fn(); -jest.mock('web-tree-sitter', () => { - return { - __esModule: true, - default: jest.fn().mockImplementation(() => ({ - setLanguage: mockSetLanguage - })) - }; -}); +const mockSetLanguage = jest.fn() +jest.mock("web-tree-sitter", () => { + return { + __esModule: true, + default: jest.fn().mockImplementation(() => ({ + setLanguage: mockSetLanguage, + })), + } +}) // Add static methods to Parser mock -const ParserMock = Parser as jest.MockedClass; -ParserMock.init = jest.fn().mockResolvedValue(undefined); +const ParserMock = Parser as jest.MockedClass +ParserMock.init = jest.fn().mockResolvedValue(undefined) ParserMock.Language = { - load: jest.fn().mockResolvedValue({ - query: jest.fn().mockReturnValue('mockQuery') - }), - prototype: {} // Add required prototype property -} as unknown as typeof Parser.Language; + load: jest.fn().mockResolvedValue({ + query: jest.fn().mockReturnValue("mockQuery"), + }), + prototype: {}, // Add required prototype property +} as unknown as typeof Parser.Language -describe('Language Parser', () => { - beforeEach(() => { - jest.clearAllMocks(); - }); +describe("Language Parser", () => { + beforeEach(() => { + jest.clearAllMocks() + }) - describe('loadRequiredLanguageParsers', () => { - it('should initialize parser only once', async () => { - const files = ['test.js', 'test2.js']; - await loadRequiredLanguageParsers(files); - await loadRequiredLanguageParsers(files); - - expect(ParserMock.init).toHaveBeenCalledTimes(1); - }); + describe("loadRequiredLanguageParsers", () => { + it("should initialize parser only once", async () => { + const files = ["test.js", "test2.js"] + await loadRequiredLanguageParsers(files) + await loadRequiredLanguageParsers(files) - it('should load JavaScript parser for .js and .jsx files', async () => { - const files = ['test.js', 'test.jsx']; - const parsers = await loadRequiredLanguageParsers(files); - - expect(ParserMock.Language.load).toHaveBeenCalledWith( - expect.stringContaining('tree-sitter-javascript.wasm') - ); - expect(parsers.js).toBeDefined(); - expect(parsers.jsx).toBeDefined(); - expect(parsers.js.query).toBeDefined(); - expect(parsers.jsx.query).toBeDefined(); - }); + expect(ParserMock.init).toHaveBeenCalledTimes(1) + }) - it('should load TypeScript parser for .ts and .tsx files', async () => { - const files = ['test.ts', 'test.tsx']; - const parsers = await loadRequiredLanguageParsers(files); - - expect(ParserMock.Language.load).toHaveBeenCalledWith( - expect.stringContaining('tree-sitter-typescript.wasm') - ); - expect(ParserMock.Language.load).toHaveBeenCalledWith( - expect.stringContaining('tree-sitter-tsx.wasm') - ); - expect(parsers.ts).toBeDefined(); - expect(parsers.tsx).toBeDefined(); - }); + it("should load JavaScript parser for .js and .jsx files", async () => { + const files = ["test.js", "test.jsx"] + const parsers = await loadRequiredLanguageParsers(files) - it('should load Python parser for .py files', async () => { - const files = ['test.py']; - const parsers = await loadRequiredLanguageParsers(files); - - expect(ParserMock.Language.load).toHaveBeenCalledWith( - expect.stringContaining('tree-sitter-python.wasm') - ); - expect(parsers.py).toBeDefined(); - }); + expect(ParserMock.Language.load).toHaveBeenCalledWith( + expect.stringContaining("tree-sitter-javascript.wasm"), + ) + expect(parsers.js).toBeDefined() + expect(parsers.jsx).toBeDefined() + expect(parsers.js.query).toBeDefined() + expect(parsers.jsx.query).toBeDefined() + }) - it('should load multiple language parsers as needed', async () => { - const files = ['test.js', 'test.py', 'test.rs', 'test.go']; - const parsers = await loadRequiredLanguageParsers(files); - - expect(ParserMock.Language.load).toHaveBeenCalledTimes(4); - expect(parsers.js).toBeDefined(); - expect(parsers.py).toBeDefined(); - expect(parsers.rs).toBeDefined(); - expect(parsers.go).toBeDefined(); - }); + it("should load TypeScript parser for .ts and .tsx files", async () => { + const files = ["test.ts", "test.tsx"] + const parsers = await loadRequiredLanguageParsers(files) - it('should handle C/C++ files correctly', async () => { - const files = ['test.c', 'test.h', 'test.cpp', 'test.hpp']; - const parsers = await loadRequiredLanguageParsers(files); - - expect(ParserMock.Language.load).toHaveBeenCalledWith( - expect.stringContaining('tree-sitter-c.wasm') - ); - expect(ParserMock.Language.load).toHaveBeenCalledWith( - expect.stringContaining('tree-sitter-cpp.wasm') - ); - expect(parsers.c).toBeDefined(); - expect(parsers.h).toBeDefined(); - expect(parsers.cpp).toBeDefined(); - expect(parsers.hpp).toBeDefined(); - }); + expect(ParserMock.Language.load).toHaveBeenCalledWith( + expect.stringContaining("tree-sitter-typescript.wasm"), + ) + expect(ParserMock.Language.load).toHaveBeenCalledWith(expect.stringContaining("tree-sitter-tsx.wasm")) + expect(parsers.ts).toBeDefined() + expect(parsers.tsx).toBeDefined() + }) - it('should throw error for unsupported file extensions', async () => { - const files = ['test.unsupported']; - - await expect(loadRequiredLanguageParsers(files)).rejects.toThrow( - 'Unsupported language: unsupported' - ); - }); + it("should load Python parser for .py files", async () => { + const files = ["test.py"] + const parsers = await loadRequiredLanguageParsers(files) - it('should load each language only once for multiple files', async () => { - const files = ['test1.js', 'test2.js', 'test3.js']; - await loadRequiredLanguageParsers(files); - - expect(ParserMock.Language.load).toHaveBeenCalledTimes(1); - expect(ParserMock.Language.load).toHaveBeenCalledWith( - expect.stringContaining('tree-sitter-javascript.wasm') - ); - }); + expect(ParserMock.Language.load).toHaveBeenCalledWith(expect.stringContaining("tree-sitter-python.wasm")) + expect(parsers.py).toBeDefined() + }) - it('should set language for each parser instance', async () => { - const files = ['test.js', 'test.py']; - await loadRequiredLanguageParsers(files); - - expect(mockSetLanguage).toHaveBeenCalledTimes(2); - }); - }); -}); \ No newline at end of file + it("should load multiple language parsers as needed", async () => { + const files = ["test.js", "test.py", "test.rs", "test.go"] + const parsers = await loadRequiredLanguageParsers(files) + + expect(ParserMock.Language.load).toHaveBeenCalledTimes(4) + expect(parsers.js).toBeDefined() + expect(parsers.py).toBeDefined() + expect(parsers.rs).toBeDefined() + expect(parsers.go).toBeDefined() + }) + + it("should handle C/C++ files correctly", async () => { + const files = ["test.c", "test.h", "test.cpp", "test.hpp"] + const parsers = await loadRequiredLanguageParsers(files) + + expect(ParserMock.Language.load).toHaveBeenCalledWith(expect.stringContaining("tree-sitter-c.wasm")) + expect(ParserMock.Language.load).toHaveBeenCalledWith(expect.stringContaining("tree-sitter-cpp.wasm")) + expect(parsers.c).toBeDefined() + expect(parsers.h).toBeDefined() + expect(parsers.cpp).toBeDefined() + expect(parsers.hpp).toBeDefined() + }) + + it("should throw error for unsupported file extensions", async () => { + const files = ["test.unsupported"] + + await expect(loadRequiredLanguageParsers(files)).rejects.toThrow("Unsupported language: unsupported") + }) + + it("should load each language only once for multiple files", async () => { + const files = ["test1.js", "test2.js", "test3.js"] + await loadRequiredLanguageParsers(files) + + expect(ParserMock.Language.load).toHaveBeenCalledTimes(1) + expect(ParserMock.Language.load).toHaveBeenCalledWith( + expect.stringContaining("tree-sitter-javascript.wasm"), + ) + }) + + it("should set language for each parser instance", async () => { + const files = ["test.js", "test.py"] + await loadRequiredLanguageParsers(files) + + expect(mockSetLanguage).toHaveBeenCalledTimes(2) + }) + }) +}) diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 73682a4..10fa0c5 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -94,7 +94,7 @@ export interface ExtensionState { mode: Mode modeApiConfigs?: Record enhancementApiConfigId?: string - experimentalDiffStrategy?: boolean + experimentalDiffStrategy?: boolean autoApprovalEnabled?: boolean } diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index 1e6e19f..28ae9c9 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -1,7 +1,7 @@ import { ApiConfiguration, ApiProvider } from "./api" import { Mode, PromptComponent } from "./modes" -export type PromptMode = Mode | 'enhance' +export type PromptMode = Mode | "enhance" export type AudioType = "notification" | "celebration" | "progress_loop" @@ -72,7 +72,7 @@ export interface WebviewMessage { | "getSystemPrompt" | "systemPrompt" | "enhancementApiConfigId" - | "experimentalDiffStrategy" + | "experimentalDiffStrategy" | "autoApprovalEnabled" text?: string disabled?: boolean diff --git a/src/shared/__tests__/checkExistApiConfig.test.ts b/src/shared/__tests__/checkExistApiConfig.test.ts index c0fcb64..25c967c 100644 --- a/src/shared/__tests__/checkExistApiConfig.test.ts +++ b/src/shared/__tests__/checkExistApiConfig.test.ts @@ -1,57 +1,57 @@ -import { checkExistKey } from '../checkExistApiConfig'; -import { ApiConfiguration } from '../api'; +import { checkExistKey } from "../checkExistApiConfig" +import { ApiConfiguration } from "../api" -describe('checkExistKey', () => { - it('should return false for undefined config', () => { - expect(checkExistKey(undefined)).toBe(false); - }); +describe("checkExistKey", () => { + it("should return false for undefined config", () => { + expect(checkExistKey(undefined)).toBe(false) + }) - it('should return false for empty config', () => { - const config: ApiConfiguration = {}; - expect(checkExistKey(config)).toBe(false); - }); + it("should return false for empty config", () => { + const config: ApiConfiguration = {} + expect(checkExistKey(config)).toBe(false) + }) - it('should return true when one key is defined', () => { - const config: ApiConfiguration = { - apiKey: 'test-key' - }; - expect(checkExistKey(config)).toBe(true); - }); + it("should return true when one key is defined", () => { + const config: ApiConfiguration = { + apiKey: "test-key", + } + expect(checkExistKey(config)).toBe(true) + }) - it('should return true when multiple keys are defined', () => { - const config: ApiConfiguration = { - apiKey: 'test-key', - glamaApiKey: 'glama-key', - openRouterApiKey: 'openrouter-key' - }; - expect(checkExistKey(config)).toBe(true); - }); + it("should return true when multiple keys are defined", () => { + const config: ApiConfiguration = { + apiKey: "test-key", + glamaApiKey: "glama-key", + openRouterApiKey: "openrouter-key", + } + expect(checkExistKey(config)).toBe(true) + }) - it('should return true when only non-key fields are undefined', () => { - const config: ApiConfiguration = { - apiKey: 'test-key', - apiProvider: undefined, - anthropicBaseUrl: undefined - }; - expect(checkExistKey(config)).toBe(true); - }); + it("should return true when only non-key fields are undefined", () => { + const config: ApiConfiguration = { + apiKey: "test-key", + apiProvider: undefined, + anthropicBaseUrl: undefined, + } + expect(checkExistKey(config)).toBe(true) + }) - it('should return false when all key fields are undefined', () => { - const config: ApiConfiguration = { - apiKey: undefined, - glamaApiKey: undefined, - openRouterApiKey: undefined, - awsRegion: undefined, - vertexProjectId: undefined, - openAiApiKey: undefined, - ollamaModelId: undefined, - lmStudioModelId: undefined, - geminiApiKey: undefined, - openAiNativeApiKey: undefined, - deepSeekApiKey: undefined, - mistralApiKey: undefined, - vsCodeLmModelSelector: undefined - }; - expect(checkExistKey(config)).toBe(false); - }); -}); \ No newline at end of file + it("should return false when all key fields are undefined", () => { + const config: ApiConfiguration = { + apiKey: undefined, + glamaApiKey: undefined, + openRouterApiKey: undefined, + awsRegion: undefined, + vertexProjectId: undefined, + openAiApiKey: undefined, + ollamaModelId: undefined, + lmStudioModelId: undefined, + geminiApiKey: undefined, + openAiNativeApiKey: undefined, + deepSeekApiKey: undefined, + mistralApiKey: undefined, + vsCodeLmModelSelector: undefined, + } + expect(checkExistKey(config)).toBe(false) + }) +}) diff --git a/src/shared/__tests__/vsCodeSelectorUtils.test.ts b/src/shared/__tests__/vsCodeSelectorUtils.test.ts index dd4ed38..6e6c188 100644 --- a/src/shared/__tests__/vsCodeSelectorUtils.test.ts +++ b/src/shared/__tests__/vsCodeSelectorUtils.test.ts @@ -1,44 +1,44 @@ -import { stringifyVsCodeLmModelSelector, SELECTOR_SEPARATOR } from '../vsCodeSelectorUtils'; -import { LanguageModelChatSelector } from 'vscode'; +import { stringifyVsCodeLmModelSelector, SELECTOR_SEPARATOR } from "../vsCodeSelectorUtils" +import { LanguageModelChatSelector } from "vscode" -describe('vsCodeSelectorUtils', () => { - describe('stringifyVsCodeLmModelSelector', () => { - it('should join all defined selector properties with separator', () => { +describe("vsCodeSelectorUtils", () => { + describe("stringifyVsCodeLmModelSelector", () => { + it("should join all defined selector properties with separator", () => { const selector: LanguageModelChatSelector = { - vendor: 'test-vendor', - family: 'test-family', - version: 'v1', - id: 'test-id' - }; + vendor: "test-vendor", + family: "test-family", + version: "v1", + id: "test-id", + } - const result = stringifyVsCodeLmModelSelector(selector); - expect(result).toBe('test-vendor/test-family/v1/test-id'); - }); + const result = stringifyVsCodeLmModelSelector(selector) + expect(result).toBe("test-vendor/test-family/v1/test-id") + }) - it('should skip undefined properties', () => { + it("should skip undefined properties", () => { const selector: LanguageModelChatSelector = { - vendor: 'test-vendor', - family: 'test-family' - }; + vendor: "test-vendor", + family: "test-family", + } - const result = stringifyVsCodeLmModelSelector(selector); - expect(result).toBe('test-vendor/test-family'); - }); + const result = stringifyVsCodeLmModelSelector(selector) + expect(result).toBe("test-vendor/test-family") + }) - it('should handle empty selector', () => { - const selector: LanguageModelChatSelector = {}; + it("should handle empty selector", () => { + const selector: LanguageModelChatSelector = {} - const result = stringifyVsCodeLmModelSelector(selector); - expect(result).toBe(''); - }); + const result = stringifyVsCodeLmModelSelector(selector) + expect(result).toBe("") + }) - it('should handle selector with only one property', () => { + it("should handle selector with only one property", () => { const selector: LanguageModelChatSelector = { - vendor: 'test-vendor' - }; + vendor: "test-vendor", + } - const result = stringifyVsCodeLmModelSelector(selector); - expect(result).toBe('test-vendor'); - }); - }); -}); \ No newline at end of file + const result = stringifyVsCodeLmModelSelector(selector) + expect(result).toBe("test-vendor") + }) + }) +}) diff --git a/src/shared/api.ts b/src/shared/api.ts index 9721f65..4fd25ba 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -1,4 +1,4 @@ -import * as vscode from 'vscode'; +import * as vscode from "vscode" export type ApiProvider = | "anthropic" @@ -126,24 +126,24 @@ export const anthropicModels = { // AWS Bedrock // https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html export interface MessageContent { - type: 'text' | 'image' | 'video' | 'tool_use' | 'tool_result'; - text?: string; - source?: { - type: 'base64'; - data: string | Uint8Array; // string for Anthropic, Uint8Array for Bedrock - media_type: 'image/jpeg' | 'image/png' | 'image/gif' | 'image/webp'; - }; - // Video specific fields - format?: string; - s3Location?: { - uri: string; - bucketOwner?: string; - }; - // Tool use and result fields - toolUseId?: string; - name?: string; - input?: any; - output?: any; // Used for tool_result type + type: "text" | "image" | "video" | "tool_use" | "tool_result" + text?: string + source?: { + type: "base64" + data: string | Uint8Array // string for Anthropic, Uint8Array for Bedrock + media_type: "image/jpeg" | "image/png" | "image/gif" | "image/webp" + } + // Video specific fields + format?: string + s3Location?: { + uri: string + bucketOwner?: string + } + // Tool use and result fields + toolUseId?: string + name?: string + input?: any + output?: any // Used for tool_result type } export type BedrockModelId = keyof typeof bedrockModels @@ -192,7 +192,6 @@ export const bedrockModels = { outputPrice: 15.0, cacheWritesPrice: 3.75, // per million tokens cacheReadsPrice: 0.3, // per million tokens - }, "anthropic.claude-3-5-haiku-20241022-v1:0": { maxTokens: 8192, @@ -203,7 +202,6 @@ export const bedrockModels = { outputPrice: 5.0, cacheWritesPrice: 1.0, cacheReadsPrice: 0.08, - }, "anthropic.claude-3-5-sonnet-20240620-v1:0": { maxTokens: 8192, @@ -237,7 +235,7 @@ export const bedrockModels = { inputPrice: 0.25, outputPrice: 1.25, }, - "meta.llama3-2-90b-instruct-v1:0" : { + "meta.llama3-2-90b-instruct-v1:0": { maxTokens: 8192, contextWindow: 128_000, supportsImages: true, @@ -246,7 +244,7 @@ export const bedrockModels = { inputPrice: 0.72, outputPrice: 0.72, }, - "meta.llama3-2-11b-instruct-v1:0" : { + "meta.llama3-2-11b-instruct-v1:0": { maxTokens: 8192, contextWindow: 128_000, supportsImages: true, @@ -255,7 +253,7 @@ export const bedrockModels = { inputPrice: 0.16, outputPrice: 0.16, }, - "meta.llama3-2-3b-instruct-v1:0" : { + "meta.llama3-2-3b-instruct-v1:0": { maxTokens: 8192, contextWindow: 128_000, supportsImages: false, @@ -264,7 +262,7 @@ export const bedrockModels = { inputPrice: 0.15, outputPrice: 0.15, }, - "meta.llama3-2-1b-instruct-v1:0" : { + "meta.llama3-2-1b-instruct-v1:0": { maxTokens: 8192, contextWindow: 128_000, supportsImages: false, @@ -273,7 +271,7 @@ export const bedrockModels = { inputPrice: 0.1, outputPrice: 0.1, }, - "meta.llama3-1-405b-instruct-v1:0" : { + "meta.llama3-1-405b-instruct-v1:0": { maxTokens: 8192, contextWindow: 128_000, supportsImages: false, @@ -282,7 +280,7 @@ export const bedrockModels = { inputPrice: 2.4, outputPrice: 2.4, }, - "meta.llama3-1-70b-instruct-v1:0" : { + "meta.llama3-1-70b-instruct-v1:0": { maxTokens: 8192, contextWindow: 128_000, supportsImages: false, @@ -291,7 +289,7 @@ export const bedrockModels = { inputPrice: 0.72, outputPrice: 0.72, }, - "meta.llama3-1-8b-instruct-v1:0" : { + "meta.llama3-1-8b-instruct-v1:0": { maxTokens: 8192, contextWindow: 8_000, supportsImages: false, @@ -300,8 +298,8 @@ export const bedrockModels = { inputPrice: 0.22, outputPrice: 0.22, }, - "meta.llama3-70b-instruct-v1:0" : { - maxTokens: 2048 , + "meta.llama3-70b-instruct-v1:0": { + maxTokens: 2048, contextWindow: 8_000, supportsImages: false, supportsComputerUse: false, @@ -309,8 +307,8 @@ export const bedrockModels = { inputPrice: 2.65, outputPrice: 3.5, }, - "meta.llama3-8b-instruct-v1:0" : { - maxTokens: 2048 , + "meta.llama3-8b-instruct-v1:0": { + maxTokens: 2048, contextWindow: 4_000, supportsImages: false, supportsComputerUse: false, @@ -488,7 +486,7 @@ export type OpenAiNativeModelId = keyof typeof openAiNativeModels export const openAiNativeDefaultModelId: OpenAiNativeModelId = "gpt-4o" export const openAiNativeModels = { // don't support tool use yet - "o1": { + o1: { maxTokens: 100_000, contextWindow: 200_000, supportsImages: true, @@ -540,8 +538,8 @@ export const deepSeekModels = { contextWindow: 64_000, supportsImages: false, supportsPromptCache: false, - inputPrice: 0.014, // $0.014 per million tokens - outputPrice: 0.28, // $0.28 per million tokens + inputPrice: 0.014, // $0.014 per million tokens + outputPrice: 0.28, // $0.28 per million tokens description: `DeepSeek-V3 achieves a significant breakthrough in inference speed over previous models. It tops the leaderboard among open-source models and rivals the most advanced closed-source models globally.`, }, } as const satisfies Record @@ -551,7 +549,6 @@ export const deepSeekModels = { // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#api-specs export const azureOpenAiDefaultApiVersion = "2024-08-01-preview" - // Mistral // https://docs.mistral.ai/getting-started/models/models_overview/ export type MistralModelId = keyof typeof mistralModels diff --git a/src/shared/checkExistApiConfig.ts b/src/shared/checkExistApiConfig.ts index 6dec4ff..8cb8055 100644 --- a/src/shared/checkExistApiConfig.ts +++ b/src/shared/checkExistApiConfig.ts @@ -1,21 +1,21 @@ -import { ApiConfiguration } from "../shared/api"; +import { ApiConfiguration } from "../shared/api" export function checkExistKey(config: ApiConfiguration | undefined) { return config ? [ - config.apiKey, - config.glamaApiKey, - config.openRouterApiKey, - config.awsRegion, - config.vertexProjectId, - config.openAiApiKey, - config.ollamaModelId, - config.lmStudioModelId, - config.geminiApiKey, - config.openAiNativeApiKey, - config.deepSeekApiKey, - config.mistralApiKey, - config.vsCodeLmModelSelector, - ].some((key) => key !== undefined) - : false; + config.apiKey, + config.glamaApiKey, + config.openRouterApiKey, + config.awsRegion, + config.vertexProjectId, + config.openAiApiKey, + config.ollamaModelId, + config.lmStudioModelId, + config.geminiApiKey, + config.openAiNativeApiKey, + config.deepSeekApiKey, + config.mistralApiKey, + config.vsCodeLmModelSelector, + ].some((key) => key !== undefined) + : false } diff --git a/src/shared/context-mentions.ts b/src/shared/context-mentions.ts index ed2a657..10375ff 100644 --- a/src/shared/context-mentions.ts +++ b/src/shared/context-mentions.ts @@ -49,11 +49,12 @@ Mention regex: - `mentionRegexGlobal`: Creates a global version of the `mentionRegex` to find all matches within a given string. */ -export const mentionRegex = /@((?:\/|\w+:\/\/)[^\s]+?|[a-f0-9]{7,40}\b|problems\b|git-changes\b)(?=[.,;:!?]?(?=[\s\r\n]|$))/ +export const mentionRegex = + /@((?:\/|\w+:\/\/)[^\s]+?|[a-f0-9]{7,40}\b|problems\b|git-changes\b)(?=[.,;:!?]?(?=[\s\r\n]|$))/ export const mentionRegexGlobal = new RegExp(mentionRegex.source, "g") export interface MentionSuggestion { - type: 'file' | 'folder' | 'git' | 'problems' + type: "file" | "folder" | "git" | "problems" label: string description?: string value: string @@ -61,7 +62,7 @@ export interface MentionSuggestion { } export interface GitMentionSuggestion extends MentionSuggestion { - type: 'git' + type: "git" hash: string shortHash: string subject: string @@ -69,17 +70,23 @@ export interface GitMentionSuggestion extends MentionSuggestion { date: string } -export function formatGitSuggestion(commit: { hash: string; shortHash: string; subject: string; author: string; date: string }): GitMentionSuggestion { +export function formatGitSuggestion(commit: { + hash: string + shortHash: string + subject: string + author: string + date: string +}): GitMentionSuggestion { return { - type: 'git', + type: "git", label: commit.subject, description: `${commit.shortHash} by ${commit.author} on ${commit.date}`, value: commit.hash, - icon: '$(git-commit)', // VSCode git commit icon + icon: "$(git-commit)", // VSCode git commit icon hash: commit.hash, shortHash: commit.shortHash, subject: commit.subject, author: commit.author, - date: commit.date + date: commit.date, } } diff --git a/src/shared/modes.ts b/src/shared/modes.ts index 475ee55..9de7e8a 100644 --- a/src/shared/modes.ts +++ b/src/shared/modes.ts @@ -1,187 +1,189 @@ // Tool options for specific tools export type ToolOptions = { - string: readonly string[]; + string: readonly string[] } // Tool configuration tuple type -export type ToolConfig = readonly [string] | readonly [string, ToolOptions]; +export type ToolConfig = readonly [string] | readonly [string, ToolOptions] // Mode types -export type Mode = string; +export type Mode = string // Mode configuration type export type ModeConfig = { - slug: string; - name: string; - roleDefinition: string; - tools: readonly ToolConfig[]; + slug: string + name: string + roleDefinition: string + tools: readonly ToolConfig[] } // Separate enhance prompt type and definition export type EnhanceConfig = { - prompt: string; + prompt: string } export const enhance: EnhanceConfig = { - prompt: "Generate an enhanced version of this prompt (reply with only the enhanced prompt - no conversation, explanations, lead-in, bullet points, placeholders, or surrounding quotes):" -} as const; + prompt: "Generate an enhanced version of this prompt (reply with only the enhanced prompt - no conversation, explanations, lead-in, bullet points, placeholders, or surrounding quotes):", +} as const // Main modes configuration as an ordered array export const modes: readonly ModeConfig[] = [ - { - slug: 'code', - name: 'Code', - roleDefinition: "You are Cline, a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.", - tools: [ - ['execute_command'], - ['read_file'], - ['write_to_file'], - ['apply_diff'], - ['search_files'], - ['list_files'], - ['list_code_definition_names'], - ['browser_action'], - ['use_mcp_tool'], - ['access_mcp_resource'], - ['ask_followup_question'], - ['attempt_completion'], - ] as const - }, - { - slug: 'architect', - name: 'Architect', - roleDefinition: "You are Cline, a software architecture expert specializing in analyzing codebases, identifying patterns, and providing high-level technical guidance. You excel at understanding complex systems, evaluating architectural decisions, and suggesting improvements while maintaining a read-only approach to the codebase. Make sure to help the user come up with a solid implementation plan for their project and don't rush to switch to implementing code.", - tools: [ - ['read_file'], - ['search_files'], - ['list_files'], - ['list_code_definition_names'], - ['browser_action'], - ['use_mcp_tool'], - ['access_mcp_resource'], - ['ask_followup_question'], - ['attempt_completion'], - ] as const - }, - { - slug: 'ask', - name: 'Ask', - roleDefinition: "You are Cline, a knowledgeable technical assistant focused on answering questions and providing information about software development, technology, and related topics. You can analyze code, explain concepts, and access external resources while maintaining a read-only approach to the codebase. Make sure to answer the user's questions and don't rush to switch to implementing code.", - tools: [ - ['read_file'], - ['search_files'], - ['list_files'], - ['list_code_definition_names'], - ['browser_action'], - ['use_mcp_tool'], - ['access_mcp_resource'], - ['ask_followup_question'], - ['attempt_completion'], - ] as const - }, - { - slug: 'test', - name: 'Test', - roleDefinition: "You are Cline, a software test engineering expert specializing in writing comprehensive test suites and ensuring thorough test coverage. You excel at writing unit tests, integration tests, and end-to-end tests that cover all edge cases while maintaining existing behavior. You must ask the user to confirm before making ANY changes to non-test code, and before implementing any test changes, you always ask the user to confirm your test plan. You focus on: 1) Writing tests that verify functionality without changing existing behavior, 2) Ensuring comprehensive test coverage including edge cases and error conditions, 3) Following testing best practices and patterns appropriate for the language/framework, 4) Using mocks, stubs, and fixtures effectively, 5) Writing clear, maintainable test code with descriptive names and good documentation.", - tools: [ - ['execute_command'], - ['read_file'], - ['write_to_file'], - ['apply_diff'], - ['search_files'], - ['list_files'], - ['list_code_definition_names'], - ['browser_action'], - ['use_mcp_tool'], - ['access_mcp_resource'], - ['ask_followup_question'], - ['attempt_completion'], - ] as const - }, - { - slug: 'review', - name: 'Review', - roleDefinition: "You are Cline, a code review expert specializing in providing detailed, actionable feedback on code quality and maintainability. You excel at: 1) Identifying potential bugs, security vulnerabilities, and performance issues, 2) Ensuring code follows project standards, patterns, and best practices, 3) Checking for proper error handling and edge cases, 4) Verifying documentation completeness and clarity, 5) Suggesting specific, actionable improvements with examples. You maintain a read-only approach to the codebase and focus on helping developers improve their code through clear, constructive feedback.", - tools: [ - ['read_file'], - ['search_files'], - ['list_files'], - ['list_code_definition_names'], - ['browser_action'], - ['use_mcp_tool'], - ['access_mcp_resource'], - ['ask_followup_question'], - ['attempt_completion'], - ] as const - }, -] as const; + { + slug: "code", + name: "Code", + roleDefinition: + "You are Cline, a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.", + tools: [ + ["execute_command"], + ["read_file"], + ["write_to_file"], + ["apply_diff"], + ["search_files"], + ["list_files"], + ["list_code_definition_names"], + ["browser_action"], + ["use_mcp_tool"], + ["access_mcp_resource"], + ["ask_followup_question"], + ["attempt_completion"], + ] as const, + }, + { + slug: "architect", + name: "Architect", + roleDefinition: + "You are Cline, a software architecture expert specializing in analyzing codebases, identifying patterns, and providing high-level technical guidance. You excel at understanding complex systems, evaluating architectural decisions, and suggesting improvements while maintaining a read-only approach to the codebase. Make sure to help the user come up with a solid implementation plan for their project and don't rush to switch to implementing code.", + tools: [ + ["read_file"], + ["search_files"], + ["list_files"], + ["list_code_definition_names"], + ["browser_action"], + ["use_mcp_tool"], + ["access_mcp_resource"], + ["ask_followup_question"], + ["attempt_completion"], + ] as const, + }, + { + slug: "ask", + name: "Ask", + roleDefinition: + "You are Cline, a knowledgeable technical assistant focused on answering questions and providing information about software development, technology, and related topics. You can analyze code, explain concepts, and access external resources while maintaining a read-only approach to the codebase. Make sure to answer the user's questions and don't rush to switch to implementing code.", + tools: [ + ["read_file"], + ["search_files"], + ["list_files"], + ["list_code_definition_names"], + ["browser_action"], + ["use_mcp_tool"], + ["access_mcp_resource"], + ["ask_followup_question"], + ["attempt_completion"], + ] as const, + }, + { + slug: "test", + name: "Test", + roleDefinition: + "You are Cline, a software test engineering expert specializing in writing comprehensive test suites and ensuring thorough test coverage. You excel at writing unit tests, integration tests, and end-to-end tests that cover all edge cases while maintaining existing behavior. You must ask the user to confirm before making ANY changes to non-test code, and before implementing any test changes, you always ask the user to confirm your test plan. You focus on: 1) Writing tests that verify functionality without changing existing behavior, 2) Ensuring comprehensive test coverage including edge cases and error conditions, 3) Following testing best practices and patterns appropriate for the language/framework, 4) Using mocks, stubs, and fixtures effectively, 5) Writing clear, maintainable test code with descriptive names and good documentation.", + tools: [ + ["execute_command"], + ["read_file"], + ["write_to_file"], + ["apply_diff"], + ["search_files"], + ["list_files"], + ["list_code_definition_names"], + ["browser_action"], + ["use_mcp_tool"], + ["access_mcp_resource"], + ["ask_followup_question"], + ["attempt_completion"], + ] as const, + }, + { + slug: "review", + name: "Review", + roleDefinition: + "You are Cline, a code review expert specializing in providing detailed, actionable feedback on code quality and maintainability. You excel at: 1) Identifying potential bugs, security vulnerabilities, and performance issues, 2) Ensuring code follows project standards, patterns, and best practices, 3) Checking for proper error handling and edge cases, 4) Verifying documentation completeness and clarity, 5) Suggesting specific, actionable improvements with examples. You maintain a read-only approach to the codebase and focus on helping developers improve their code through clear, constructive feedback.", + tools: [ + ["read_file"], + ["search_files"], + ["list_files"], + ["list_code_definition_names"], + ["browser_action"], + ["use_mcp_tool"], + ["access_mcp_resource"], + ["ask_followup_question"], + ["attempt_completion"], + ] as const, + }, +] as const // Export the default mode slug -export const defaultModeSlug = modes[0].slug; +export const defaultModeSlug = modes[0].slug // Helper functions export function getModeBySlug(slug: string): ModeConfig | undefined { - return modes.find(mode => mode.slug === slug); + return modes.find((mode) => mode.slug === slug) } export function getModeConfig(slug: string): ModeConfig { - const mode = getModeBySlug(slug); - if (!mode) { - throw new Error(`No mode found for slug: ${slug}`); - } - return mode; + const mode = getModeBySlug(slug) + if (!mode) { + throw new Error(`No mode found for slug: ${slug}`) + } + return mode } // Derive tool names from the modes configuration -export type ToolName = typeof modes[number]['tools'][number][0]; -export type TestToolName = ToolName | 'unknown_tool'; +export type ToolName = (typeof modes)[number]["tools"][number][0] +export type TestToolName = ToolName | "unknown_tool" export function isToolAllowedForMode(tool: TestToolName, modeSlug: string): boolean { - if (tool === 'unknown_tool') { - return false; - } - const mode = getModeBySlug(modeSlug); - if (!mode) { - return false; - } - return mode.tools.some(([toolName]) => toolName === tool); + if (tool === "unknown_tool") { + return false + } + const mode = getModeBySlug(modeSlug) + if (!mode) { + return false + } + return mode.tools.some(([toolName]) => toolName === tool) } export function getToolOptions(tool: ToolName, modeSlug: string): ToolOptions | undefined { - const mode = getModeBySlug(modeSlug); - if (!mode) { - return undefined; - } - const toolConfig = mode.tools.find(([toolName]) => toolName === tool); - return toolConfig?.[1]; + const mode = getModeBySlug(modeSlug) + if (!mode) { + return undefined + } + const toolConfig = mode.tools.find(([toolName]) => toolName === tool) + return toolConfig?.[1] } export type PromptComponent = { - roleDefinition?: string; - customInstructions?: string; + roleDefinition?: string + customInstructions?: string } export type CustomPrompts = { - [key: string]: PromptComponent | string | undefined; + [key: string]: PromptComponent | string | undefined } // Create the defaultPrompts object with the correct type export const defaultPrompts: CustomPrompts = { - ...Object.fromEntries(modes.map(mode => [ - mode.slug, - { roleDefinition: mode.roleDefinition } - ])), - enhance: enhance.prompt -} as const; + ...Object.fromEntries(modes.map((mode) => [mode.slug, { roleDefinition: mode.roleDefinition }])), + enhance: enhance.prompt, +} as const // Helper function to safely get role definition export function getRoleDefinition(modeSlug: string): string { - const prompt = defaultPrompts[modeSlug]; - if (!prompt || typeof prompt === 'string') { - throw new Error(`Invalid mode slug: ${modeSlug}`); - } - if (!prompt.roleDefinition) { - throw new Error(`No role definition found for mode: ${modeSlug}`); - } - return prompt.roleDefinition; -} \ No newline at end of file + const prompt = defaultPrompts[modeSlug] + if (!prompt || typeof prompt === "string") { + throw new Error(`Invalid mode slug: ${modeSlug}`) + } + if (!prompt.roleDefinition) { + throw new Error(`No role definition found for mode: ${modeSlug}`) + } + return prompt.roleDefinition +} diff --git a/src/shared/vsCodeSelectorUtils.ts b/src/shared/vsCodeSelectorUtils.ts index a54d63f..620fccc 100644 --- a/src/shared/vsCodeSelectorUtils.ts +++ b/src/shared/vsCodeSelectorUtils.ts @@ -1,14 +1,7 @@ -import { LanguageModelChatSelector } from 'vscode'; +import { LanguageModelChatSelector } from "vscode" -export const SELECTOR_SEPARATOR = '/'; +export const SELECTOR_SEPARATOR = "/" export function stringifyVsCodeLmModelSelector(selector: LanguageModelChatSelector): string { - return [ - selector.vendor, - selector.family, - selector.version, - selector.id - ] - .filter(Boolean) - .join(SELECTOR_SEPARATOR); + return [selector.vendor, selector.family, selector.version, selector.id].filter(Boolean).join(SELECTOR_SEPARATOR) } diff --git a/src/test/extension.test.ts b/src/test/extension.test.ts index 7377f3f..c67b3db 100644 --- a/src/test/extension.test.ts +++ b/src/test/extension.test.ts @@ -1,345 +1,333 @@ -const assert = require('assert'); -const vscode = require('vscode'); -const path = require('path'); -const fs = require('fs'); -const dotenv = require('dotenv'); +const assert = require("assert") +const vscode = require("vscode") +const path = require("path") +const fs = require("fs") +const dotenv = require("dotenv") // Load test environment variables -const testEnvPath = path.join(__dirname, '.test_env'); -dotenv.config({ path: testEnvPath }); +const testEnvPath = path.join(__dirname, ".test_env") +dotenv.config({ path: testEnvPath }) -suite('Roo Cline Extension Test Suite', () => { - vscode.window.showInformationMessage('Starting Roo Cline extension tests.'); +suite("Roo Cline Extension Test Suite", () => { + vscode.window.showInformationMessage("Starting Roo Cline extension tests.") - test('Extension should be present', () => { - const extension = vscode.extensions.getExtension('RooVeterinaryInc.roo-cline'); - assert.notStrictEqual(extension, undefined); - }); + test("Extension should be present", () => { + const extension = vscode.extensions.getExtension("RooVeterinaryInc.roo-cline") + assert.notStrictEqual(extension, undefined) + }) - test('Extension should activate', async () => { - const extension = vscode.extensions.getExtension('RooVeterinaryInc.roo-cline'); + test("Extension should activate", async () => { + const extension = vscode.extensions.getExtension("RooVeterinaryInc.roo-cline") if (!extension) { - assert.fail('Extension not found'); + assert.fail("Extension not found") } - await extension.activate(); - assert.strictEqual(extension.isActive, true); - }); + await extension.activate() + assert.strictEqual(extension.isActive, true) + }) - test('OpenRouter API key and models should be configured correctly', function(done) { + test("OpenRouter API key and models should be configured correctly", function (done) { // @ts-ignore - this.timeout(60000); // Increase timeout to 60s for network requests - - (async () => { + this.timeout(60000) // Increase timeout to 60s for network requests + ;(async () => { try { // Get extension instance - const extension = vscode.extensions.getExtension('RooVeterinaryInc.roo-cline'); + const extension = vscode.extensions.getExtension("RooVeterinaryInc.roo-cline") if (!extension) { - done(new Error('Extension not found')); - return; + done(new Error("Extension not found")) + return } // Verify API key is set and valid - const apiKey = process.env.OPEN_ROUTER_API_KEY; + const apiKey = process.env.OPEN_ROUTER_API_KEY if (!apiKey) { - done(new Error('OPEN_ROUTER_API_KEY environment variable is not set')); - return; + done(new Error("OPEN_ROUTER_API_KEY environment variable is not set")) + return } - if (!apiKey.startsWith('sk-or-v1-')) { - done(new Error('OpenRouter API key should have correct format')); - return; + if (!apiKey.startsWith("sk-or-v1-")) { + done(new Error("OpenRouter API key should have correct format")) + return } // Activate extension and get provider - const api = await extension.activate(); + const api = await extension.activate() if (!api) { - done(new Error('Extension API not found')); - return; + done(new Error("Extension API not found")) + return } // Get the provider from the extension's exports - const provider = api.sidebarProvider; + const provider = api.sidebarProvider if (!provider) { - done(new Error('Provider not found')); - return; + done(new Error("Provider not found")) + return } // Set up the API configuration - await provider.updateGlobalState('apiProvider', 'openrouter'); - await provider.storeSecret('openRouterApiKey', apiKey); + await provider.updateGlobalState("apiProvider", "openrouter") + await provider.storeSecret("openRouterApiKey", apiKey) // Set up timeout to fail test if models don't load const timeout = setTimeout(() => { - done(new Error('Timeout waiting for models to load')); - }, 30000); + done(new Error("Timeout waiting for models to load")) + }, 30000) // Wait for models to be loaded const checkModels = setInterval(async () => { try { - const models = await provider.readOpenRouterModels(); + const models = await provider.readOpenRouterModels() if (!models) { - return; + return } - clearInterval(checkModels); - clearTimeout(timeout); + clearInterval(checkModels) + clearTimeout(timeout) // Verify expected Claude models are available const expectedModels = [ - 'anthropic/claude-3.5-sonnet:beta', - 'anthropic/claude-3-sonnet:beta', - 'anthropic/claude-3.5-sonnet', - 'anthropic/claude-3.5-sonnet-20240620', - 'anthropic/claude-3.5-sonnet-20240620:beta', - 'anthropic/claude-3.5-haiku:beta' - ]; + "anthropic/claude-3.5-sonnet:beta", + "anthropic/claude-3-sonnet:beta", + "anthropic/claude-3.5-sonnet", + "anthropic/claude-3.5-sonnet-20240620", + "anthropic/claude-3.5-sonnet-20240620:beta", + "anthropic/claude-3.5-haiku:beta", + ] for (const modelId of expectedModels) { - assert.strictEqual( - modelId in models, - true, - `Model ${modelId} should be available` - ); + assert.strictEqual(modelId in models, true, `Model ${modelId} should be available`) } - done(); + done() } catch (error) { - clearInterval(checkModels); - clearTimeout(timeout); - done(error); + clearInterval(checkModels) + clearTimeout(timeout) + done(error) } - }, 1000); + }, 1000) // Trigger model loading - await provider.refreshOpenRouterModels(); - + await provider.refreshOpenRouterModels() } catch (error) { - done(error); + done(error) } - })(); - }); + })() + }) + + test("Commands should be registered", async () => { + const commands = await vscode.commands.getCommands(true) - test('Commands should be registered', async () => { - const commands = await vscode.commands.getCommands(true); - // Test core commands are registered const expectedCommands = [ - 'roo-cline.plusButtonClicked', - 'roo-cline.mcpButtonClicked', - 'roo-cline.historyButtonClicked', - 'roo-cline.popoutButtonClicked', - 'roo-cline.settingsButtonClicked', - 'roo-cline.openInNewTab' - ]; + "roo-cline.plusButtonClicked", + "roo-cline.mcpButtonClicked", + "roo-cline.historyButtonClicked", + "roo-cline.popoutButtonClicked", + "roo-cline.settingsButtonClicked", + "roo-cline.openInNewTab", + ] for (const cmd of expectedCommands) { - assert.strictEqual( - commands.includes(cmd), - true, - `Command ${cmd} should be registered` - ); + assert.strictEqual(commands.includes(cmd), true, `Command ${cmd} should be registered`) } - }); + }) - test('Views should be registered', () => { + test("Views should be registered", () => { const view = vscode.window.createWebviewPanel( - 'roo-cline.SidebarProvider', - 'Roo Cline', + "roo-cline.SidebarProvider", + "Roo Cline", vscode.ViewColumn.One, - {} - ); - assert.notStrictEqual(view, undefined); - view.dispose(); - }); + {}, + ) + assert.notStrictEqual(view, undefined) + view.dispose() + }) - test('Should handle prompt and response correctly', async function() { + test("Should handle prompt and response correctly", async function () { // @ts-ignore - this.timeout(60000); // Increase timeout for API request + this.timeout(60000) // Increase timeout for API request - const timeout = 30000; - const interval = 1000; + const timeout = 30000 + const interval = 1000 // Get extension instance - const extension = vscode.extensions.getExtension('RooVeterinaryInc.roo-cline'); + const extension = vscode.extensions.getExtension("RooVeterinaryInc.roo-cline") if (!extension) { - assert.fail('Extension not found'); - return; + assert.fail("Extension not found") + return } // Activate extension and get API - const api = await extension.activate(); + const api = await extension.activate() if (!api) { - assert.fail('Extension API not found'); - return; + assert.fail("Extension API not found") + return } // Get provider - const provider = api.sidebarProvider; + const provider = api.sidebarProvider if (!provider) { - assert.fail('Provider not found'); - return; + assert.fail("Provider not found") + return } // Set up API configuration - await provider.updateGlobalState('apiProvider', 'openrouter'); - await provider.updateGlobalState('openRouterModelId', 'anthropic/claude-3.5-sonnet'); - const apiKey = process.env.OPEN_ROUTER_API_KEY; + await provider.updateGlobalState("apiProvider", "openrouter") + await provider.updateGlobalState("openRouterModelId", "anthropic/claude-3.5-sonnet") + const apiKey = process.env.OPEN_ROUTER_API_KEY if (!apiKey) { - assert.fail('OPEN_ROUTER_API_KEY environment variable is not set'); - return; + assert.fail("OPEN_ROUTER_API_KEY environment variable is not set") + return } - await provider.storeSecret('openRouterApiKey', apiKey); + await provider.storeSecret("openRouterApiKey", apiKey) // Create webview panel with development options - const extensionUri = extension.extensionUri; + const extensionUri = extension.extensionUri const panel = vscode.window.createWebviewPanel( - 'roo-cline.SidebarProvider', - 'Roo Cline', + "roo-cline.SidebarProvider", + "Roo Cline", vscode.ViewColumn.One, { enableScripts: true, enableCommandUris: true, retainContextWhenHidden: true, - localResourceRoots: [extensionUri] - } - ); + localResourceRoots: [extensionUri], + }, + ) try { // Initialize webview with development context panel.webview.options = { enableScripts: true, enableCommandUris: true, - localResourceRoots: [extensionUri] - }; + localResourceRoots: [extensionUri], + } // Initialize provider with panel - provider.resolveWebviewView(panel); + provider.resolveWebviewView(panel) // Set up message tracking - let webviewReady = false; - let messagesReceived = false; - const originalPostMessage = provider.postMessageToWebview.bind(provider); + let webviewReady = false + let messagesReceived = false + const originalPostMessage = provider.postMessageToWebview.bind(provider) // @ts-ignore provider.postMessageToWebview = async (message) => { - if (message.type === 'state') { - webviewReady = true; - console.log('Webview state received:', message); + if (message.type === "state") { + webviewReady = true + console.log("Webview state received:", message) if (message.state?.clineMessages?.length > 0) { - messagesReceived = true; - console.log('Messages in state:', message.state.clineMessages); + messagesReceived = true + console.log("Messages in state:", message.state.clineMessages) } } - await originalPostMessage(message); - }; + await originalPostMessage(message) + } // Wait for webview to launch and receive initial state - let startTime = Date.now(); + let startTime = Date.now() while (Date.now() - startTime < timeout) { if (webviewReady) { // Wait an additional second for webview to fully initialize - await new Promise(resolve => setTimeout(resolve, 1000)); - break; + await new Promise((resolve) => setTimeout(resolve, 1000)) + break } - await new Promise(resolve => setTimeout(resolve, interval)); + await new Promise((resolve) => setTimeout(resolve, interval)) } if (!webviewReady) { - throw new Error('Timeout waiting for webview to be ready'); + throw new Error("Timeout waiting for webview to be ready") } // Send webviewDidLaunch to initialize chat - await provider.postMessageToWebview({ type: 'webviewDidLaunch' }); - console.log('Sent webviewDidLaunch'); + await provider.postMessageToWebview({ type: "webviewDidLaunch" }) + console.log("Sent webviewDidLaunch") // Wait for webview to fully initialize - await new Promise(resolve => setTimeout(resolve, 2000)); + await new Promise((resolve) => setTimeout(resolve, 2000)) // Restore original postMessage - provider.postMessageToWebview = originalPostMessage; + provider.postMessageToWebview = originalPostMessage // Wait for OpenRouter models to be fully loaded - startTime = Date.now(); + startTime = Date.now() while (Date.now() - startTime < timeout) { - const models = await provider.readOpenRouterModels(); + const models = await provider.readOpenRouterModels() if (models && Object.keys(models).length > 0) { - console.log('OpenRouter models loaded'); - break; + console.log("OpenRouter models loaded") + break } - await new Promise(resolve => setTimeout(resolve, interval)); + await new Promise((resolve) => setTimeout(resolve, interval)) } // Send prompt - const prompt = "Hello world, what is your name?"; - console.log('Sending prompt:', prompt); + const prompt = "Hello world, what is your name?" + console.log("Sending prompt:", prompt) // Start task try { - await api.startNewTask(prompt); - console.log('Task started'); + await api.startNewTask(prompt) + console.log("Task started") } catch (error) { - console.error('Error starting task:', error); - throw error; + console.error("Error starting task:", error) + throw error } // Wait for task to appear in history with tokens - startTime = Date.now(); + startTime = Date.now() while (Date.now() - startTime < timeout) { - const state = await provider.getState(); - const task = state.taskHistory?.[0]; + const state = await provider.getState() + const task = state.taskHistory?.[0] if (task && task.tokensOut > 0) { - console.log('Task completed with tokens:', task); - break; + console.log("Task completed with tokens:", task) + break } - await new Promise(resolve => setTimeout(resolve, interval)); + await new Promise((resolve) => setTimeout(resolve, interval)) } // Wait for messages to be processed - startTime = Date.now(); - let responseReceived = false; + startTime = Date.now() + let responseReceived = false while (Date.now() - startTime < timeout) { // Check provider.clineMessages - const messages = provider.clineMessages; + const messages = provider.clineMessages if (messages && messages.length > 0) { - console.log('Provider messages:', JSON.stringify(messages, null, 2)); + console.log("Provider messages:", JSON.stringify(messages, null, 2)) // @ts-ignore - const hasResponse = messages.some(m => - m.type === 'say' && - m.text && - m.text.toLowerCase().includes('cline') - ); + const hasResponse = messages.some( + (m: { type: string; text: string }) => + m.type === "say" && m.text && m.text.toLowerCase().includes("cline"), + ) if (hasResponse) { - console.log('Found response containing "Cline" in provider messages'); - responseReceived = true; - break; + console.log('Found response containing "Cline" in provider messages') + responseReceived = true + break } } // Check provider.cline.clineMessages - const clineMessages = provider.cline?.clineMessages; + const clineMessages = provider.cline?.clineMessages if (clineMessages && clineMessages.length > 0) { - console.log('Cline messages:', JSON.stringify(clineMessages, null, 2)); + console.log("Cline messages:", JSON.stringify(clineMessages, null, 2)) // @ts-ignore - const hasResponse = clineMessages.some(m => - m.type === 'say' && - m.text && - m.text.toLowerCase().includes('cline') - ); + const hasResponse = clineMessages.some( + (m: { type: string; text: string }) => + m.type === "say" && m.text && m.text.toLowerCase().includes("cline"), + ) if (hasResponse) { - console.log('Found response containing "Cline" in cline messages'); - responseReceived = true; - break; + console.log('Found response containing "Cline" in cline messages') + responseReceived = true + break } } - await new Promise(resolve => setTimeout(resolve, interval)); + await new Promise((resolve) => setTimeout(resolve, interval)) } if (!responseReceived) { - console.log('Final provider state:', await provider.getState()); - console.log('Final cline messages:', provider.cline?.clineMessages); - throw new Error('Did not receive expected response containing "Cline"'); + console.log("Final provider state:", await provider.getState()) + console.log("Final cline messages:", provider.cline?.clineMessages) + throw new Error('Did not receive expected response containing "Cline"') } } finally { - panel.dispose(); + panel.dispose() } - }); -}); + }) +}) diff --git a/src/test/tsconfig.json b/src/test/tsconfig.json index 0560c90..5c488e9 100644 --- a/src/test/tsconfig.json +++ b/src/test/tsconfig.json @@ -1,19 +1,19 @@ { - "compilerOptions": { - "module": "commonjs", - "target": "ES2020", - "lib": ["ES2020"], - "sourceMap": true, - "rootDir": "../..", - "strict": false, - "noImplicitAny": false, - "noImplicitThis": false, - "alwaysStrict": false, - "skipLibCheck": true, - "baseUrl": "../..", - "paths": { - "*": ["*", "src/*"] - } - }, - "exclude": ["node_modules", ".vscode-test"] -} \ No newline at end of file + "compilerOptions": { + "module": "commonjs", + "target": "ES2020", + "lib": ["ES2020"], + "sourceMap": true, + "rootDir": "../..", + "strict": false, + "noImplicitAny": false, + "noImplicitThis": false, + "alwaysStrict": false, + "skipLibCheck": true, + "baseUrl": "../..", + "paths": { + "*": ["*", "src/*"] + } + }, + "exclude": ["node_modules", ".vscode-test"] +} diff --git a/src/utils/__tests__/cost.test.ts b/src/utils/__tests__/cost.test.ts index b1a44aa..e390c4a 100644 --- a/src/utils/__tests__/cost.test.ts +++ b/src/utils/__tests__/cost.test.ts @@ -1,97 +1,97 @@ -import { calculateApiCost } from '../cost'; -import { ModelInfo } from '../../shared/api'; +import { calculateApiCost } from "../cost" +import { ModelInfo } from "../../shared/api" -describe('Cost Utility', () => { - describe('calculateApiCost', () => { - const mockModelInfo: ModelInfo = { - maxTokens: 8192, - contextWindow: 200_000, - supportsPromptCache: true, - inputPrice: 3.0, // $3 per million tokens - outputPrice: 15.0, // $15 per million tokens - cacheWritesPrice: 3.75, // $3.75 per million tokens - cacheReadsPrice: 0.3, // $0.30 per million tokens - }; +describe("Cost Utility", () => { + describe("calculateApiCost", () => { + const mockModelInfo: ModelInfo = { + maxTokens: 8192, + contextWindow: 200_000, + supportsPromptCache: true, + inputPrice: 3.0, // $3 per million tokens + outputPrice: 15.0, // $15 per million tokens + cacheWritesPrice: 3.75, // $3.75 per million tokens + cacheReadsPrice: 0.3, // $0.30 per million tokens + } - it('should calculate basic input/output costs correctly', () => { - const cost = calculateApiCost(mockModelInfo, 1000, 500); - - // Input cost: (3.0 / 1_000_000) * 1000 = 0.003 - // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 - // Total: 0.003 + 0.0075 = 0.0105 - expect(cost).toBe(0.0105); - }); + it("should calculate basic input/output costs correctly", () => { + const cost = calculateApiCost(mockModelInfo, 1000, 500) - it('should handle cache writes cost', () => { - const cost = calculateApiCost(mockModelInfo, 1000, 500, 2000); - - // Input cost: (3.0 / 1_000_000) * 1000 = 0.003 - // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 - // Cache writes: (3.75 / 1_000_000) * 2000 = 0.0075 - // Total: 0.003 + 0.0075 + 0.0075 = 0.018 - expect(cost).toBeCloseTo(0.018, 6); - }); + // Input cost: (3.0 / 1_000_000) * 1000 = 0.003 + // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 + // Total: 0.003 + 0.0075 = 0.0105 + expect(cost).toBe(0.0105) + }) - it('should handle cache reads cost', () => { - const cost = calculateApiCost(mockModelInfo, 1000, 500, undefined, 3000); - - // Input cost: (3.0 / 1_000_000) * 1000 = 0.003 - // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 - // Cache reads: (0.3 / 1_000_000) * 3000 = 0.0009 - // Total: 0.003 + 0.0075 + 0.0009 = 0.0114 - expect(cost).toBe(0.0114); - }); + it("should handle cache writes cost", () => { + const cost = calculateApiCost(mockModelInfo, 1000, 500, 2000) - it('should handle all cost components together', () => { - const cost = calculateApiCost(mockModelInfo, 1000, 500, 2000, 3000); - - // Input cost: (3.0 / 1_000_000) * 1000 = 0.003 - // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 - // Cache writes: (3.75 / 1_000_000) * 2000 = 0.0075 - // Cache reads: (0.3 / 1_000_000) * 3000 = 0.0009 - // Total: 0.003 + 0.0075 + 0.0075 + 0.0009 = 0.0189 - expect(cost).toBe(0.0189); - }); + // Input cost: (3.0 / 1_000_000) * 1000 = 0.003 + // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 + // Cache writes: (3.75 / 1_000_000) * 2000 = 0.0075 + // Total: 0.003 + 0.0075 + 0.0075 = 0.018 + expect(cost).toBeCloseTo(0.018, 6) + }) - it('should handle missing prices gracefully', () => { - const modelWithoutPrices: ModelInfo = { - maxTokens: 8192, - contextWindow: 200_000, - supportsPromptCache: true - }; + it("should handle cache reads cost", () => { + const cost = calculateApiCost(mockModelInfo, 1000, 500, undefined, 3000) - const cost = calculateApiCost(modelWithoutPrices, 1000, 500, 2000, 3000); - expect(cost).toBe(0); - }); + // Input cost: (3.0 / 1_000_000) * 1000 = 0.003 + // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 + // Cache reads: (0.3 / 1_000_000) * 3000 = 0.0009 + // Total: 0.003 + 0.0075 + 0.0009 = 0.0114 + expect(cost).toBe(0.0114) + }) - it('should handle zero tokens', () => { - const cost = calculateApiCost(mockModelInfo, 0, 0, 0, 0); - expect(cost).toBe(0); - }); + it("should handle all cost components together", () => { + const cost = calculateApiCost(mockModelInfo, 1000, 500, 2000, 3000) - it('should handle undefined cache values', () => { - const cost = calculateApiCost(mockModelInfo, 1000, 500); - - // Input cost: (3.0 / 1_000_000) * 1000 = 0.003 - // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 - // Total: 0.003 + 0.0075 = 0.0105 - expect(cost).toBe(0.0105); - }); + // Input cost: (3.0 / 1_000_000) * 1000 = 0.003 + // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 + // Cache writes: (3.75 / 1_000_000) * 2000 = 0.0075 + // Cache reads: (0.3 / 1_000_000) * 3000 = 0.0009 + // Total: 0.003 + 0.0075 + 0.0075 + 0.0009 = 0.0189 + expect(cost).toBe(0.0189) + }) - it('should handle missing cache prices', () => { - const modelWithoutCachePrices: ModelInfo = { - ...mockModelInfo, - cacheWritesPrice: undefined, - cacheReadsPrice: undefined - }; + it("should handle missing prices gracefully", () => { + const modelWithoutPrices: ModelInfo = { + maxTokens: 8192, + contextWindow: 200_000, + supportsPromptCache: true, + } - const cost = calculateApiCost(modelWithoutCachePrices, 1000, 500, 2000, 3000); - - // Should only include input and output costs - // Input cost: (3.0 / 1_000_000) * 1000 = 0.003 - // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 - // Total: 0.003 + 0.0075 = 0.0105 - expect(cost).toBe(0.0105); - }); - }); -}); \ No newline at end of file + const cost = calculateApiCost(modelWithoutPrices, 1000, 500, 2000, 3000) + expect(cost).toBe(0) + }) + + it("should handle zero tokens", () => { + const cost = calculateApiCost(mockModelInfo, 0, 0, 0, 0) + expect(cost).toBe(0) + }) + + it("should handle undefined cache values", () => { + const cost = calculateApiCost(mockModelInfo, 1000, 500) + + // Input cost: (3.0 / 1_000_000) * 1000 = 0.003 + // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 + // Total: 0.003 + 0.0075 = 0.0105 + expect(cost).toBe(0.0105) + }) + + it("should handle missing cache prices", () => { + const modelWithoutCachePrices: ModelInfo = { + ...mockModelInfo, + cacheWritesPrice: undefined, + cacheReadsPrice: undefined, + } + + const cost = calculateApiCost(modelWithoutCachePrices, 1000, 500, 2000, 3000) + + // Should only include input and output costs + // Input cost: (3.0 / 1_000_000) * 1000 = 0.003 + // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 + // Total: 0.003 + 0.0075 = 0.0105 + expect(cost).toBe(0.0105) + }) + }) +}) diff --git a/src/utils/__tests__/enhance-prompt.test.ts b/src/utils/__tests__/enhance-prompt.test.ts index 61b89c1..69fdd04 100644 --- a/src/utils/__tests__/enhance-prompt.test.ts +++ b/src/utils/__tests__/enhance-prompt.test.ts @@ -1,126 +1,126 @@ -import { enhancePrompt } from '../enhance-prompt' -import { ApiConfiguration } from '../../shared/api' -import { buildApiHandler, SingleCompletionHandler } from '../../api' -import { defaultPrompts } from '../../shared/modes' +import { enhancePrompt } from "../enhance-prompt" +import { ApiConfiguration } from "../../shared/api" +import { buildApiHandler, SingleCompletionHandler } from "../../api" +import { defaultPrompts } from "../../shared/modes" // Mock the API handler -jest.mock('../../api', () => ({ - buildApiHandler: jest.fn() +jest.mock("../../api", () => ({ + buildApiHandler: jest.fn(), })) -describe('enhancePrompt', () => { - const mockApiConfig: ApiConfiguration = { - apiProvider: 'openai', - openAiApiKey: 'test-key', - openAiBaseUrl: 'https://api.openai.com/v1' - } +describe("enhancePrompt", () => { + const mockApiConfig: ApiConfiguration = { + apiProvider: "openai", + openAiApiKey: "test-key", + openAiBaseUrl: "https://api.openai.com/v1", + } - beforeEach(() => { - jest.clearAllMocks() - - // Mock the API handler with a completePrompt method - ;(buildApiHandler as jest.Mock).mockReturnValue({ - completePrompt: jest.fn().mockResolvedValue('Enhanced prompt'), - createMessage: jest.fn(), - getModel: jest.fn().mockReturnValue({ - id: 'test-model', - info: { - maxTokens: 4096, - contextWindow: 8192, - supportsPromptCache: false - } - }) - } as unknown as SingleCompletionHandler) - }) + beforeEach(() => { + jest.clearAllMocks() - it('enhances prompt using default enhancement prompt when no custom prompt provided', async () => { - const result = await enhancePrompt(mockApiConfig, 'Test prompt') - - expect(result).toBe('Enhanced prompt') - const handler = buildApiHandler(mockApiConfig) - expect((handler as any).completePrompt).toHaveBeenCalledWith( - `${defaultPrompts.enhance}\n\nTest prompt` - ) - }) + // Mock the API handler with a completePrompt method + ;(buildApiHandler as jest.Mock).mockReturnValue({ + completePrompt: jest.fn().mockResolvedValue("Enhanced prompt"), + createMessage: jest.fn(), + getModel: jest.fn().mockReturnValue({ + id: "test-model", + info: { + maxTokens: 4096, + contextWindow: 8192, + supportsPromptCache: false, + }, + }), + } as unknown as SingleCompletionHandler) + }) - it('enhances prompt using custom enhancement prompt when provided', async () => { - const customEnhancePrompt = 'You are a custom prompt enhancer' - - const result = await enhancePrompt(mockApiConfig, 'Test prompt', customEnhancePrompt) - - expect(result).toBe('Enhanced prompt') - const handler = buildApiHandler(mockApiConfig) - expect((handler as any).completePrompt).toHaveBeenCalledWith( - `${customEnhancePrompt}\n\nTest prompt` - ) - }) + it("enhances prompt using default enhancement prompt when no custom prompt provided", async () => { + const result = await enhancePrompt(mockApiConfig, "Test prompt") - it('throws error for empty prompt input', async () => { - await expect(enhancePrompt(mockApiConfig, '')).rejects.toThrow('No prompt text provided') - }) + expect(result).toBe("Enhanced prompt") + const handler = buildApiHandler(mockApiConfig) + expect((handler as any).completePrompt).toHaveBeenCalledWith(`${defaultPrompts.enhance}\n\nTest prompt`) + }) - it('throws error for missing API configuration', async () => { - await expect(enhancePrompt({} as ApiConfiguration, 'Test prompt')).rejects.toThrow('No valid API configuration provided') - }) + it("enhances prompt using custom enhancement prompt when provided", async () => { + const customEnhancePrompt = "You are a custom prompt enhancer" - it('throws error for API provider that does not support prompt enhancement', async () => { - (buildApiHandler as jest.Mock).mockReturnValue({ - // No completePrompt method - createMessage: jest.fn(), - getModel: jest.fn().mockReturnValue({ - id: 'test-model', - info: { - maxTokens: 4096, - contextWindow: 8192, - supportsPromptCache: false - } - }) - }) + const result = await enhancePrompt(mockApiConfig, "Test prompt", customEnhancePrompt) - await expect(enhancePrompt(mockApiConfig, 'Test prompt')).rejects.toThrow('The selected API provider does not support prompt enhancement') - }) + expect(result).toBe("Enhanced prompt") + const handler = buildApiHandler(mockApiConfig) + expect((handler as any).completePrompt).toHaveBeenCalledWith(`${customEnhancePrompt}\n\nTest prompt`) + }) - it('uses appropriate model based on provider', async () => { - const openRouterConfig: ApiConfiguration = { - apiProvider: 'openrouter', - openRouterApiKey: 'test-key', - openRouterModelId: 'test-model' - } + it("throws error for empty prompt input", async () => { + await expect(enhancePrompt(mockApiConfig, "")).rejects.toThrow("No prompt text provided") + }) - // Mock successful enhancement - ;(buildApiHandler as jest.Mock).mockReturnValue({ - completePrompt: jest.fn().mockResolvedValue('Enhanced prompt'), - createMessage: jest.fn(), - getModel: jest.fn().mockReturnValue({ - id: 'test-model', - info: { - maxTokens: 4096, - contextWindow: 8192, - supportsPromptCache: false - } - }) - } as unknown as SingleCompletionHandler) + it("throws error for missing API configuration", async () => { + await expect(enhancePrompt({} as ApiConfiguration, "Test prompt")).rejects.toThrow( + "No valid API configuration provided", + ) + }) - const result = await enhancePrompt(openRouterConfig, 'Test prompt') - - expect(buildApiHandler).toHaveBeenCalledWith(openRouterConfig) - expect(result).toBe('Enhanced prompt') - }) + it("throws error for API provider that does not support prompt enhancement", async () => { + ;(buildApiHandler as jest.Mock).mockReturnValue({ + // No completePrompt method + createMessage: jest.fn(), + getModel: jest.fn().mockReturnValue({ + id: "test-model", + info: { + maxTokens: 4096, + contextWindow: 8192, + supportsPromptCache: false, + }, + }), + }) - it('propagates API errors', async () => { - (buildApiHandler as jest.Mock).mockReturnValue({ - completePrompt: jest.fn().mockRejectedValue(new Error('API Error')), - createMessage: jest.fn(), - getModel: jest.fn().mockReturnValue({ - id: 'test-model', - info: { - maxTokens: 4096, - contextWindow: 8192, - supportsPromptCache: false - } - }) - } as unknown as SingleCompletionHandler) + await expect(enhancePrompt(mockApiConfig, "Test prompt")).rejects.toThrow( + "The selected API provider does not support prompt enhancement", + ) + }) - await expect(enhancePrompt(mockApiConfig, 'Test prompt')).rejects.toThrow('API Error') - }) -}) \ No newline at end of file + it("uses appropriate model based on provider", async () => { + const openRouterConfig: ApiConfiguration = { + apiProvider: "openrouter", + openRouterApiKey: "test-key", + openRouterModelId: "test-model", + } + + // Mock successful enhancement + ;(buildApiHandler as jest.Mock).mockReturnValue({ + completePrompt: jest.fn().mockResolvedValue("Enhanced prompt"), + createMessage: jest.fn(), + getModel: jest.fn().mockReturnValue({ + id: "test-model", + info: { + maxTokens: 4096, + contextWindow: 8192, + supportsPromptCache: false, + }, + }), + } as unknown as SingleCompletionHandler) + + const result = await enhancePrompt(openRouterConfig, "Test prompt") + + expect(buildApiHandler).toHaveBeenCalledWith(openRouterConfig) + expect(result).toBe("Enhanced prompt") + }) + + it("propagates API errors", async () => { + ;(buildApiHandler as jest.Mock).mockReturnValue({ + completePrompt: jest.fn().mockRejectedValue(new Error("API Error")), + createMessage: jest.fn(), + getModel: jest.fn().mockReturnValue({ + id: "test-model", + info: { + maxTokens: 4096, + contextWindow: 8192, + supportsPromptCache: false, + }, + }), + } as unknown as SingleCompletionHandler) + + await expect(enhancePrompt(mockApiConfig, "Test prompt")).rejects.toThrow("API Error") + }) +}) diff --git a/src/utils/__tests__/git.test.ts b/src/utils/__tests__/git.test.ts index 6e2355d..7fe59a1 100644 --- a/src/utils/__tests__/git.test.ts +++ b/src/utils/__tests__/git.test.ts @@ -1,336 +1,344 @@ -import { jest } from '@jest/globals' -import { searchCommits, getCommitInfo, getWorkingState, GitCommit } from '../git' -import { ExecException } from 'child_process' +import { jest } from "@jest/globals" +import { searchCommits, getCommitInfo, getWorkingState, GitCommit } from "../git" +import { ExecException } from "child_process" type ExecFunction = ( - command: string, - options: { cwd?: string }, - callback: (error: ExecException | null, result?: { stdout: string; stderr: string }) => void + command: string, + options: { cwd?: string }, + callback: (error: ExecException | null, result?: { stdout: string; stderr: string }) => void, ) => void type PromisifiedExec = (command: string, options?: { cwd?: string }) => Promise<{ stdout: string; stderr: string }> // Mock child_process.exec -jest.mock('child_process', () => ({ - exec: jest.fn() +jest.mock("child_process", () => ({ + exec: jest.fn(), })) // Mock util.promisify to return our own mock function -jest.mock('util', () => ({ - promisify: jest.fn((fn: ExecFunction): PromisifiedExec => { - return async (command: string, options?: { cwd?: string }) => { - // Call the original mock to maintain the mock implementation - return new Promise((resolve, reject) => { - fn(command, options || {}, (error: ExecException | null, result?: { stdout: string; stderr: string }) => { - if (error) { - reject(error) - } else { - resolve(result!) - } - }) - }) - } - }) +jest.mock("util", () => ({ + promisify: jest.fn((fn: ExecFunction): PromisifiedExec => { + return async (command: string, options?: { cwd?: string }) => { + // Call the original mock to maintain the mock implementation + return new Promise((resolve, reject) => { + fn( + command, + options || {}, + (error: ExecException | null, result?: { stdout: string; stderr: string }) => { + if (error) { + reject(error) + } else { + resolve(result!) + } + }, + ) + }) + } + }), })) // Mock extract-text -jest.mock('../../integrations/misc/extract-text', () => ({ - truncateOutput: jest.fn(text => text) +jest.mock("../../integrations/misc/extract-text", () => ({ + truncateOutput: jest.fn((text) => text), })) -describe('git utils', () => { - // Get the mock with proper typing - const { exec } = jest.requireMock('child_process') as { exec: jest.MockedFunction } - const cwd = '/test/path' +describe("git utils", () => { + // Get the mock with proper typing + const { exec } = jest.requireMock("child_process") as { exec: jest.MockedFunction } + const cwd = "/test/path" - beforeEach(() => { - jest.clearAllMocks() - }) + beforeEach(() => { + jest.clearAllMocks() + }) - describe('searchCommits', () => { - const mockCommitData = [ - 'abc123def456', - 'abc123', - 'fix: test commit', - 'John Doe', - '2024-01-06', - 'def456abc789', - 'def456', - 'feat: new feature', - 'Jane Smith', - '2024-01-05' - ].join('\n') + describe("searchCommits", () => { + const mockCommitData = [ + "abc123def456", + "abc123", + "fix: test commit", + "John Doe", + "2024-01-06", + "def456abc789", + "def456", + "feat: new feature", + "Jane Smith", + "2024-01-05", + ].join("\n") - it('should return commits when git is installed and repo exists', async () => { - // Set up mock responses - const responses = new Map([ - ['git --version', { stdout: 'git version 2.39.2', stderr: '' }], - ['git rev-parse --git-dir', { stdout: '.git', stderr: '' }], - ['git log -n 10 --format="%H%n%h%n%s%n%an%n%ad" --date=short --grep="test" --regexp-ignore-case', { stdout: mockCommitData, stderr: '' }] - ]) + it("should return commits when git is installed and repo exists", async () => { + // Set up mock responses + const responses = new Map([ + ["git --version", { stdout: "git version 2.39.2", stderr: "" }], + ["git rev-parse --git-dir", { stdout: ".git", stderr: "" }], + [ + 'git log -n 10 --format="%H%n%h%n%s%n%an%n%ad" --date=short --grep="test" --regexp-ignore-case', + { stdout: mockCommitData, stderr: "" }, + ], + ]) - exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { - // Find matching response - for (const [cmd, response] of responses) { - if (command === cmd) { - callback(null, response) - return - } - } - callback(new Error(`Unexpected command: ${command}`)) - }) + exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { + // Find matching response + for (const [cmd, response] of responses) { + if (command === cmd) { + callback(null, response) + return + } + } + callback(new Error(`Unexpected command: ${command}`)) + }) - const result = await searchCommits('test', cwd) + const result = await searchCommits("test", cwd) - // First verify the result is correct - expect(result).toHaveLength(2) - expect(result[0]).toEqual({ - hash: 'abc123def456', - shortHash: 'abc123', - subject: 'fix: test commit', - author: 'John Doe', - date: '2024-01-06' - }) + // First verify the result is correct + expect(result).toHaveLength(2) + expect(result[0]).toEqual({ + hash: "abc123def456", + shortHash: "abc123", + subject: "fix: test commit", + author: "John Doe", + date: "2024-01-06", + }) - // Then verify all commands were called correctly - expect(exec).toHaveBeenCalledWith( - 'git --version', - {}, - expect.any(Function) - ) - expect(exec).toHaveBeenCalledWith( - 'git rev-parse --git-dir', - { cwd }, - expect.any(Function) - ) - expect(exec).toHaveBeenCalledWith( - 'git log -n 10 --format="%H%n%h%n%s%n%an%n%ad" --date=short --grep="test" --regexp-ignore-case', - { cwd }, - expect.any(Function) - ) - }, 20000) + // Then verify all commands were called correctly + expect(exec).toHaveBeenCalledWith("git --version", {}, expect.any(Function)) + expect(exec).toHaveBeenCalledWith("git rev-parse --git-dir", { cwd }, expect.any(Function)) + expect(exec).toHaveBeenCalledWith( + 'git log -n 10 --format="%H%n%h%n%s%n%an%n%ad" --date=short --grep="test" --regexp-ignore-case', + { cwd }, + expect.any(Function), + ) + }, 20000) - it('should return empty array when git is not installed', async () => { - exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { - if (command === 'git --version') { - callback(new Error('git not found')) - return - } - callback(new Error('Unexpected command')) - }) + it("should return empty array when git is not installed", async () => { + exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { + if (command === "git --version") { + callback(new Error("git not found")) + return + } + callback(new Error("Unexpected command")) + }) - const result = await searchCommits('test', cwd) - expect(result).toEqual([]) - expect(exec).toHaveBeenCalledWith('git --version', {}, expect.any(Function)) - }) + const result = await searchCommits("test", cwd) + expect(result).toEqual([]) + expect(exec).toHaveBeenCalledWith("git --version", {}, expect.any(Function)) + }) - it('should return empty array when not in a git repository', async () => { - const responses = new Map([ - ['git --version', { stdout: 'git version 2.39.2', stderr: '' }], - ['git rev-parse --git-dir', null] // null indicates error should be called - ]) + it("should return empty array when not in a git repository", async () => { + const responses = new Map([ + ["git --version", { stdout: "git version 2.39.2", stderr: "" }], + ["git rev-parse --git-dir", null], // null indicates error should be called + ]) - exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { - const response = responses.get(command) - if (response === null) { - callback(new Error('not a git repository')) - } else if (response) { - callback(null, response) - } else { - callback(new Error('Unexpected command')) - } - }) + exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { + const response = responses.get(command) + if (response === null) { + callback(new Error("not a git repository")) + } else if (response) { + callback(null, response) + } else { + callback(new Error("Unexpected command")) + } + }) - const result = await searchCommits('test', cwd) - expect(result).toEqual([]) - expect(exec).toHaveBeenCalledWith('git --version', {}, expect.any(Function)) - expect(exec).toHaveBeenCalledWith('git rev-parse --git-dir', { cwd }, expect.any(Function)) - }) + const result = await searchCommits("test", cwd) + expect(result).toEqual([]) + expect(exec).toHaveBeenCalledWith("git --version", {}, expect.any(Function)) + expect(exec).toHaveBeenCalledWith("git rev-parse --git-dir", { cwd }, expect.any(Function)) + }) - it('should handle hash search when grep search returns no results', async () => { - const responses = new Map([ - ['git --version', { stdout: 'git version 2.39.2', stderr: '' }], - ['git rev-parse --git-dir', { stdout: '.git', stderr: '' }], - ['git log -n 10 --format="%H%n%h%n%s%n%an%n%ad" --date=short --grep="abc123" --regexp-ignore-case', { stdout: '', stderr: '' }], - ['git log -n 10 --format="%H%n%h%n%s%n%an%n%ad" --date=short --author-date-order abc123', { stdout: mockCommitData, stderr: '' }] - ]) + it("should handle hash search when grep search returns no results", async () => { + const responses = new Map([ + ["git --version", { stdout: "git version 2.39.2", stderr: "" }], + ["git rev-parse --git-dir", { stdout: ".git", stderr: "" }], + [ + 'git log -n 10 --format="%H%n%h%n%s%n%an%n%ad" --date=short --grep="abc123" --regexp-ignore-case', + { stdout: "", stderr: "" }, + ], + [ + 'git log -n 10 --format="%H%n%h%n%s%n%an%n%ad" --date=short --author-date-order abc123', + { stdout: mockCommitData, stderr: "" }, + ], + ]) - exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { - for (const [cmd, response] of responses) { - if (command === cmd) { - callback(null, response) - return - } - } - callback(new Error('Unexpected command')) - }) + exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { + for (const [cmd, response] of responses) { + if (command === cmd) { + callback(null, response) + return + } + } + callback(new Error("Unexpected command")) + }) - const result = await searchCommits('abc123', cwd) - expect(result).toHaveLength(2) - expect(result[0]).toEqual({ - hash: 'abc123def456', - shortHash: 'abc123', - subject: 'fix: test commit', - author: 'John Doe', - date: '2024-01-06' - }) - }) - }) + const result = await searchCommits("abc123", cwd) + expect(result).toHaveLength(2) + expect(result[0]).toEqual({ + hash: "abc123def456", + shortHash: "abc123", + subject: "fix: test commit", + author: "John Doe", + date: "2024-01-06", + }) + }) + }) - describe('getCommitInfo', () => { - const mockCommitInfo = [ - 'abc123def456', - 'abc123', - 'fix: test commit', - 'John Doe', - '2024-01-06', - 'Detailed description' - ].join('\n') - const mockStats = '1 file changed, 2 insertions(+), 1 deletion(-)' - const mockDiff = '@@ -1,1 +1,2 @@\n-old line\n+new line' + describe("getCommitInfo", () => { + const mockCommitInfo = [ + "abc123def456", + "abc123", + "fix: test commit", + "John Doe", + "2024-01-06", + "Detailed description", + ].join("\n") + const mockStats = "1 file changed, 2 insertions(+), 1 deletion(-)" + const mockDiff = "@@ -1,1 +1,2 @@\n-old line\n+new line" - it('should return formatted commit info', async () => { - const responses = new Map([ - ['git --version', { stdout: 'git version 2.39.2', stderr: '' }], - ['git rev-parse --git-dir', { stdout: '.git', stderr: '' }], - ['git show --format="%H%n%h%n%s%n%an%n%ad%n%b" --no-patch abc123', { stdout: mockCommitInfo, stderr: '' }], - ['git show --stat --format="" abc123', { stdout: mockStats, stderr: '' }], - ['git show --format="" abc123', { stdout: mockDiff, stderr: '' }] - ]) + it("should return formatted commit info", async () => { + const responses = new Map([ + ["git --version", { stdout: "git version 2.39.2", stderr: "" }], + ["git rev-parse --git-dir", { stdout: ".git", stderr: "" }], + [ + 'git show --format="%H%n%h%n%s%n%an%n%ad%n%b" --no-patch abc123', + { stdout: mockCommitInfo, stderr: "" }, + ], + ['git show --stat --format="" abc123', { stdout: mockStats, stderr: "" }], + ['git show --format="" abc123', { stdout: mockDiff, stderr: "" }], + ]) - exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { - for (const [cmd, response] of responses) { - if (command.startsWith(cmd)) { - callback(null, response) - return - } - } - callback(new Error('Unexpected command')) - }) + exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { + for (const [cmd, response] of responses) { + if (command.startsWith(cmd)) { + callback(null, response) + return + } + } + callback(new Error("Unexpected command")) + }) - const result = await getCommitInfo('abc123', cwd) - expect(result).toContain('Commit: abc123') - expect(result).toContain('Author: John Doe') - expect(result).toContain('Files Changed:') - expect(result).toContain('Full Changes:') - }) + const result = await getCommitInfo("abc123", cwd) + expect(result).toContain("Commit: abc123") + expect(result).toContain("Author: John Doe") + expect(result).toContain("Files Changed:") + expect(result).toContain("Full Changes:") + }) - it('should return error message when git is not installed', async () => { - exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { - if (command === 'git --version') { - callback(new Error('git not found')) - return - } - callback(new Error('Unexpected command')) - }) + it("should return error message when git is not installed", async () => { + exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { + if (command === "git --version") { + callback(new Error("git not found")) + return + } + callback(new Error("Unexpected command")) + }) - const result = await getCommitInfo('abc123', cwd) - expect(result).toBe('Git is not installed') - }) + const result = await getCommitInfo("abc123", cwd) + expect(result).toBe("Git is not installed") + }) - it('should return error message when not in a git repository', async () => { - const responses = new Map([ - ['git --version', { stdout: 'git version 2.39.2', stderr: '' }], - ['git rev-parse --git-dir', null] // null indicates error should be called - ]) + it("should return error message when not in a git repository", async () => { + const responses = new Map([ + ["git --version", { stdout: "git version 2.39.2", stderr: "" }], + ["git rev-parse --git-dir", null], // null indicates error should be called + ]) - exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { - const response = responses.get(command) - if (response === null) { - callback(new Error('not a git repository')) - } else if (response) { - callback(null, response) - } else { - callback(new Error('Unexpected command')) - } - }) + exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { + const response = responses.get(command) + if (response === null) { + callback(new Error("not a git repository")) + } else if (response) { + callback(null, response) + } else { + callback(new Error("Unexpected command")) + } + }) - const result = await getCommitInfo('abc123', cwd) - expect(result).toBe('Not a git repository') - }) - }) + const result = await getCommitInfo("abc123", cwd) + expect(result).toBe("Not a git repository") + }) + }) - describe('getWorkingState', () => { - const mockStatus = ' M src/file1.ts\n?? src/file2.ts' - const mockDiff = '@@ -1,1 +1,2 @@\n-old line\n+new line' + describe("getWorkingState", () => { + const mockStatus = " M src/file1.ts\n?? src/file2.ts" + const mockDiff = "@@ -1,1 +1,2 @@\n-old line\n+new line" - it('should return working directory changes', async () => { - const responses = new Map([ - ['git --version', { stdout: 'git version 2.39.2', stderr: '' }], - ['git rev-parse --git-dir', { stdout: '.git', stderr: '' }], - ['git status --short', { stdout: mockStatus, stderr: '' }], - ['git diff HEAD', { stdout: mockDiff, stderr: '' }] - ]) + it("should return working directory changes", async () => { + const responses = new Map([ + ["git --version", { stdout: "git version 2.39.2", stderr: "" }], + ["git rev-parse --git-dir", { stdout: ".git", stderr: "" }], + ["git status --short", { stdout: mockStatus, stderr: "" }], + ["git diff HEAD", { stdout: mockDiff, stderr: "" }], + ]) - exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { - for (const [cmd, response] of responses) { - if (command === cmd) { - callback(null, response) - return - } - } - callback(new Error('Unexpected command')) - }) + exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { + for (const [cmd, response] of responses) { + if (command === cmd) { + callback(null, response) + return + } + } + callback(new Error("Unexpected command")) + }) - const result = await getWorkingState(cwd) - expect(result).toContain('Working directory changes:') - expect(result).toContain('src/file1.ts') - expect(result).toContain('src/file2.ts') - }) + const result = await getWorkingState(cwd) + expect(result).toContain("Working directory changes:") + expect(result).toContain("src/file1.ts") + expect(result).toContain("src/file2.ts") + }) - it('should return message when working directory is clean', async () => { - const responses = new Map([ - ['git --version', { stdout: 'git version 2.39.2', stderr: '' }], - ['git rev-parse --git-dir', { stdout: '.git', stderr: '' }], - ['git status --short', { stdout: '', stderr: '' }] - ]) + it("should return message when working directory is clean", async () => { + const responses = new Map([ + ["git --version", { stdout: "git version 2.39.2", stderr: "" }], + ["git rev-parse --git-dir", { stdout: ".git", stderr: "" }], + ["git status --short", { stdout: "", stderr: "" }], + ]) - exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { - for (const [cmd, response] of responses) { - if (command === cmd) { - callback(null, response) - return - } - } - callback(new Error('Unexpected command')) - }) + exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { + for (const [cmd, response] of responses) { + if (command === cmd) { + callback(null, response) + return + } + } + callback(new Error("Unexpected command")) + }) - const result = await getWorkingState(cwd) - expect(result).toBe('No changes in working directory') - }) + const result = await getWorkingState(cwd) + expect(result).toBe("No changes in working directory") + }) - it('should return error message when git is not installed', async () => { - exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { - if (command === 'git --version') { - callback(new Error('git not found')) - return - } - callback(new Error('Unexpected command')) - }) + it("should return error message when git is not installed", async () => { + exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { + if (command === "git --version") { + callback(new Error("git not found")) + return + } + callback(new Error("Unexpected command")) + }) - const result = await getWorkingState(cwd) - expect(result).toBe('Git is not installed') - }) + const result = await getWorkingState(cwd) + expect(result).toBe("Git is not installed") + }) - it('should return error message when not in a git repository', async () => { - const responses = new Map([ - ['git --version', { stdout: 'git version 2.39.2', stderr: '' }], - ['git rev-parse --git-dir', null] // null indicates error should be called - ]) + it("should return error message when not in a git repository", async () => { + const responses = new Map([ + ["git --version", { stdout: "git version 2.39.2", stderr: "" }], + ["git rev-parse --git-dir", null], // null indicates error should be called + ]) - exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { - const response = responses.get(command) - if (response === null) { - callback(new Error('not a git repository')) - } else if (response) { - callback(null, response) - } else { - callback(new Error('Unexpected command')) - } - }) + exec.mockImplementation((command: string, options: { cwd?: string }, callback: Function) => { + const response = responses.get(command) + if (response === null) { + callback(new Error("not a git repository")) + } else if (response) { + callback(null, response) + } else { + callback(new Error("Unexpected command")) + } + }) - const result = await getWorkingState(cwd) - expect(result).toBe('Not a git repository') - }) - }) -}) \ No newline at end of file + const result = await getWorkingState(cwd) + expect(result).toBe("Not a git repository") + }) + }) +}) diff --git a/src/utils/__tests__/path.test.ts b/src/utils/__tests__/path.test.ts index 5b35c05..1d20e86 100644 --- a/src/utils/__tests__/path.test.ts +++ b/src/utils/__tests__/path.test.ts @@ -1,135 +1,135 @@ -import { arePathsEqual, getReadablePath } from '../path'; -import * as path from 'path'; -import os from 'os'; +import { arePathsEqual, getReadablePath } from "../path" +import * as path from "path" +import os from "os" -describe('Path Utilities', () => { - const originalPlatform = process.platform; +describe("Path Utilities", () => { + const originalPlatform = process.platform - afterEach(() => { - Object.defineProperty(process, 'platform', { - value: originalPlatform - }); - }); + afterEach(() => { + Object.defineProperty(process, "platform", { + value: originalPlatform, + }) + }) - describe('String.prototype.toPosix', () => { - it('should convert backslashes to forward slashes', () => { - const windowsPath = 'C:\\Users\\test\\file.txt'; - expect(windowsPath.toPosix()).toBe('C:/Users/test/file.txt'); - }); + describe("String.prototype.toPosix", () => { + it("should convert backslashes to forward slashes", () => { + const windowsPath = "C:\\Users\\test\\file.txt" + expect(windowsPath.toPosix()).toBe("C:/Users/test/file.txt") + }) - it('should not modify paths with forward slashes', () => { - const unixPath = '/home/user/file.txt'; - expect(unixPath.toPosix()).toBe('/home/user/file.txt'); - }); + it("should not modify paths with forward slashes", () => { + const unixPath = "/home/user/file.txt" + expect(unixPath.toPosix()).toBe("/home/user/file.txt") + }) - it('should preserve extended-length Windows paths', () => { - const extendedPath = '\\\\?\\C:\\Very\\Long\\Path'; - expect(extendedPath.toPosix()).toBe('\\\\?\\C:\\Very\\Long\\Path'); - }); - }); + it("should preserve extended-length Windows paths", () => { + const extendedPath = "\\\\?\\C:\\Very\\Long\\Path" + expect(extendedPath.toPosix()).toBe("\\\\?\\C:\\Very\\Long\\Path") + }) + }) - describe('arePathsEqual', () => { - describe('on Windows', () => { - beforeEach(() => { - Object.defineProperty(process, 'platform', { - value: 'win32' - }); - }); + describe("arePathsEqual", () => { + describe("on Windows", () => { + beforeEach(() => { + Object.defineProperty(process, "platform", { + value: "win32", + }) + }) - it('should compare paths case-insensitively', () => { - expect(arePathsEqual('C:\\Users\\Test', 'c:\\users\\test')).toBe(true); - }); + it("should compare paths case-insensitively", () => { + expect(arePathsEqual("C:\\Users\\Test", "c:\\users\\test")).toBe(true) + }) - it('should handle different path separators', () => { - // Convert both paths to use forward slashes after normalization - const path1 = path.normalize('C:\\Users\\Test').replace(/\\/g, '/'); - const path2 = path.normalize('C:/Users/Test').replace(/\\/g, '/'); - expect(arePathsEqual(path1, path2)).toBe(true); - }); + it("should handle different path separators", () => { + // Convert both paths to use forward slashes after normalization + const path1 = path.normalize("C:\\Users\\Test").replace(/\\/g, "/") + const path2 = path.normalize("C:/Users/Test").replace(/\\/g, "/") + expect(arePathsEqual(path1, path2)).toBe(true) + }) - it('should normalize paths with ../', () => { - // Convert both paths to use forward slashes after normalization - const path1 = path.normalize('C:\\Users\\Test\\..\\Test').replace(/\\/g, '/'); - const path2 = path.normalize('C:\\Users\\Test').replace(/\\/g, '/'); - expect(arePathsEqual(path1, path2)).toBe(true); - }); - }); + it("should normalize paths with ../", () => { + // Convert both paths to use forward slashes after normalization + const path1 = path.normalize("C:\\Users\\Test\\..\\Test").replace(/\\/g, "/") + const path2 = path.normalize("C:\\Users\\Test").replace(/\\/g, "/") + expect(arePathsEqual(path1, path2)).toBe(true) + }) + }) - describe('on POSIX', () => { - beforeEach(() => { - Object.defineProperty(process, 'platform', { - value: 'darwin' - }); - }); + describe("on POSIX", () => { + beforeEach(() => { + Object.defineProperty(process, "platform", { + value: "darwin", + }) + }) - it('should compare paths case-sensitively', () => { - expect(arePathsEqual('/Users/Test', '/Users/test')).toBe(false); - }); + it("should compare paths case-sensitively", () => { + expect(arePathsEqual("/Users/Test", "/Users/test")).toBe(false) + }) - it('should normalize paths', () => { - expect(arePathsEqual('/Users/./Test', '/Users/Test')).toBe(true); - }); + it("should normalize paths", () => { + expect(arePathsEqual("/Users/./Test", "/Users/Test")).toBe(true) + }) - it('should handle trailing slashes', () => { - expect(arePathsEqual('/Users/Test/', '/Users/Test')).toBe(true); - }); - }); + it("should handle trailing slashes", () => { + expect(arePathsEqual("/Users/Test/", "/Users/Test")).toBe(true) + }) + }) - describe('edge cases', () => { - it('should handle undefined paths', () => { - expect(arePathsEqual(undefined, undefined)).toBe(true); - expect(arePathsEqual('/test', undefined)).toBe(false); - expect(arePathsEqual(undefined, '/test')).toBe(false); - }); + describe("edge cases", () => { + it("should handle undefined paths", () => { + expect(arePathsEqual(undefined, undefined)).toBe(true) + expect(arePathsEqual("/test", undefined)).toBe(false) + expect(arePathsEqual(undefined, "/test")).toBe(false) + }) - it('should handle root paths with trailing slashes', () => { - expect(arePathsEqual('/', '/')).toBe(true); - expect(arePathsEqual('C:\\', 'C:\\')).toBe(true); - }); - }); - }); + it("should handle root paths with trailing slashes", () => { + expect(arePathsEqual("/", "/")).toBe(true) + expect(arePathsEqual("C:\\", "C:\\")).toBe(true) + }) + }) + }) - describe('getReadablePath', () => { - const homeDir = os.homedir(); - const desktop = path.join(homeDir, 'Desktop'); + describe("getReadablePath", () => { + const homeDir = os.homedir() + const desktop = path.join(homeDir, "Desktop") - it('should return basename when path equals cwd', () => { - const cwd = '/Users/test/project'; - expect(getReadablePath(cwd, cwd)).toBe('project'); - }); + it("should return basename when path equals cwd", () => { + const cwd = "/Users/test/project" + expect(getReadablePath(cwd, cwd)).toBe("project") + }) - it('should return relative path when inside cwd', () => { - const cwd = '/Users/test/project'; - const filePath = '/Users/test/project/src/file.txt'; - expect(getReadablePath(cwd, filePath)).toBe('src/file.txt'); - }); + it("should return relative path when inside cwd", () => { + const cwd = "/Users/test/project" + const filePath = "/Users/test/project/src/file.txt" + expect(getReadablePath(cwd, filePath)).toBe("src/file.txt") + }) - it('should return absolute path when outside cwd', () => { - const cwd = '/Users/test/project'; - const filePath = '/Users/test/other/file.txt'; - expect(getReadablePath(cwd, filePath)).toBe('/Users/test/other/file.txt'); - }); + it("should return absolute path when outside cwd", () => { + const cwd = "/Users/test/project" + const filePath = "/Users/test/other/file.txt" + expect(getReadablePath(cwd, filePath)).toBe("/Users/test/other/file.txt") + }) - it('should handle Desktop as cwd', () => { - const filePath = path.join(desktop, 'file.txt'); - expect(getReadablePath(desktop, filePath)).toBe(filePath.toPosix()); - }); + it("should handle Desktop as cwd", () => { + const filePath = path.join(desktop, "file.txt") + expect(getReadablePath(desktop, filePath)).toBe(filePath.toPosix()) + }) - it('should handle undefined relative path', () => { - const cwd = '/Users/test/project'; - expect(getReadablePath(cwd)).toBe('project'); - }); + it("should handle undefined relative path", () => { + const cwd = "/Users/test/project" + expect(getReadablePath(cwd)).toBe("project") + }) - it('should handle parent directory traversal', () => { - const cwd = '/Users/test/project'; - const filePath = '../../other/file.txt'; - expect(getReadablePath(cwd, filePath)).toBe('/Users/other/file.txt'); - }); + it("should handle parent directory traversal", () => { + const cwd = "/Users/test/project" + const filePath = "../../other/file.txt" + expect(getReadablePath(cwd, filePath)).toBe("/Users/other/file.txt") + }) - it('should normalize paths with redundant segments', () => { - const cwd = '/Users/test/project'; - const filePath = '/Users/test/project/./src/../src/file.txt'; - expect(getReadablePath(cwd, filePath)).toBe('src/file.txt'); - }); - }); -}); \ No newline at end of file + it("should normalize paths with redundant segments", () => { + const cwd = "/Users/test/project" + const filePath = "/Users/test/project/./src/../src/file.txt" + expect(getReadablePath(cwd, filePath)).toBe("src/file.txt") + }) + }) +}) diff --git a/src/utils/enhance-prompt.ts b/src/utils/enhance-prompt.ts index d7c7440..3724757 100644 --- a/src/utils/enhance-prompt.ts +++ b/src/utils/enhance-prompt.ts @@ -6,22 +6,26 @@ import { defaultPrompts } from "../shared/modes" * Enhances a prompt using the configured API without creating a full Cline instance or task history. * This is a lightweight alternative that only uses the API's completion functionality. */ -export async function enhancePrompt(apiConfiguration: ApiConfiguration, promptText: string, enhancePrompt?: string): Promise { - if (!promptText) { - throw new Error("No prompt text provided") - } - if (!apiConfiguration || !apiConfiguration.apiProvider) { - throw new Error("No valid API configuration provided") - } - - const handler = buildApiHandler(apiConfiguration) - - // Check if handler supports single completions - if (!('completePrompt' in handler)) { - throw new Error("The selected API provider does not support prompt enhancement") - } - - const enhancePromptText = enhancePrompt ?? defaultPrompts.enhance - const prompt = `${enhancePromptText}\n\n${promptText}` - return (handler as SingleCompletionHandler).completePrompt(prompt) -} \ No newline at end of file +export async function enhancePrompt( + apiConfiguration: ApiConfiguration, + promptText: string, + enhancePrompt?: string, +): Promise { + if (!promptText) { + throw new Error("No prompt text provided") + } + if (!apiConfiguration || !apiConfiguration.apiProvider) { + throw new Error("No valid API configuration provided") + } + + const handler = buildApiHandler(apiConfiguration) + + // Check if handler supports single completions + if (!("completePrompt" in handler)) { + throw new Error("The selected API provider does not support prompt enhancement") + } + + const enhancePromptText = enhancePrompt ?? defaultPrompts.enhance + const prompt = `${enhancePromptText}\n\n${promptText}` + return (handler as SingleCompletionHandler).completePrompt(prompt) +} diff --git a/src/utils/git.ts b/src/utils/git.ts index 0cd957f..640af7f 100644 --- a/src/utils/git.ts +++ b/src/utils/git.ts @@ -15,7 +15,7 @@ export interface GitCommit { async function checkGitRepo(cwd: string): Promise { try { - await execAsync('git rev-parse --git-dir', { cwd }) + await execAsync("git rev-parse --git-dir", { cwd }) return true } catch (error) { return false @@ -24,7 +24,7 @@ async function checkGitRepo(cwd: string): Promise { async function checkGitInstalled(): Promise { try { - await execAsync('git --version') + await execAsync("git --version") return true } catch (error) { return false @@ -47,18 +47,16 @@ export async function searchCommits(query: string, cwd: string): Promise ({ stdout: "" })) if (!hashStdout.trim()) { @@ -69,7 +67,10 @@ export async function searchCommits(query: string, cwd: string): Promise line !== "--") + const lines = output + .trim() + .split("\n") + .filter((line) => line !== "--") for (let i = 0; i < lines.length; i += 5) { commits.push({ @@ -77,7 +78,7 @@ export async function searchCommits(query: string, cwd: string): Promise } // Get commit info, stats, and diff separately - const { stdout: info } = await execAsync( - `git show --format="%H%n%h%n%s%n%an%n%ad%n%b" --no-patch ${hash}`, - { cwd } - ) - const [fullHash, shortHash, subject, author, date, body] = info.trim().split('\n') - - const { stdout: stats } = await execAsync( - `git show --stat --format="" ${hash}`, - { cwd } - ) + const { stdout: info } = await execAsync(`git show --format="%H%n%h%n%s%n%an%n%ad%n%b" --no-patch ${hash}`, { + cwd, + }) + const [fullHash, shortHash, subject, author, date, body] = info.trim().split("\n") - const { stdout: diff } = await execAsync( - `git show --format="" ${hash}`, - { cwd } - ) + const { stdout: stats } = await execAsync(`git show --stat --format="" ${hash}`, { cwd }) + + const { stdout: diff } = await execAsync(`git show --format="" ${hash}`, { cwd }) const summary = [ `Commit: ${shortHash} (${fullHash})`, `Author: ${author}`, `Date: ${date}`, `\nMessage: ${subject}`, - body ? `\nDescription:\n${body}` : '', - '\nFiles Changed:', + body ? `\nDescription:\n${body}` : "", + "\nFiles Changed:", stats.trim(), - '\nFull Changes:' - ].join('\n') + "\nFull Changes:", + ].join("\n") - const output = summary + '\n\n' + diff.trim() + const output = summary + "\n\n" + diff.trim() return truncateOutput(output, GIT_OUTPUT_LINE_LIMIT) } catch (error) { console.error("Error getting commit info:", error) @@ -149,13 +143,13 @@ export async function getWorkingState(cwd: string): Promise { } // Get status of working directory - const { stdout: status } = await execAsync('git status --short', { cwd }) + const { stdout: status } = await execAsync("git status --short", { cwd }) if (!status.trim()) { return "No changes in working directory" } // Get all changes (both staged and unstaged) compared to HEAD - const { stdout: diff } = await execAsync('git diff HEAD', { cwd }) + const { stdout: diff } = await execAsync("git diff HEAD", { cwd }) const lineLimit = GIT_OUTPUT_LINE_LIMIT const output = `Working directory changes:\n\n${status}\n\n${diff}`.trim() return truncateOutput(output, lineLimit) @@ -163,4 +157,4 @@ export async function getWorkingState(cwd: string): Promise { console.error("Error getting working state:", error) return `Failed to get working state: ${error instanceof Error ? error.message : String(error)}` } -} \ No newline at end of file +} diff --git a/src/utils/sound.ts b/src/utils/sound.ts index a7f0d73..877a041 100644 --- a/src/utils/sound.ts +++ b/src/utils/sound.ts @@ -21,7 +21,7 @@ export const isWAV = (filepath: string): boolean => { } let isSoundEnabled = false -let volume = .5 +let volume = 0.5 /** * Set sound configuration diff --git a/webview-ui/config-overrides.js b/webview-ui/config-overrides.js index 857b363..65bd531 100644 --- a/webview-ui/config-overrides.js +++ b/webview-ui/config-overrides.js @@ -1,22 +1,22 @@ -const { override } = require('customize-cra'); +const { override } = require("customize-cra") -module.exports = override(); +module.exports = override() // Jest configuration override -module.exports.jest = function(config) { - // Configure reporters - config.reporters = [["jest-simple-dot-reporter", {}]]; - - // Configure module name mapper for CSS modules - config.moduleNameMapper = { - ...config.moduleNameMapper, - "\\.(css|less|scss|sass)$": "identity-obj-proxy" - }; - - // Configure transform ignore patterns for ES modules - config.transformIgnorePatterns = [ - '/node_modules/(?!(rehype-highlight|react-remark|unist-util-visit|unist-util-find-after|vfile|unified|bail|is-plain-obj|trough|vfile-message|unist-util-stringify-position|mdast-util-from-markdown|mdast-util-to-string|micromark|decode-named-character-reference|character-entities|markdown-table|zwitch|longest-streak|escape-string-regexp|unist-util-is|hast-util-to-text|@vscode/webview-ui-toolkit|@microsoft/fast-react-wrapper|@microsoft/fast-element|@microsoft/fast-foundation|@microsoft/fast-web-utilities|exenv-es6|vscrui)/)' - ]; - - return config; -} \ No newline at end of file +module.exports.jest = function (config) { + // Configure reporters + config.reporters = [["jest-simple-dot-reporter", {}]] + + // Configure module name mapper for CSS modules + config.moduleNameMapper = { + ...config.moduleNameMapper, + "\\.(css|less|scss|sass)$": "identity-obj-proxy", + } + + // Configure transform ignore patterns for ES modules + config.transformIgnorePatterns = [ + "/node_modules/(?!(rehype-highlight|react-remark|unist-util-visit|unist-util-find-after|vfile|unified|bail|is-plain-obj|trough|vfile-message|unist-util-stringify-position|mdast-util-from-markdown|mdast-util-to-string|micromark|decode-named-character-reference|character-entities|markdown-table|zwitch|longest-streak|escape-string-regexp|unist-util-is|hast-util-to-text|@vscode/webview-ui-toolkit|@microsoft/fast-react-wrapper|@microsoft/fast-element|@microsoft/fast-foundation|@microsoft/fast-web-utilities|exenv-es6|vscrui)/)", + ] + + return config +} diff --git a/webview-ui/src/components/chat/Announcement.tsx b/webview-ui/src/components/chat/Announcement.tsx index 04b4974..f86f871 100644 --- a/webview-ui/src/components/chat/Announcement.tsx +++ b/webview-ui/src/components/chat/Announcement.tsx @@ -33,25 +33,28 @@ const Announcement = ({ version, hideAnnouncement }: AnnouncementProps) => { 🎉{" "}Introducing Roo Cline v{minorVersion} -

- Agent Modes Customization -

+

Agent Modes Customization

- Click the new icon in the menu bar to open the Prompts Settings and customize Agent Modes for new levels of productivity. + Click the new icon in + the menu bar to open the Prompts Settings and customize Agent Modes for new levels of productivity.

  • Tailor how Roo Cline behaves in different modes: Code, Architect, and Ask.
  • Preview and verify your changes using the Preview System Prompt button.

-

- Prompt Enhancement Configuration -

+

Prompt Enhancement Configuration

- Now available for all providers! Access it directly in the chat box by clicking the sparkle icon next to the input field. From there, you can customize the enhancement logic and provider to best suit your workflow. + Now available for all providers! Access it directly in the chat box by clicking the{" "} + sparkle icon next to the + input field. From there, you can customize the enhancement logic and provider to best suit your + workflow.

  • Customize how prompts are enhanced for better results in your workflow.
  • -
  • Use the sparkle icon in the chat box to select a API configuration and provider (e.g., GPT-4) and configure your own enhancement logic.
  • +
  • + Use the sparkle icon in the chat box to select a API configuration and provider (e.g., GPT-4) + and configure your own enhancement logic. +
  • Test your changes instantly with the Preview Prompt Enhancement tool.

diff --git a/webview-ui/src/components/chat/AutoApproveMenu.tsx b/webview-ui/src/components/chat/AutoApproveMenu.tsx index f451638..1eb2fc4 100644 --- a/webview-ui/src/components/chat/AutoApproveMenu.tsx +++ b/webview-ui/src/components/chat/AutoApproveMenu.tsx @@ -127,7 +127,7 @@ const AutoApproveMenu = ({ style }: AutoApproveMenuProps) => { }, [alwaysApproveResubmit, setAlwaysApproveResubmit]) // Map action IDs to their specific handlers - const actionHandlers: Record void> = { + const actionHandlers: Record void> = { readFiles: handleReadOnlyChange, editFiles: handleWriteChange, executeCommands: handleExecuteChange, @@ -166,25 +166,30 @@ const AutoApproveMenu = ({ style }: AutoApproveMenuProps) => { }} /> -
- Auto-approve: - + + Auto-approve: + + {enabledActionsList || "None"} { {actions.map((action) => (
e.stopPropagation()}> - + {action.label}
diff --git a/webview-ui/src/components/chat/BrowserSessionRow.tsx b/webview-ui/src/components/chat/BrowserSessionRow.tsx index 682cfd2..0b06601 100644 --- a/webview-ui/src/components/chat/BrowserSessionRow.tsx +++ b/webview-ui/src/components/chat/BrowserSessionRow.tsx @@ -31,8 +31,8 @@ const BrowserSessionRow = memo((props: BrowserSessionRowProps) => { const { browserViewportSize = "900x600" } = useExtensionState() const [viewportWidth, viewportHeight] = browserViewportSize.split("x").map(Number) - const aspectRatio = (viewportHeight / viewportWidth * 100).toFixed(2) - const defaultMousePosition = `${Math.round(viewportWidth/2)},${Math.round(viewportHeight/2)}` + const aspectRatio = ((viewportHeight / viewportWidth) * 100).toFixed(2) + const defaultMousePosition = `${Math.round(viewportWidth / 2)},${Math.round(viewportHeight / 2)}` const isLastApiReqInterrupted = useMemo(() => { // Check if last api_req_started is cancelled @@ -171,7 +171,8 @@ const BrowserSessionRow = memo((props: BrowserSessionRowProps) => { const displayState = isLastPage ? { url: currentPage?.currentState.url || latestState.url || initialUrl, - mousePosition: currentPage?.currentState.mousePosition || latestState.mousePosition || defaultMousePosition, + mousePosition: + currentPage?.currentState.mousePosition || latestState.mousePosition || defaultMousePosition, consoleLogs: currentPage?.currentState.consoleLogs, screenshot: currentPage?.currentState.screenshot || latestState.screenshot, } @@ -226,7 +227,9 @@ const BrowserSessionRow = memo((props: BrowserSessionRowProps) => { }, [isBrowsing, currentPage?.nextAction?.messages]) // Use latest click position while browsing, otherwise use display state - const mousePosition = isBrowsing ? latestClickPosition || displayState.mousePosition : displayState.mousePosition || defaultMousePosition + const mousePosition = isBrowsing + ? latestClickPosition || displayState.mousePosition + : displayState.mousePosition || defaultMousePosition const [browserSessionRow, { height: rowHeight }] = useSize(
diff --git a/webview-ui/src/components/chat/ChatRow.tsx b/webview-ui/src/components/chat/ChatRow.tsx index 3089a13..8cda508 100644 --- a/webview-ui/src/components/chat/ChatRow.tsx +++ b/webview-ui/src/components/chat/ChatRow.tsx @@ -565,7 +565,13 @@ export const ChatRowContent = ({ whiteSpace: "pre-line", wordWrap: "break-word", }}> -
+
{highlightMentions(message.text)} { - e.stopPropagation(); + e.stopPropagation() vscode.postMessage({ type: "deleteMessage", - value: message.ts - }); - }} - > + value: message.ts, + }) + }}>
@@ -835,10 +840,13 @@ export const ChatRowContent = ({ tool={{ name: useMcpServer.toolName || "", description: - server?.tools?.find((tool) => tool.name === useMcpServer.toolName) - ?.description || "", - alwaysAllow: server?.tools?.find((tool) => tool.name === useMcpServer.toolName) - ?.alwaysAllow || false, + server?.tools?.find( + (tool) => tool.name === useMcpServer.toolName, + )?.description || "", + alwaysAllow: + server?.tools?.find( + (tool) => tool.name === useMcpServer.toolName, + )?.alwaysAllow || false, }} serverName={useMcpServer.serverName} /> @@ -919,14 +927,13 @@ export const ProgressIndicator = () => ( ) const Markdown = memo(({ markdown, partial }: { markdown?: string; partial?: boolean }) => { - const [isHovering, setIsHovering] = useState(false); + const [isHovering, setIsHovering] = useState(false) return (
setIsHovering(true)} onMouseLeave={() => setIsHovering(false)} - style={{ position: "relative" }} - > + style={{ position: "relative" }}>
@@ -938,9 +945,8 @@ const Markdown = memo(({ markdown, partial }: { markdown?: string; partial?: boo right: "8px", opacity: 0, animation: "fadeIn 0.2s ease-in-out forwards", - borderRadius: "4px" - }} - > + borderRadius: "4px", + }}> - {showCopyModal && ( -
- Prompt Copied to Clipboard -
- )} + {showCopyModal &&
Prompt Copied to Clipboard
}
{ components={{ List: React.forwardRef((props, ref) => (
- )) + )), }} itemContent={(index, item) => (
{
diff --git a/webview-ui/src/components/history/__tests__/HistoryView.test.tsx b/webview-ui/src/components/history/__tests__/HistoryView.test.tsx index 3b7623e..1f1e4c7 100644 --- a/webview-ui/src/components/history/__tests__/HistoryView.test.tsx +++ b/webview-ui/src/components/history/__tests__/HistoryView.test.tsx @@ -1,232 +1,235 @@ -import React from 'react' -import { render, screen, fireEvent, within, waitFor } from '@testing-library/react' -import userEvent from '@testing-library/user-event' -import HistoryView from '../HistoryView' -import { useExtensionState } from '../../../context/ExtensionStateContext' -import { vscode } from '../../../utils/vscode' +import React from "react" +import { render, screen, fireEvent, within, waitFor } from "@testing-library/react" +import userEvent from "@testing-library/user-event" +import HistoryView from "../HistoryView" +import { useExtensionState } from "../../../context/ExtensionStateContext" +import { vscode } from "../../../utils/vscode" // Mock dependencies -jest.mock('../../../context/ExtensionStateContext') -jest.mock('../../../utils/vscode') -jest.mock('react-virtuoso', () => ({ - Virtuoso: ({ data, itemContent }: any) => ( -
- {data.map((item: any, index: number) => ( -
- {itemContent(index, item)} -
- ))} -
- ), +jest.mock("../../../context/ExtensionStateContext") +jest.mock("../../../utils/vscode") +jest.mock("react-virtuoso", () => ({ + Virtuoso: ({ data, itemContent }: any) => ( +
+ {data.map((item: any, index: number) => ( +
+ {itemContent(index, item)} +
+ ))} +
+ ), })) const mockTaskHistory = [ - { - id: '1', - task: 'Test task 1', - ts: new Date('2022-02-16T00:00:00').getTime(), - tokensIn: 100, - tokensOut: 50, - totalCost: 0.002, - }, - { - id: '2', - task: 'Test task 2', - ts: new Date('2022-02-17T00:00:00').getTime(), - tokensIn: 200, - tokensOut: 100, - cacheWrites: 50, - cacheReads: 25, - }, + { + id: "1", + task: "Test task 1", + ts: new Date("2022-02-16T00:00:00").getTime(), + tokensIn: 100, + tokensOut: 50, + totalCost: 0.002, + }, + { + id: "2", + task: "Test task 2", + ts: new Date("2022-02-17T00:00:00").getTime(), + tokensIn: 200, + tokensOut: 100, + cacheWrites: 50, + cacheReads: 25, + }, ] -describe('HistoryView', () => { - beforeEach(() => { - // Reset all mocks before each test - jest.clearAllMocks() - jest.useFakeTimers() - - // Mock useExtensionState implementation - ;(useExtensionState as jest.Mock).mockReturnValue({ - taskHistory: mockTaskHistory, - }) - }) +describe("HistoryView", () => { + beforeEach(() => { + // Reset all mocks before each test + jest.clearAllMocks() + jest.useFakeTimers() - afterEach(() => { - jest.useRealTimers() - }) + // Mock useExtensionState implementation + ;(useExtensionState as jest.Mock).mockReturnValue({ + taskHistory: mockTaskHistory, + }) + }) - it('renders history items correctly', () => { - const onDone = jest.fn() - render() + afterEach(() => { + jest.useRealTimers() + }) - // Check if both tasks are rendered - expect(screen.getByTestId('virtuoso-item-1')).toBeInTheDocument() - expect(screen.getByTestId('virtuoso-item-2')).toBeInTheDocument() - expect(screen.getByText('Test task 1')).toBeInTheDocument() - expect(screen.getByText('Test task 2')).toBeInTheDocument() - }) + it("renders history items correctly", () => { + const onDone = jest.fn() + render() - it('handles search functionality', async () => { - const onDone = jest.fn() - render() + // Check if both tasks are rendered + expect(screen.getByTestId("virtuoso-item-1")).toBeInTheDocument() + expect(screen.getByTestId("virtuoso-item-2")).toBeInTheDocument() + expect(screen.getByText("Test task 1")).toBeInTheDocument() + expect(screen.getByText("Test task 2")).toBeInTheDocument() + }) - // Get search input and radio group - const searchInput = screen.getByPlaceholderText('Fuzzy search history...') - const radioGroup = screen.getByRole('radiogroup') - - // Type in search - await userEvent.type(searchInput, 'task 1') + it("handles search functionality", async () => { + const onDone = jest.fn() + render() - // Check if sort option automatically changes to "Most Relevant" - const mostRelevantRadio = within(radioGroup).getByLabelText('Most Relevant') - expect(mostRelevantRadio).not.toBeDisabled() - - // Click and wait for radio update - fireEvent.click(mostRelevantRadio) + // Get search input and radio group + const searchInput = screen.getByPlaceholderText("Fuzzy search history...") + const radioGroup = screen.getByRole("radiogroup") - // Wait for radio button to be checked - const updatedRadio = await within(radioGroup).findByRole('radio', { name: 'Most Relevant', checked: true }) - expect(updatedRadio).toBeInTheDocument() - }) + // Type in search + await userEvent.type(searchInput, "task 1") - it('handles sort options correctly', async () => { - const onDone = jest.fn() - render() + // Check if sort option automatically changes to "Most Relevant" + const mostRelevantRadio = within(radioGroup).getByLabelText("Most Relevant") + expect(mostRelevantRadio).not.toBeDisabled() - const radioGroup = screen.getByRole('radiogroup') + // Click and wait for radio update + fireEvent.click(mostRelevantRadio) - // Test changing sort options - const oldestRadio = within(radioGroup).getByLabelText('Oldest') - fireEvent.click(oldestRadio) - - // Wait for oldest radio to be checked - const checkedOldestRadio = await within(radioGroup).findByRole('radio', { name: 'Oldest', checked: true }) - expect(checkedOldestRadio).toBeInTheDocument() + // Wait for radio button to be checked + const updatedRadio = await within(radioGroup).findByRole("radio", { name: "Most Relevant", checked: true }) + expect(updatedRadio).toBeInTheDocument() + }) - const mostExpensiveRadio = within(radioGroup).getByLabelText('Most Expensive') - fireEvent.click(mostExpensiveRadio) - - // Wait for most expensive radio to be checked - const checkedExpensiveRadio = await within(radioGroup).findByRole('radio', { name: 'Most Expensive', checked: true }) - expect(checkedExpensiveRadio).toBeInTheDocument() - }) + it("handles sort options correctly", async () => { + const onDone = jest.fn() + render() - it('handles task selection', () => { - const onDone = jest.fn() - render() + const radioGroup = screen.getByRole("radiogroup") - // Click on first task - fireEvent.click(screen.getByText('Test task 1')) + // Test changing sort options + const oldestRadio = within(radioGroup).getByLabelText("Oldest") + fireEvent.click(oldestRadio) - // Verify vscode message was sent - expect(vscode.postMessage).toHaveBeenCalledWith({ - type: 'showTaskWithId', - text: '1', - }) - }) + // Wait for oldest radio to be checked + const checkedOldestRadio = await within(radioGroup).findByRole("radio", { name: "Oldest", checked: true }) + expect(checkedOldestRadio).toBeInTheDocument() - it('handles task deletion', () => { - const onDone = jest.fn() - render() + const mostExpensiveRadio = within(radioGroup).getByLabelText("Most Expensive") + fireEvent.click(mostExpensiveRadio) - // Find and hover over first task - const taskContainer = screen.getByTestId('virtuoso-item-1') - fireEvent.mouseEnter(taskContainer) - - const deleteButton = within(taskContainer).getByTitle('Delete Task') - fireEvent.click(deleteButton) + // Wait for most expensive radio to be checked + const checkedExpensiveRadio = await within(radioGroup).findByRole("radio", { + name: "Most Expensive", + checked: true, + }) + expect(checkedExpensiveRadio).toBeInTheDocument() + }) - // Verify vscode message was sent - expect(vscode.postMessage).toHaveBeenCalledWith({ - type: 'deleteTaskWithId', - text: '1', - }) - }) + it("handles task selection", () => { + const onDone = jest.fn() + render() - it('handles task copying', async () => { - const mockClipboard = { - writeText: jest.fn().mockResolvedValue(undefined), - } - Object.assign(navigator, { clipboard: mockClipboard }) + // Click on first task + fireEvent.click(screen.getByText("Test task 1")) - const onDone = jest.fn() - render() + // Verify vscode message was sent + expect(vscode.postMessage).toHaveBeenCalledWith({ + type: "showTaskWithId", + text: "1", + }) + }) - // Find and hover over first task - const taskContainer = screen.getByTestId('virtuoso-item-1') - fireEvent.mouseEnter(taskContainer) - - const copyButton = within(taskContainer).getByTitle('Copy Prompt') - await userEvent.click(copyButton) + it("handles task deletion", () => { + const onDone = jest.fn() + render() - // Verify clipboard API was called - expect(navigator.clipboard.writeText).toHaveBeenCalledWith('Test task 1') - - // Wait for copy modal to appear - const copyModal = await screen.findByText('Prompt Copied to Clipboard') - expect(copyModal).toBeInTheDocument() + // Find and hover over first task + const taskContainer = screen.getByTestId("virtuoso-item-1") + fireEvent.mouseEnter(taskContainer) - // Fast-forward timers and wait for modal to disappear - jest.advanceTimersByTime(2000) - await waitFor(() => { - expect(screen.queryByText('Prompt Copied to Clipboard')).not.toBeInTheDocument() - }) - }) + const deleteButton = within(taskContainer).getByTitle("Delete Task") + fireEvent.click(deleteButton) - it('formats dates correctly', () => { - const onDone = jest.fn() - render() + // Verify vscode message was sent + expect(vscode.postMessage).toHaveBeenCalledWith({ + type: "deleteTaskWithId", + text: "1", + }) + }) - // Find first task container and check date format - const taskContainer = screen.getByTestId('virtuoso-item-1') - const dateElement = within(taskContainer).getByText((content) => { - return content.includes('FEBRUARY 16') && content.includes('12:00 AM') - }) - expect(dateElement).toBeInTheDocument() - }) + it("handles task copying", async () => { + const mockClipboard = { + writeText: jest.fn().mockResolvedValue(undefined), + } + Object.assign(navigator, { clipboard: mockClipboard }) - it('displays token counts correctly', () => { - const onDone = jest.fn() - render() + const onDone = jest.fn() + render() - // Find first task container - const taskContainer = screen.getByTestId('virtuoso-item-1') + // Find and hover over first task + const taskContainer = screen.getByTestId("virtuoso-item-1") + fireEvent.mouseEnter(taskContainer) - // Find token counts within the task container - const tokensContainer = within(taskContainer).getByTestId('tokens-container') - expect(within(tokensContainer).getByTestId('tokens-in')).toHaveTextContent('100') - expect(within(tokensContainer).getByTestId('tokens-out')).toHaveTextContent('50') - }) + const copyButton = within(taskContainer).getByTitle("Copy Prompt") + await userEvent.click(copyButton) - it('displays cache information when available', () => { - const onDone = jest.fn() - render() + // Verify clipboard API was called + expect(navigator.clipboard.writeText).toHaveBeenCalledWith("Test task 1") - // Find second task container - const taskContainer = screen.getByTestId('virtuoso-item-2') + // Wait for copy modal to appear + const copyModal = await screen.findByText("Prompt Copied to Clipboard") + expect(copyModal).toBeInTheDocument() - // Find cache info within the task container - const cacheContainer = within(taskContainer).getByTestId('cache-container') - expect(within(cacheContainer).getByTestId('cache-writes')).toHaveTextContent('+50') - expect(within(cacheContainer).getByTestId('cache-reads')).toHaveTextContent('25') - }) + // Fast-forward timers and wait for modal to disappear + jest.advanceTimersByTime(2000) + await waitFor(() => { + expect(screen.queryByText("Prompt Copied to Clipboard")).not.toBeInTheDocument() + }) + }) - it('handles export functionality', () => { - const onDone = jest.fn() - render() + it("formats dates correctly", () => { + const onDone = jest.fn() + render() - // Find and hover over second task - const taskContainer = screen.getByTestId('virtuoso-item-2') - fireEvent.mouseEnter(taskContainer) - - const exportButton = within(taskContainer).getByText('EXPORT') - fireEvent.click(exportButton) + // Find first task container and check date format + const taskContainer = screen.getByTestId("virtuoso-item-1") + const dateElement = within(taskContainer).getByText((content) => { + return content.includes("FEBRUARY 16") && content.includes("12:00 AM") + }) + expect(dateElement).toBeInTheDocument() + }) - // Verify vscode message was sent - expect(vscode.postMessage).toHaveBeenCalledWith({ - type: 'exportTaskWithId', - text: '2', - }) - }) -}) \ No newline at end of file + it("displays token counts correctly", () => { + const onDone = jest.fn() + render() + + // Find first task container + const taskContainer = screen.getByTestId("virtuoso-item-1") + + // Find token counts within the task container + const tokensContainer = within(taskContainer).getByTestId("tokens-container") + expect(within(tokensContainer).getByTestId("tokens-in")).toHaveTextContent("100") + expect(within(tokensContainer).getByTestId("tokens-out")).toHaveTextContent("50") + }) + + it("displays cache information when available", () => { + const onDone = jest.fn() + render() + + // Find second task container + const taskContainer = screen.getByTestId("virtuoso-item-2") + + // Find cache info within the task container + const cacheContainer = within(taskContainer).getByTestId("cache-container") + expect(within(cacheContainer).getByTestId("cache-writes")).toHaveTextContent("+50") + expect(within(cacheContainer).getByTestId("cache-reads")).toHaveTextContent("25") + }) + + it("handles export functionality", () => { + const onDone = jest.fn() + render() + + // Find and hover over second task + const taskContainer = screen.getByTestId("virtuoso-item-2") + fireEvent.mouseEnter(taskContainer) + + const exportButton = within(taskContainer).getByText("EXPORT") + fireEvent.click(exportButton) + + // Verify vscode message was sent + expect(vscode.postMessage).toHaveBeenCalledWith({ + type: "exportTaskWithId", + text: "2", + }) + }) +}) diff --git a/webview-ui/src/components/mcp/McpEnabledToggle.tsx b/webview-ui/src/components/mcp/McpEnabledToggle.tsx index 41c94c7..41bbee9 100644 --- a/webview-ui/src/components/mcp/McpEnabledToggle.tsx +++ b/webview-ui/src/components/mcp/McpEnabledToggle.tsx @@ -7,7 +7,7 @@ const McpEnabledToggle = () => { const { mcpEnabled, setMcpEnabled } = useExtensionState() const handleChange = (e: Event | FormEvent) => { - const target = ('target' in e ? e.target : null) as HTMLInputElement | null + const target = ("target" in e ? e.target : null) as HTMLInputElement | null if (!target) return setMcpEnabled(target.checked) vscode.postMessage({ type: "mcpEnabled", bool: target.checked }) @@ -15,20 +15,20 @@ const McpEnabledToggle = () => { return (
- + Enable MCP Servers -

- When enabled, Cline will be able to interact with MCP servers for advanced functionality. If you're not using MCP, you can disable this to reduce Cline's token usage. +

+ When enabled, Cline will be able to interact with MCP servers for advanced functionality. If you're not + using MCP, you can disable this to reduce Cline's token usage.

) } -export default McpEnabledToggle \ No newline at end of file +export default McpEnabledToggle diff --git a/webview-ui/src/components/mcp/McpToolRow.tsx b/webview-ui/src/components/mcp/McpToolRow.tsx index f382cc6..20fc1ac 100644 --- a/webview-ui/src/components/mcp/McpToolRow.tsx +++ b/webview-ui/src/components/mcp/McpToolRow.tsx @@ -10,14 +10,14 @@ type McpToolRowProps = { const McpToolRow = ({ tool, serverName, alwaysAllowMcp }: McpToolRowProps) => { const handleAlwaysAllowChange = () => { - if (!serverName) return; - + if (!serverName) return + vscode.postMessage({ type: "toggleToolAlwaysAllow", serverName, toolName: tool.name, - alwaysAllow: !tool.alwaysAllow - }); + alwaysAllow: !tool.alwaysAllow, + }) } return ( @@ -35,10 +35,7 @@ const McpToolRow = ({ tool, serverName, alwaysAllowMcp }: McpToolRowProps) => { {tool.name}
{serverName && alwaysAllowMcp && ( - + Always allow )} diff --git a/webview-ui/src/components/mcp/McpView.tsx b/webview-ui/src/components/mcp/McpView.tsx index ea2ea2b..607a736 100644 --- a/webview-ui/src/components/mcp/McpView.tsx +++ b/webview-ui/src/components/mcp/McpView.tsx @@ -159,7 +159,7 @@ const McpView = ({ onDone }: McpViewProps) => { } // Server Row Component -const ServerRow = ({ server, alwaysAllowMcp }: { server: McpServer, alwaysAllowMcp?: boolean }) => { +const ServerRow = ({ server, alwaysAllowMcp }: { server: McpServer; alwaysAllowMcp?: boolean }) => { const [isExpanded, setIsExpanded] = useState(false) const getStatusColor = () => { @@ -216,9 +216,9 @@ const ServerRow = ({ server, alwaysAllowMcp }: { server: McpServer, alwaysAllowM style={{ width: "20px", height: "10px", - backgroundColor: server.disabled ? - "var(--vscode-titleBar-inactiveForeground)" : - "var(--vscode-button-background)", + backgroundColor: server.disabled + ? "var(--vscode-titleBar-inactiveForeground)" + : "var(--vscode-button-background)", borderRadius: "5px", position: "relative", cursor: "pointer", @@ -229,30 +229,31 @@ const ServerRow = ({ server, alwaysAllowMcp }: { server: McpServer, alwaysAllowM vscode.postMessage({ type: "toggleMcpServer", serverName: server.name, - disabled: !server.disabled - }); + disabled: !server.disabled, + }) }} onKeyDown={(e) => { if (e.key === "Enter" || e.key === " ") { - e.preventDefault(); + e.preventDefault() vscode.postMessage({ type: "toggleMcpServer", serverName: server.name, - disabled: !server.disabled - }); + disabled: !server.disabled, + }) } - }} - > -
+ }}> +
({ - vscode: { - postMessage: jest.fn() - } +jest.mock("../../../utils/vscode", () => ({ + vscode: { + postMessage: jest.fn(), + }, })) -jest.mock('@vscode/webview-ui-toolkit/react', () => ({ - VSCodeCheckbox: function MockVSCodeCheckbox({ - children, - checked, - onChange - }: { - children?: React.ReactNode; - checked?: boolean; - onChange?: (e: React.ChangeEvent) => void; - }) { - return ( - - ) - } +jest.mock("@vscode/webview-ui-toolkit/react", () => ({ + VSCodeCheckbox: function MockVSCodeCheckbox({ + children, + checked, + onChange, + }: { + children?: React.ReactNode + checked?: boolean + onChange?: (e: React.ChangeEvent) => void + }) { + return ( + + ) + }, })) -describe('McpToolRow', () => { - const mockTool = { - name: 'test-tool', - description: 'A test tool', - alwaysAllow: false - } +describe("McpToolRow", () => { + const mockTool = { + name: "test-tool", + description: "A test tool", + alwaysAllow: false, + } - beforeEach(() => { - jest.clearAllMocks() - }) + beforeEach(() => { + jest.clearAllMocks() + }) - it('renders tool name and description', () => { - render() - - expect(screen.getByText('test-tool')).toBeInTheDocument() - expect(screen.getByText('A test tool')).toBeInTheDocument() - }) + it("renders tool name and description", () => { + render() - it('does not show always allow checkbox when serverName is not provided', () => { - render() - - expect(screen.queryByText('Always allow')).not.toBeInTheDocument() - }) + expect(screen.getByText("test-tool")).toBeInTheDocument() + expect(screen.getByText("A test tool")).toBeInTheDocument() + }) - it('shows always allow checkbox when serverName and alwaysAllowMcp are provided', () => { - render() - - expect(screen.getByText('Always allow')).toBeInTheDocument() - }) - - it('sends message to toggle always allow when checkbox is clicked', () => { - render() - - const checkbox = screen.getByRole('checkbox') - fireEvent.click(checkbox) - - expect(vscode.postMessage).toHaveBeenCalledWith({ - type: 'toggleToolAlwaysAllow', - serverName: 'test-server', - toolName: 'test-tool', - alwaysAllow: true - }) - }) - - it('reflects always allow state in checkbox', () => { - const alwaysAllowedTool = { - ...mockTool, - alwaysAllow: true - } - - render() - - const checkbox = screen.getByRole('checkbox') as HTMLInputElement - expect(checkbox.checked).toBe(true) - }) - - it('prevents event propagation when clicking the checkbox', () => { - const mockOnClick = jest.fn() - render( -
- -
- ) - - const container = screen.getByTestId('tool-row-container') - fireEvent.click(container) - - expect(mockOnClick).not.toHaveBeenCalled() - }) + it("does not show always allow checkbox when serverName is not provided", () => { + render() - it('displays input schema parameters when provided', () => { - const toolWithSchema = { - ...mockTool, - inputSchema: { - type: 'object', - properties: { - param1: { - type: 'string', - description: 'First parameter' - }, - param2: { - type: 'number', - description: 'Second parameter' - } - }, - required: ['param1'] - } - } + expect(screen.queryByText("Always allow")).not.toBeInTheDocument() + }) - render() - - expect(screen.getByText('Parameters')).toBeInTheDocument() - expect(screen.getByText('param1')).toBeInTheDocument() - expect(screen.getByText('param2')).toBeInTheDocument() - expect(screen.getByText('First parameter')).toBeInTheDocument() - expect(screen.getByText('Second parameter')).toBeInTheDocument() - }) -}) \ No newline at end of file + it("shows always allow checkbox when serverName and alwaysAllowMcp are provided", () => { + render() + + expect(screen.getByText("Always allow")).toBeInTheDocument() + }) + + it("sends message to toggle always allow when checkbox is clicked", () => { + render() + + const checkbox = screen.getByRole("checkbox") + fireEvent.click(checkbox) + + expect(vscode.postMessage).toHaveBeenCalledWith({ + type: "toggleToolAlwaysAllow", + serverName: "test-server", + toolName: "test-tool", + alwaysAllow: true, + }) + }) + + it("reflects always allow state in checkbox", () => { + const alwaysAllowedTool = { + ...mockTool, + alwaysAllow: true, + } + + render() + + const checkbox = screen.getByRole("checkbox") as HTMLInputElement + expect(checkbox.checked).toBe(true) + }) + + it("prevents event propagation when clicking the checkbox", () => { + const mockOnClick = jest.fn() + render( +
+ +
, + ) + + const container = screen.getByTestId("tool-row-container") + fireEvent.click(container) + + expect(mockOnClick).not.toHaveBeenCalled() + }) + + it("displays input schema parameters when provided", () => { + const toolWithSchema = { + ...mockTool, + inputSchema: { + type: "object", + properties: { + param1: { + type: "string", + description: "First parameter", + }, + param2: { + type: "number", + description: "Second parameter", + }, + }, + required: ["param1"], + }, + } + + render() + + expect(screen.getByText("Parameters")).toBeInTheDocument() + expect(screen.getByText("param1")).toBeInTheDocument() + expect(screen.getByText("param2")).toBeInTheDocument() + expect(screen.getByText("First parameter")).toBeInTheDocument() + expect(screen.getByText("Second parameter")).toBeInTheDocument() + }) +}) diff --git a/webview-ui/src/components/prompts/PromptsView.tsx b/webview-ui/src/components/prompts/PromptsView.tsx index 497bac7..cebcf1a 100644 --- a/webview-ui/src/components/prompts/PromptsView.tsx +++ b/webview-ui/src/components/prompts/PromptsView.tsx @@ -8,9 +8,9 @@ type PromptsViewProps = { onDone: () => void } -const AGENT_MODES = modes.map(mode => ({ +const AGENT_MODES = modes.map((mode) => ({ id: mode.slug, - label: mode.name + label: mode.name, })) const PromptsView = ({ onDone }: PromptsViewProps) => { @@ -21,24 +21,24 @@ const PromptsView = ({ onDone }: PromptsViewProps) => { setEnhancementApiConfigId, mode, customInstructions, - setCustomInstructions + setCustomInstructions, } = useExtensionState() - const [testPrompt, setTestPrompt] = useState('') + const [testPrompt, setTestPrompt] = useState("") const [isEnhancing, setIsEnhancing] = useState(false) const [activeTab, setActiveTab] = useState(mode) const [isDialogOpen, setIsDialogOpen] = useState(false) - const [selectedPromptContent, setSelectedPromptContent] = useState('') - const [selectedPromptTitle, setSelectedPromptTitle] = useState('') + const [selectedPromptContent, setSelectedPromptContent] = useState("") + const [selectedPromptTitle, setSelectedPromptTitle] = useState("") useEffect(() => { const handler = (event: MessageEvent) => { const message = event.data - if (message.type === 'enhancedPrompt') { + if (message.type === "enhancedPrompt") { if (message.text) { setTestPrompt(message.text) } setIsEnhancing(false) - } else if (message.type === 'systemPrompt') { + } else if (message.type === "systemPrompt") { if (message.text) { setSelectedPromptContent(message.text) setSelectedPromptTitle(`System Prompt (${message.mode} mode)`) @@ -47,17 +47,15 @@ const PromptsView = ({ onDone }: PromptsViewProps) => { } } - window.addEventListener('message', handler) - return () => window.removeEventListener('message', handler) + window.addEventListener("message", handler) + return () => window.removeEventListener("message", handler) }, []) - type AgentMode = string; + type AgentMode = string const updateAgentPrompt = (mode: Mode, promptData: PromptComponent) => { const existingPrompt = customPrompts?.[mode] - const updatedPrompt = typeof existingPrompt === 'object' - ? { ...existingPrompt, ...promptData } - : promptData + const updatedPrompt = typeof existingPrompt === "object" ? { ...existingPrompt, ...promptData } : promptData // Only include properties that differ from defaults if (updatedPrompt.roleDefinition === getRoleDefinition(mode)) { @@ -67,14 +65,14 @@ const PromptsView = ({ onDone }: PromptsViewProps) => { vscode.postMessage({ type: "updatePrompt", promptMode: mode, - customPrompt: updatedPrompt + customPrompt: updatedPrompt, }) } const updateEnhancePrompt = (value: string | undefined) => { vscode.postMessage({ type: "updateEnhancedPrompt", - text: value + text: value, }) } @@ -94,8 +92,8 @@ const PromptsView = ({ onDone }: PromptsViewProps) => { const handleAgentReset = (mode: AgentMode) => { const existingPrompt = customPrompts?.[mode] updateAgentPrompt(mode, { - ...(typeof existingPrompt === 'object' ? existingPrompt : {}), - roleDefinition: undefined + ...(typeof existingPrompt === "object" ? existingPrompt : {}), + roleDefinition: undefined, }) } @@ -105,22 +103,22 @@ const PromptsView = ({ onDone }: PromptsViewProps) => { const getAgentPromptValue = (mode: Mode): string => { const prompt = customPrompts?.[mode] - return typeof prompt === 'object' ? prompt.roleDefinition ?? getRoleDefinition(mode) : getRoleDefinition(mode); + return typeof prompt === "object" ? (prompt.roleDefinition ?? getRoleDefinition(mode)) : getRoleDefinition(mode) } const getEnhancePromptValue = (): string => { const enhance = customPrompts?.enhance - const defaultEnhance = typeof defaultPrompts.enhance === 'string' ? defaultPrompts.enhance : '' - return typeof enhance === 'string' ? enhance : defaultEnhance + const defaultEnhance = typeof defaultPrompts.enhance === "string" ? defaultPrompts.enhance : "" + return typeof enhance === "string" ? enhance : defaultEnhance } const handleTestEnhancement = () => { if (!testPrompt.trim()) return - + setIsEnhancing(true) vscode.postMessage({ type: "enhancePrompt", - text: testPrompt + text: testPrompt, }) } @@ -147,19 +145,23 @@ const PromptsView = ({ onDone }: PromptsViewProps) => {
-
+
Custom Instructions for All Modes
-
- These instructions apply to all modes. They provide a base set of behaviors that can be enhanced by mode-specific instructions below. +
+ These instructions apply to all modes. They provide a base set of behaviors that can be enhanced + by mode-specific instructions below.
{ - const value = (e as CustomEvent)?.detail?.target?.value || ((e as any).target as HTMLTextAreaElement).value + const value = + (e as CustomEvent)?.detail?.target?.value || + ((e as any).target as HTMLTextAreaElement).value setCustomInstructions(value || undefined) vscode.postMessage({ type: "customInstructions", - text: value.trim() || undefined + text: value.trim() || undefined, }) }} rows={4} @@ -168,32 +170,38 @@ const PromptsView = ({ onDone }: PromptsViewProps) => { data-testid="global-custom-instructions-textarea" />
- Instructions can also be loaded from vscode.postMessage({ - type: "openFile", - text: "./.clinerules", - values: { - create: true, - content: "", - } - })} - >.clinerules in your workspace. + onClick={() => + vscode.postMessage({ + type: "openFile", + text: "./.clinerules", + values: { + create: true, + content: "", + }, + }) + }> + .clinerules + {" "} + in your workspace.

Mode-Specific Prompts

-
+
{AGENT_MODES.map((tab) => ( ))}
-
-
+
+
-
+
Role Definition
handleAgentReset(activeTab)} data-testid="reset-prompt-button" - title="Revert to default" - > + title="Revert to default">
-
- Define Cline's expertise and personality for this mode. This description shapes how Cline presents itself and approaches tasks. +
+ Define Cline's expertise and personality for this mode. This description shapes how + Cline presents itself and approaches tasks.
{ data-testid={`${activeTab}-prompt-textarea`} />
-
+
Mode-specific Custom Instructions
-
- Add behavioral guidelines specific to {activeTab} mode. These instructions enhance the base behaviors defined above. +
+ Add behavioral guidelines specific to {activeTab} mode. These instructions enhance the base + behaviors defined above.
{ const prompt = customPrompts?.[activeTab] - return typeof prompt === 'object' ? prompt.customInstructions ?? '' : '' + return typeof prompt === "object" ? (prompt.customInstructions ?? "") : "" })()} onChange={(e) => { - const value = (e as CustomEvent)?.detail?.target?.value || ((e as any).target as HTMLTextAreaElement).value + const value = + (e as CustomEvent)?.detail?.target?.value || + ((e as any).target as HTMLTextAreaElement).value const existingPrompt = customPrompts?.[activeTab] updateAgentPrompt(activeTab, { - ...(typeof existingPrompt === 'object' ? existingPrompt : {}), - customInstructions: value.trim() || undefined + ...(typeof existingPrompt === "object" ? existingPrompt : {}), + customInstructions: value.trim() || undefined, }) }} rows={4} @@ -271,25 +295,34 @@ const PromptsView = ({ onDone }: PromptsViewProps) => { style={{ width: "100%" }} data-testid={`${activeTab}-custom-instructions-textarea`} /> -
- Custom instructions specific to {activeTab} mode can also be loaded from + Custom instructions specific to {activeTab} mode can also be loaded from{" "} + { // First create/update the file with current custom instructions const defaultContent = `# ${activeTab} Mode Rules\n\nAdd mode-specific rules and guidelines here.` const existingPrompt = customPrompts?.[activeTab] - const existingInstructions = typeof existingPrompt === 'object' ? existingPrompt.customInstructions : undefined + const existingInstructions = + typeof existingPrompt === "object" + ? existingPrompt.customInstructions + : undefined vscode.postMessage({ type: "updatePrompt", promptMode: activeTab, customPrompt: { - ...(typeof existingPrompt === 'object' ? existingPrompt : {}), - customInstructions: existingInstructions || defaultContent - } + ...(typeof existingPrompt === "object" ? existingPrompt : {}), + customInstructions: existingInstructions || defaultContent, + }, }) // Then open the file vscode.postMessage({ @@ -298,37 +331,40 @@ const PromptsView = ({ onDone }: PromptsViewProps) => { values: { create: true, content: "", - } + }, }) - }} - >.clinerules-{activeTab} in your workspace. + }}> + .clinerules-{activeTab} + {" "} + in your workspace.
-
+
{ vscode.postMessage({ type: "getSystemPrompt", - mode: activeTab + mode: activeTab, }) }} - data-testid="preview-prompt-button" - > + data-testid="preview-prompt-button"> Preview System Prompt

Prompt Enhancement

-
- Use prompt enhancement to get tailored suggestions or improvements for your inputs. This ensures Cline understands your intent and provides the best possible responses. +
+ Use prompt enhancement to get tailored suggestions or improvements for your inputs. This ensures + Cline understands your intent and provides the best possible responses.
@@ -337,22 +373,22 @@ const PromptsView = ({ onDone }: PromptsViewProps) => {
API Configuration
- You can select an API configuration to always use for enhancing prompts, or just use whatever is currently selected + You can select an API configuration to always use for enhancing prompts, or just use + whatever is currently selected
{ const value = e.detail?.target?.value || e.target?.value setEnhancementApiConfigId(value) vscode.postMessage({ type: "enhancementApiConfigId", - text: value + text: value, }) }} - style={{ width: "300px" }} - > + style={{ width: "300px" }}> Use currently selected API configuration {(listApiConfigMeta || []).map((config) => ( @@ -363,15 +399,29 @@ const PromptsView = ({ onDone }: PromptsViewProps) => {
-
+
Enhancement Prompt
- +
-
+
This prompt will be used to refine your input when you hit the sparkle icon in chat.
@@ -382,7 +432,7 @@ const PromptsView = ({ onDone }: PromptsViewProps) => { resize="vertical" style={{ width: "100%" }} /> - +
{ style={{ width: "100%" }} data-testid="test-prompt-textarea" /> -
+
+ appearance="primary"> Preview Prompt Enhancement
@@ -417,66 +467,68 @@ const PromptsView = ({ onDone }: PromptsViewProps) => {
{isDialogOpen && ( -
-
-
+
setIsDialogOpen(false)} style={{ - position: 'absolute', - top: '20px', - right: '20px' - }} - > + position: "absolute", + top: "20px", + right: "20px", + }}> -

{selectedPromptTitle}

-
+							

{selectedPromptTitle}

+
 								{selectedPromptContent}
 							
-
- setIsDialogOpen(false)}> - Close - +
+ setIsDialogOpen(false)}>Close
@@ -485,4 +537,4 @@ const PromptsView = ({ onDone }: PromptsViewProps) => { ) } -export default PromptsView \ No newline at end of file +export default PromptsView diff --git a/webview-ui/src/components/prompts/__tests__/PromptsView.test.tsx b/webview-ui/src/components/prompts/__tests__/PromptsView.test.tsx index 2ccc7f1..fb2d323 100644 --- a/webview-ui/src/components/prompts/__tests__/PromptsView.test.tsx +++ b/webview-ui/src/components/prompts/__tests__/PromptsView.test.tsx @@ -1,160 +1,166 @@ -import { render, screen, fireEvent } from '@testing-library/react' -import '@testing-library/jest-dom' -import PromptsView from '../PromptsView' -import { ExtensionStateContext } from '../../../context/ExtensionStateContext' -import { vscode } from '../../../utils/vscode' +import { render, screen, fireEvent } from "@testing-library/react" +import "@testing-library/jest-dom" +import PromptsView from "../PromptsView" +import { ExtensionStateContext } from "../../../context/ExtensionStateContext" +import { vscode } from "../../../utils/vscode" // Mock vscode API -jest.mock('../../../utils/vscode', () => ({ - vscode: { - postMessage: jest.fn() - } +jest.mock("../../../utils/vscode", () => ({ + vscode: { + postMessage: jest.fn(), + }, })) const mockExtensionState = { - customPrompts: {}, - listApiConfigMeta: [ - { id: 'config1', name: 'Config 1' }, - { id: 'config2', name: 'Config 2' } - ], - enhancementApiConfigId: '', - setEnhancementApiConfigId: jest.fn(), - mode: 'code', - customInstructions: 'Initial instructions', - setCustomInstructions: jest.fn() + customPrompts: {}, + listApiConfigMeta: [ + { id: "config1", name: "Config 1" }, + { id: "config2", name: "Config 2" }, + ], + enhancementApiConfigId: "", + setEnhancementApiConfigId: jest.fn(), + mode: "code", + customInstructions: "Initial instructions", + setCustomInstructions: jest.fn(), } const renderPromptsView = (props = {}) => { - const mockOnDone = jest.fn() - return render( - - - - ) + const mockOnDone = jest.fn() + return render( + + + , + ) } -describe('PromptsView', () => { - beforeEach(() => { - jest.clearAllMocks() - }) +describe("PromptsView", () => { + beforeEach(() => { + jest.clearAllMocks() + }) - it('renders all mode tabs', () => { - renderPromptsView() - expect(screen.getByTestId('code-tab')).toBeInTheDocument() - expect(screen.getByTestId('ask-tab')).toBeInTheDocument() - expect(screen.getByTestId('architect-tab')).toBeInTheDocument() - }) + it("renders all mode tabs", () => { + renderPromptsView() + expect(screen.getByTestId("code-tab")).toBeInTheDocument() + expect(screen.getByTestId("ask-tab")).toBeInTheDocument() + expect(screen.getByTestId("architect-tab")).toBeInTheDocument() + }) - it('defaults to current mode as active tab', () => { - renderPromptsView({ mode: 'ask' }) - - const codeTab = screen.getByTestId('code-tab') - const askTab = screen.getByTestId('ask-tab') - const architectTab = screen.getByTestId('architect-tab') - - expect(askTab).toHaveAttribute('data-active', 'true') - expect(codeTab).toHaveAttribute('data-active', 'false') - expect(architectTab).toHaveAttribute('data-active', 'false') - }) + it("defaults to current mode as active tab", () => { + renderPromptsView({ mode: "ask" }) - it('switches between tabs correctly', () => { - renderPromptsView({ mode: 'code' }) - - const codeTab = screen.getByTestId('code-tab') - const askTab = screen.getByTestId('ask-tab') - const architectTab = screen.getByTestId('architect-tab') - - // Initial state matches current mode (code) - expect(codeTab).toHaveAttribute('data-active', 'true') - expect(askTab).toHaveAttribute('data-active', 'false') - expect(architectTab).toHaveAttribute('data-active', 'false') - expect(architectTab).toHaveAttribute('data-active', 'false') - - // Click Ask tab - fireEvent.click(askTab) - expect(askTab).toHaveAttribute('data-active', 'true') - expect(codeTab).toHaveAttribute('data-active', 'false') - expect(architectTab).toHaveAttribute('data-active', 'false') - - // Click Architect tab - fireEvent.click(architectTab) - expect(architectTab).toHaveAttribute('data-active', 'true') - expect(askTab).toHaveAttribute('data-active', 'false') - expect(codeTab).toHaveAttribute('data-active', 'false') - }) + const codeTab = screen.getByTestId("code-tab") + const askTab = screen.getByTestId("ask-tab") + const architectTab = screen.getByTestId("architect-tab") - it('handles prompt changes correctly', () => { - renderPromptsView() - - const textarea = screen.getByTestId('code-prompt-textarea') - fireEvent(textarea, new CustomEvent('change', { - detail: { - target: { - value: 'New prompt value' - } - } - })) - - expect(vscode.postMessage).toHaveBeenCalledWith({ - type: 'updatePrompt', - promptMode: 'code', - customPrompt: { roleDefinition: 'New prompt value' } - }) - }) + expect(askTab).toHaveAttribute("data-active", "true") + expect(codeTab).toHaveAttribute("data-active", "false") + expect(architectTab).toHaveAttribute("data-active", "false") + }) - it('resets prompt to default value', () => { - renderPromptsView() - - const resetButton = screen.getByTestId('reset-prompt-button') - fireEvent.click(resetButton) - - expect(vscode.postMessage).toHaveBeenCalledWith({ - type: 'updatePrompt', - promptMode: 'code', - customPrompt: { roleDefinition: undefined } - }) - }) + it("switches between tabs correctly", () => { + renderPromptsView({ mode: "code" }) - it('handles API configuration selection', () => { - renderPromptsView() - - const dropdown = screen.getByTestId('api-config-dropdown') - fireEvent(dropdown, new CustomEvent('change', { - detail: { - target: { - value: 'config1' - } - } - })) - - expect(mockExtensionState.setEnhancementApiConfigId).toHaveBeenCalledWith('config1') - expect(vscode.postMessage).toHaveBeenCalledWith({ - type: 'enhancementApiConfigId', - text: 'config1' - }) - }) + const codeTab = screen.getByTestId("code-tab") + const askTab = screen.getByTestId("ask-tab") + const architectTab = screen.getByTestId("architect-tab") - it('handles clearing custom instructions correctly', async () => { - const setCustomInstructions = jest.fn() - renderPromptsView({ - ...mockExtensionState, - customInstructions: 'Initial instructions', - setCustomInstructions - }) + // Initial state matches current mode (code) + expect(codeTab).toHaveAttribute("data-active", "true") + expect(askTab).toHaveAttribute("data-active", "false") + expect(architectTab).toHaveAttribute("data-active", "false") + expect(architectTab).toHaveAttribute("data-active", "false") - const textarea = screen.getByTestId('global-custom-instructions-textarea') - const changeEvent = new CustomEvent('change', { - detail: { target: { value: '' } } - }) - Object.defineProperty(changeEvent, 'target', { - value: { value: '' } - }) - await fireEvent(textarea, changeEvent) + // Click Ask tab + fireEvent.click(askTab) + expect(askTab).toHaveAttribute("data-active", "true") + expect(codeTab).toHaveAttribute("data-active", "false") + expect(architectTab).toHaveAttribute("data-active", "false") - expect(setCustomInstructions).toHaveBeenCalledWith(undefined) - expect(vscode.postMessage).toHaveBeenCalledWith({ - type: 'customInstructions', - text: undefined - }) - }) -}) \ No newline at end of file + // Click Architect tab + fireEvent.click(architectTab) + expect(architectTab).toHaveAttribute("data-active", "true") + expect(askTab).toHaveAttribute("data-active", "false") + expect(codeTab).toHaveAttribute("data-active", "false") + }) + + it("handles prompt changes correctly", () => { + renderPromptsView() + + const textarea = screen.getByTestId("code-prompt-textarea") + fireEvent( + textarea, + new CustomEvent("change", { + detail: { + target: { + value: "New prompt value", + }, + }, + }), + ) + + expect(vscode.postMessage).toHaveBeenCalledWith({ + type: "updatePrompt", + promptMode: "code", + customPrompt: { roleDefinition: "New prompt value" }, + }) + }) + + it("resets prompt to default value", () => { + renderPromptsView() + + const resetButton = screen.getByTestId("reset-prompt-button") + fireEvent.click(resetButton) + + expect(vscode.postMessage).toHaveBeenCalledWith({ + type: "updatePrompt", + promptMode: "code", + customPrompt: { roleDefinition: undefined }, + }) + }) + + it("handles API configuration selection", () => { + renderPromptsView() + + const dropdown = screen.getByTestId("api-config-dropdown") + fireEvent( + dropdown, + new CustomEvent("change", { + detail: { + target: { + value: "config1", + }, + }, + }), + ) + + expect(mockExtensionState.setEnhancementApiConfigId).toHaveBeenCalledWith("config1") + expect(vscode.postMessage).toHaveBeenCalledWith({ + type: "enhancementApiConfigId", + text: "config1", + }) + }) + + it("handles clearing custom instructions correctly", async () => { + const setCustomInstructions = jest.fn() + renderPromptsView({ + ...mockExtensionState, + customInstructions: "Initial instructions", + setCustomInstructions, + }) + + const textarea = screen.getByTestId("global-custom-instructions-textarea") + const changeEvent = new CustomEvent("change", { + detail: { target: { value: "" } }, + }) + Object.defineProperty(changeEvent, "target", { + value: { value: "" }, + }) + await fireEvent(textarea, changeEvent) + + expect(setCustomInstructions).toHaveBeenCalledWith(undefined) + expect(vscode.postMessage).toHaveBeenCalledWith({ + type: "customInstructions", + text: undefined, + }) + }) +}) diff --git a/webview-ui/src/components/settings/ApiConfigManager.tsx b/webview-ui/src/components/settings/ApiConfigManager.tsx index e05af14..0a60f60 100644 --- a/webview-ui/src/components/settings/ApiConfigManager.tsx +++ b/webview-ui/src/components/settings/ApiConfigManager.tsx @@ -3,223 +3,216 @@ import { memo, useEffect, useRef, useState } from "react" import { ApiConfigMeta } from "../../../../src/shared/ExtensionMessage" interface ApiConfigManagerProps { - currentApiConfigName?: string - listApiConfigMeta?: ApiConfigMeta[] - onSelectConfig: (configName: string) => void - onDeleteConfig: (configName: string) => void - onRenameConfig: (oldName: string, newName: string) => void - onUpsertConfig: (configName: string) => void + currentApiConfigName?: string + listApiConfigMeta?: ApiConfigMeta[] + onSelectConfig: (configName: string) => void + onDeleteConfig: (configName: string) => void + onRenameConfig: (oldName: string, newName: string) => void + onUpsertConfig: (configName: string) => void } const ApiConfigManager = ({ - currentApiConfigName = "", - listApiConfigMeta = [], - onSelectConfig, - onDeleteConfig, - onRenameConfig, - onUpsertConfig, + currentApiConfigName = "", + listApiConfigMeta = [], + onSelectConfig, + onDeleteConfig, + onRenameConfig, + onUpsertConfig, }: ApiConfigManagerProps) => { - const [editState, setEditState] = useState<'new' | 'rename' | null>(null); - const [inputValue, setInputValue] = useState(""); - const inputRef = useRef(); + const [editState, setEditState] = useState<"new" | "rename" | null>(null) + const [inputValue, setInputValue] = useState("") + const inputRef = useRef() - // Focus input when entering edit mode - useEffect(() => { - if (editState) { - setTimeout(() => inputRef.current?.focus(), 0); - } - }, [editState]); + // Focus input when entering edit mode + useEffect(() => { + if (editState) { + setTimeout(() => inputRef.current?.focus(), 0) + } + }, [editState]) - // Reset edit state when current profile changes - useEffect(() => { - setEditState(null); - setInputValue(""); - }, [currentApiConfigName]); + // Reset edit state when current profile changes + useEffect(() => { + setEditState(null) + setInputValue("") + }, [currentApiConfigName]) - const handleAdd = () => { - const newConfigName = currentApiConfigName + " (copy)"; - onUpsertConfig(newConfigName); - }; + const handleAdd = () => { + const newConfigName = currentApiConfigName + " (copy)" + onUpsertConfig(newConfigName) + } - const handleStartRename = () => { - setEditState('rename'); - setInputValue(currentApiConfigName || ""); - }; + const handleStartRename = () => { + setEditState("rename") + setInputValue(currentApiConfigName || "") + } - const handleCancel = () => { - setEditState(null); - setInputValue(""); - }; + const handleCancel = () => { + setEditState(null) + setInputValue("") + } - const handleSave = () => { - const trimmedValue = inputValue.trim(); - if (!trimmedValue) return; + const handleSave = () => { + const trimmedValue = inputValue.trim() + if (!trimmedValue) return - if (editState === 'new') { - onUpsertConfig(trimmedValue); - } else if (editState === 'rename' && currentApiConfigName) { - onRenameConfig(currentApiConfigName, trimmedValue); - } + if (editState === "new") { + onUpsertConfig(trimmedValue) + } else if (editState === "rename" && currentApiConfigName) { + onRenameConfig(currentApiConfigName, trimmedValue) + } - setEditState(null); - setInputValue(""); - }; + setEditState(null) + setInputValue("") + } - const handleDelete = () => { - if (!currentApiConfigName || !listApiConfigMeta || listApiConfigMeta.length <= 1) return; - - // Let the extension handle both deletion and selection - onDeleteConfig(currentApiConfigName); - }; + const handleDelete = () => { + if (!currentApiConfigName || !listApiConfigMeta || listApiConfigMeta.length <= 1) return - const isOnlyProfile = listApiConfigMeta?.length === 1; + // Let the extension handle both deletion and selection + onDeleteConfig(currentApiConfigName) + } - return ( -
-
- + const isOnlyProfile = listApiConfigMeta?.length === 1 - {editState ? ( -
- setInputValue(e.target.value)} - placeholder={editState === 'new' ? "Enter profile name" : "Enter new name"} - style={{ flexGrow: 1 }} - onKeyDown={(e: any) => { - if (e.key === 'Enter' && inputValue.trim()) { - handleSave(); - } else if (e.key === 'Escape') { - handleCancel(); - } - }} - /> - - - - - - -
- ) : ( - <> -
- - - - - {currentApiConfigName && ( - <> - - - - - - - - )} -
-

- Save different API configurations to quickly switch between providers and settings -

- - )} -
-
- ) + return ( +
+
+ + + {editState ? ( +
+ setInputValue(e.target.value)} + placeholder={editState === "new" ? "Enter profile name" : "Enter new name"} + style={{ flexGrow: 1 }} + onKeyDown={(e: any) => { + if (e.key === "Enter" && inputValue.trim()) { + handleSave() + } else if (e.key === "Escape") { + handleCancel() + } + }} + /> + + + + + + +
+ ) : ( + <> +
+ + + + + {currentApiConfigName && ( + <> + + + + + + + + )} +
+

+ Save different API configurations to quickly switch between providers and settings +

+ + )} +
+
+ ) } -export default memo(ApiConfigManager) \ No newline at end of file +export default memo(ApiConfigManager) diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 13edd85..8e6fe42 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -1,11 +1,6 @@ import { Checkbox, Dropdown } from "vscrui" import type { DropdownOption } from "vscrui" -import { - VSCodeLink, - VSCodeRadio, - VSCodeRadioGroup, - VSCodeTextField -} from "@vscode/webview-ui-toolkit/react" +import { VSCodeLink, VSCodeRadio, VSCodeRadioGroup, VSCodeTextField } from "@vscode/webview-ui-toolkit/react" import { Fragment, memo, useCallback, useEffect, useMemo, useState } from "react" import { useEvent, useInterval } from "react-use" import { @@ -83,7 +78,12 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = requestLocalModels() } }, [selectedProvider, requestLocalModels]) - useInterval(requestLocalModels, selectedProvider === "ollama" || selectedProvider === "lmstudio" || selectedProvider === "vscode-lm" ? 2000 : null) + useInterval( + requestLocalModels, + selectedProvider === "ollama" || selectedProvider === "lmstudio" || selectedProvider === "vscode-lm" + ? 2000 + : null, + ) const handleMessage = useCallback((event: MessageEvent) => { const message: ExtensionMessage = event.data if (message.type === "ollamaModels" && message.ollamaModels) { @@ -102,17 +102,19 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = ...Object.keys(models).map((modelId) => ({ value: modelId, label: modelId, - })) + })), ] return ( {handleInputChange("apiModelId")({ - target: { - value: (value as DropdownOption).value - } - })}} + onChange={(value: unknown) => { + handleInputChange("apiModelId")({ + target: { + value: (value as DropdownOption).value, + }, + }) + }} style={{ width: "100%" }} options={options} /> @@ -131,8 +133,8 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = onChange={(value: unknown) => { handleInputChange("apiProvider")({ target: { - value: (value as DropdownOption).value - } + value: (value as DropdownOption).value, + }, }) }} style={{ minWidth: 130, position: "relative", zIndex: OPENROUTER_MODEL_PICKER_Z_INDEX + 1 }} @@ -149,7 +151,7 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = { value: "vscode-lm", label: "VS Code LM API" }, { value: "mistral", label: "Mistral" }, { value: "lmstudio", label: "LM Studio" }, - { value: "ollama", label: "Ollama" } + { value: "ollama", label: "Ollama" }, ]} />
@@ -331,7 +333,8 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = target: { value: checked }, }) }}> - Compress prompts and message chains to the context size (OpenRouter Transforms) + Compress prompts and message chains to the context size ( + OpenRouter Transforms)
@@ -371,11 +374,13 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = id="aws-region-dropdown" value={apiConfiguration?.awsRegion || ""} style={{ width: "100%" }} - onChange={(value: unknown) => {handleInputChange("awsRegion")({ - target: { - value: (value as DropdownOption).value - } - })}} + onChange={(value: unknown) => { + handleInputChange("awsRegion")({ + target: { + value: (value as DropdownOption).value, + }, + }) + }} options={[ { value: "", label: "Select a region..." }, { value: "us-east-1", label: "us-east-1" }, @@ -392,7 +397,7 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = { value: "eu-west-2", label: "eu-west-2" }, { value: "eu-west-3", label: "eu-west-3" }, { value: "sa-east-1", label: "sa-east-1" }, - { value: "us-gov-west-1", label: "us-gov-west-1" } + { value: "us-gov-west-1", label: "us-gov-west-1" }, ]} />
@@ -435,18 +440,20 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = id="vertex-region-dropdown" value={apiConfiguration?.vertexRegion || ""} style={{ width: "100%" }} - onChange={(value: unknown) => {handleInputChange("vertexRegion")({ - target: { - value: (value as DropdownOption).value - } - })}} + onChange={(value: unknown) => { + handleInputChange("vertexRegion")({ + target: { + value: (value as DropdownOption).value, + }, + }) + }} options={[ { value: "", label: "Select a region..." }, { value: "us-east5", label: "us-east5" }, { value: "us-central1", label: "us-central1" }, { value: "europe-west1", label: "europe-west1" }, { value: "europe-west4", label: "europe-west4" }, - { value: "asia-southeast1", label: "asia-southeast1" } + { value: "asia-southeast1", label: "asia-southeast1" }, ]} />
@@ -520,7 +527,7 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = API Key -
+
{ @@ -669,19 +676,21 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = {vsCodeLmModels.length > 0 ? ( { - const valueStr = (value as DropdownOption)?.value; + const valueStr = (value as DropdownOption)?.value if (!valueStr) { return } - const [vendor, family] = valueStr.split('/'); + const [vendor, family] = valueStr.split("/") handleInputChange("vsCodeLmModelSelector")({ target: { - value: { vendor, family } - } + value: { vendor, family }, + }, }) }} style={{ width: "100%" }} @@ -689,18 +698,20 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = { value: "", label: "Select a model..." }, ...vsCodeLmModels.map((model) => ({ value: `${model.vendor}/${model.family}`, - label: `${model.vendor} - ${model.family}` - })) + label: `${model.vendor} - ${model.family}`, + })), ]} /> ) : ( -

- The VS Code Language Model API allows you to run models provided by other VS Code extensions (including but not limited to GitHub Copilot). - The easiest way to get started is to install the Copilot and Copilot Chat extensions from the VS Code Marketplace. +

+ The VS Code Language Model API allows you to run models provided by other VS Code + extensions (including but not limited to GitHub Copilot). The easiest way to get started + is to install the Copilot and Copilot Chat extensions from the VS Code Marketplace.

)} @@ -711,7 +722,8 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = color: "var(--vscode-errorForeground)", fontWeight: 500, }}> - Note: This is a very experimental integration and may not work as expected. Please report any issues to the Roo-Cline GitHub repository. + Note: This is a very experimental integration and may not work as expected. Please report + any issues to the Roo-Cline GitHub repository.

@@ -1042,9 +1054,9 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) { case "vscode-lm": return { selectedProvider: provider, - selectedModelId: apiConfiguration?.vsCodeLmModelSelector ? - `${apiConfiguration.vsCodeLmModelSelector.vendor}/${apiConfiguration.vsCodeLmModelSelector.family}` : - "", + selectedModelId: apiConfiguration?.vsCodeLmModelSelector + ? `${apiConfiguration.vsCodeLmModelSelector.vendor}/${apiConfiguration.vsCodeLmModelSelector.family}` + : "", selectedModelInfo: { ...openAiModelInfoSaneDefaults, supportsImages: false, // VSCode LM API currently doesn't support images diff --git a/webview-ui/src/components/settings/GlamaModelPicker.tsx b/webview-ui/src/components/settings/GlamaModelPicker.tsx index 32f27f5..6d1dee1 100644 --- a/webview-ui/src/components/settings/GlamaModelPicker.tsx +++ b/webview-ui/src/components/settings/GlamaModelPicker.tsx @@ -38,7 +38,6 @@ const GlamaModelPicker: React.FC = () => { return normalizeApiConfiguration(apiConfiguration) }, [apiConfiguration]) - useEffect(() => { if (apiConfiguration?.glamaModelId && apiConfiguration?.glamaModelId !== searchTerm) { setSearchTerm(apiConfiguration?.glamaModelId) @@ -47,18 +46,15 @@ const GlamaModelPicker: React.FC = () => { const debouncedRefreshModels = useMemo( () => - debounce( - () => { - vscode.postMessage({ type: "refreshGlamaModels" }) - }, - 50 - ), - [] + debounce(() => { + vscode.postMessage({ type: "refreshGlamaModels" }) + }, 50), + [], ) useMount(() => { debouncedRefreshModels() - + // Cleanup debounced function return () => { debouncedRefreshModels.clear() @@ -91,7 +87,7 @@ const GlamaModelPicker: React.FC = () => { const fzf = useMemo(() => { return new Fzf(searchableItems, { - selector: item => item.html + selector: (item) => item.html, }) }, [searchableItems]) @@ -99,9 +95,9 @@ const GlamaModelPicker: React.FC = () => { if (!searchTerm) return searchableItems const searchResults = fzf.find(searchTerm) - return searchResults.map(result => ({ + return searchResults.map((result) => ({ ...result.item, - html: highlightFzfMatch(result.item.html, Array.from(result.positions), "model-item-highlight") + html: highlightFzfMatch(result.item.html, Array.from(result.positions), "model-item-highlight"), })) }, [searchableItems, searchTerm, fzf]) diff --git a/webview-ui/src/components/settings/OpenAiModelPicker.tsx b/webview-ui/src/components/settings/OpenAiModelPicker.tsx index 33166ba..71957e7 100644 --- a/webview-ui/src/components/settings/OpenAiModelPicker.tsx +++ b/webview-ui/src/components/settings/OpenAiModelPicker.tsx @@ -37,19 +37,16 @@ const OpenAiModelPicker: React.FC = () => { const debouncedRefreshModels = useMemo( () => - debounce( - (baseUrl: string, apiKey: string) => { - vscode.postMessage({ - type: "refreshOpenAiModels", - values: { - baseUrl, - apiKey - } - }) - }, - 50 - ), - [] + debounce((baseUrl: string, apiKey: string) => { + vscode.postMessage({ + type: "refreshOpenAiModels", + values: { + baseUrl, + apiKey, + }, + }) + }, 50), + [], ) useEffect(() => { @@ -57,10 +54,7 @@ const OpenAiModelPicker: React.FC = () => { return } - debouncedRefreshModels( - apiConfiguration.openAiBaseUrl, - apiConfiguration.openAiApiKey - ) + debouncedRefreshModels(apiConfiguration.openAiBaseUrl, apiConfiguration.openAiApiKey) // Cleanup debounced function return () => { @@ -94,7 +88,7 @@ const OpenAiModelPicker: React.FC = () => { const fzf = useMemo(() => { return new Fzf(searchableItems, { - selector: item => item.html + selector: (item) => item.html, }) }, [searchableItems]) @@ -102,9 +96,9 @@ const OpenAiModelPicker: React.FC = () => { if (!searchTerm) return searchableItems const searchResults = fzf.find(searchTerm) - return searchResults.map(result => ({ + return searchResults.map((result) => ({ ...result.item, - html: highlightFzfMatch(result.item.html, Array.from(result.positions), "model-item-highlight") + html: highlightFzfMatch(result.item.html, Array.from(result.positions), "model-item-highlight"), })) }, [searchableItems, searchTerm, fzf]) diff --git a/webview-ui/src/components/settings/OpenRouterModelPicker.tsx b/webview-ui/src/components/settings/OpenRouterModelPicker.tsx index 568d99d..a9508b8 100644 --- a/webview-ui/src/components/settings/OpenRouterModelPicker.tsx +++ b/webview-ui/src/components/settings/OpenRouterModelPicker.tsx @@ -46,18 +46,15 @@ const OpenRouterModelPicker: React.FC = () => { const debouncedRefreshModels = useMemo( () => - debounce( - () => { - vscode.postMessage({ type: "refreshOpenRouterModels" }) - }, - 50 - ), - [] + debounce(() => { + vscode.postMessage({ type: "refreshOpenRouterModels" }) + }, 50), + [], ) useMount(() => { debouncedRefreshModels() - + // Cleanup debounced function return () => { debouncedRefreshModels.clear() @@ -90,7 +87,7 @@ const OpenRouterModelPicker: React.FC = () => { const fzf = useMemo(() => { return new Fzf(searchableItems, { - selector: item => item.html + selector: (item) => item.html, }) }, [searchableItems]) @@ -98,9 +95,9 @@ const OpenRouterModelPicker: React.FC = () => { if (!searchTerm) return searchableItems const searchResults = fzf.find(searchTerm) - return searchResults.map(result => ({ + return searchResults.map((result) => ({ ...result.item, - html: highlightFzfMatch(result.item.html, Array.from(result.positions), "model-item-highlight") + html: highlightFzfMatch(result.item.html, Array.from(result.positions), "model-item-highlight"), })) }, [searchableItems, searchTerm, fzf]) diff --git a/webview-ui/src/components/settings/SettingsView.tsx b/webview-ui/src/components/settings/SettingsView.tsx index beafc74..870fda8 100644 --- a/webview-ui/src/components/settings/SettingsView.tsx +++ b/webview-ui/src/components/settings/SettingsView.tsx @@ -1,4 +1,10 @@ -import { VSCodeButton, VSCodeCheckbox, VSCodeLink, VSCodeTextArea, VSCodeTextField } from "@vscode/webview-ui-toolkit/react" +import { + VSCodeButton, + VSCodeCheckbox, + VSCodeLink, + VSCodeTextArea, + VSCodeTextField, +} from "@vscode/webview-ui-toolkit/react" import { memo, useEffect, useState } from "react" import { useExtensionState } from "../../context/ExtensionStateContext" import { validateApiConfiguration, validateModelId } from "../../utils/validate" @@ -61,7 +67,7 @@ const SettingsView = ({ onDone }: SettingsViewProps) => { listApiConfigMeta, mode, setMode, - experimentalDiffStrategy, + experimentalDiffStrategy, setExperimentalDiffStrategy, } = useExtensionState() const [apiErrorMessage, setApiErrorMessage] = useState(undefined) @@ -77,7 +83,7 @@ const SettingsView = ({ onDone }: SettingsViewProps) => { if (!apiValidationResult && !modelIdValidationResult) { vscode.postMessage({ type: "apiConfiguration", - apiConfiguration + apiConfiguration, }) vscode.postMessage({ type: "customInstructions", text: customInstructions }) vscode.postMessage({ type: "alwaysAllowReadOnly", bool: alwaysAllowReadOnly }) @@ -102,10 +108,10 @@ const SettingsView = ({ onDone }: SettingsViewProps) => { vscode.postMessage({ type: "upsertApiConfiguration", text: currentApiConfigName, - apiConfiguration + apiConfiguration, }) vscode.postMessage({ type: "mode", text: mode }) - vscode.postMessage({ type: "experimentalDiffStrategy", bool: experimentalDiffStrategy }) + vscode.postMessage({ type: "experimentalDiffStrategy", bool: experimentalDiffStrategy }) onDone() } } @@ -135,7 +141,7 @@ const SettingsView = ({ onDone }: SettingsViewProps) => { setCommandInput("") vscode.postMessage({ type: "allowedCommands", - commands: newCommands + commands: newCommands, }) } } @@ -161,53 +167,53 @@ const SettingsView = ({ onDone }: SettingsViewProps) => { marginBottom: "17px", paddingRight: 17, }}> -

Settings

Done
-

Provider Settings

+

+ Provider Settings +

{ vscode.postMessage({ type: "loadApiConfiguration", - text: configName + text: configName, }) }} onDeleteConfig={(configName: string) => { vscode.postMessage({ type: "deleteApiConfiguration", - text: configName + text: configName, }) }} onRenameConfig={(oldName: string, newName: string) => { vscode.postMessage({ type: "renameApiConfiguration", values: { oldName, newName }, - apiConfiguration + apiConfiguration, }) }} onUpsertConfig={(configName: string) => { vscode.postMessage({ type: "upsertApiConfiguration", text: configName, - apiConfiguration + apiConfiguration, }) }} /> - +
-

Agent Settings

+

+ Agent Settings +

@@ -225,22 +231,27 @@ const SettingsView = ({ onDone }: SettingsViewProps) => { color: "var(--vscode-input-foreground)", border: "1px solid var(--vscode-input-border)", borderRadius: "2px", - height: "28px" + height: "28px", }}> -

- Select the mode that best fits your needs. Code mode focuses on implementation details, Architect mode on high-level design, and Ask mode on asking questions about the codebase. +

+ Select the mode that best fits your needs. Code mode focuses on implementation details, + Architect mode on high-level design, and Ask mode on asking questions about the + codebase.

- + -

+

Select the language that Cline should use for communication.

@@ -298,7 +310,11 @@ const SettingsView = ({ onDone }: SettingsViewProps) => { marginTop: "5px", color: "var(--vscode-descriptionForeground)", }}> - These instructions are added to the end of the system prompt sent with every request. Custom instructions set in .clinerules in the working directory are also included. For mode-specific instructions, use the Prompts tab in the top menu. + These instructions are added to the end of the system prompt sent with every request. Custom + instructions set in .clinerules in the working directory are also included. For + mode-specific instructions, use the{" "} + Prompts tab + in the top menu.

@@ -306,8 +322,8 @@ const SettingsView = ({ onDone }: SettingsViewProps) => {
-
- Terminal output limit +
+ Terminal output limit { onChange={(e) => setTerminalOutputLineLimit(parseInt(e.target.value))} style={{ flexGrow: 1, - accentColor: 'var(--vscode-button-background)', - height: '2px' + accentColor: "var(--vscode-button-background)", + height: "2px", }} /> - - {terminalOutputLineLimit ?? 500} - + {terminalOutputLineLimit ?? 500}

- Maximum number of lines to include in terminal output when executing commands. When exceeded lines will be removed from the middle, saving tokens. + Maximum number of lines to include in terminal output when executing commands. When exceeded + lines will be removed from the middle, saving tokens.

- { - setDiffEnabled(e.target.checked) - if (!e.target.checked) { - // Reset experimental strategy when diffs are disabled - setExperimentalDiffStrategy(false) - } - }}> + { + setDiffEnabled(e.target.checked) + if (!e.target.checked) { + // Reset experimental strategy when diffs are disabled + setExperimentalDiffStrategy(false) + } + }}> Enable editing through diffs

{ marginTop: "5px", color: "var(--vscode-descriptionForeground)", }}> - When enabled, Cline will be able to edit files more quickly and will automatically reject truncated full-file writes. Works best with the latest Claude 3.5 Sonnet model. + When enabled, Cline will be able to edit files more quickly and will automatically reject + truncated full-file writes. Works best with the latest Claude 3.5 Sonnet model.

{diffEnabled && (
-
+
⚠️ { Use experimental unified diff strategy
-

- Enable the experimental unified diff strategy. This strategy might reduce the number of retries caused by model errors but may cause unexpected behavior or incorrect edits. +

+ Enable the experimental unified diff strategy. This strategy might reduce the number of + retries caused by model errors but may cause unexpected behavior or incorrect edits. Only enable if you understand the risks and are willing to carefully review all changes.

-
- Match precision +
+ Match precision { step="0.005" value={fuzzyMatchThreshold ?? 1.0} onChange={(e) => { - setFuzzyMatchThreshold(parseFloat(e.target.value)); + setFuzzyMatchThreshold(parseFloat(e.target.value)) }} style={{ flexGrow: 1, - accentColor: 'var(--vscode-button-background)', - height: '2px' + accentColor: "var(--vscode-button-background)", + height: "2px", }} /> - + {Math.round((fuzzyMatchThreshold || 1) * 100)}%
-

- This slider controls how precisely code sections must match when applying diffs. Lower values allow more flexible matching but increase the risk of incorrect replacements. Use values below 100% with extreme caution. +

+ This slider controls how precisely code sections must match when applying diffs. Lower + values allow more flexible matching but increase the risk of incorrect replacements. Use + values below 100% with extreme caution.

)} @@ -409,11 +440,20 @@ const SettingsView = ({ onDone }: SettingsViewProps) => {

-
-

⚠️ High-Risk Auto-Approve Settings

+
+

+ ⚠️ High-Risk Auto-Approve Settings +

- The following settings allow Cline to automatically perform potentially dangerous operations without requiring approval. - Enable these settings only if you fully trust the AI and understand the associated security risks. + The following settings allow Cline to automatically perform potentially dangerous operations + without requiring approval. Enable these settings only if you fully trust the AI and understand + the associated security risks.

@@ -427,7 +467,7 @@ const SettingsView = ({ onDone }: SettingsViewProps) => {

{alwaysAllowWrite && (
-
+
{ onChange={(e) => setWriteDelayMs(parseInt(e.target.value))} style={{ flex: 1, - accentColor: 'var(--vscode-button-background)', - height: '2px' + accentColor: "var(--vscode-button-background)", + height: "2px", }} /> - - {writeDelayMs}ms - + {writeDelayMs}ms
-

+

Delay after writes to allow diagnostics to detect potential problems

@@ -459,7 +502,8 @@ const SettingsView = ({ onDone }: SettingsViewProps) => { Always approve browser actions

- Automatically perform browser actions without requiring approval
+ Automatically perform browser actions without requiring approval +
Note: Only applies when the model supports computer use

@@ -475,7 +519,7 @@ const SettingsView = ({ onDone }: SettingsViewProps) => {

{alwaysApproveResubmit && (
-
+
{ onChange={(e) => setRequestDelaySeconds(parseInt(e.target.value))} style={{ flex: 1, - accentColor: 'var(--vscode-button-background)', - height: '2px' + accentColor: "var(--vscode-button-background)", + height: "2px", }} /> - - {requestDelaySeconds}s - + {requestDelaySeconds}s
-

+

Delay before retrying the request

@@ -507,7 +554,8 @@ const SettingsView = ({ onDone }: SettingsViewProps) => { Always approve MCP tools

- Enable auto-approval of individual MCP tools in the MCP Servers view (requires both this setting and the tool's individual "Always allow" checkbox) + Enable auto-approval of individual MCP tools in the MCP Servers view (requires both this + setting and the tool's individual "Always allow" checkbox)

@@ -524,20 +572,22 @@ const SettingsView = ({ onDone }: SettingsViewProps) => { {alwaysAllowExecute && (
Allowed Auto-Execute Commands -

- Command prefixes that can be auto-executed when "Always approve execute operations" is enabled. +

+ Command prefixes that can be auto-executed when "Always approve execute operations" + is enabled.

-
+
setCommandInput(e.target.value)} onKeyDown={(e: any) => { - if (e.key === 'Enter') { + if (e.key === "Enter") { e.preventDefault() handleAddCommand() } @@ -545,51 +595,53 @@ const SettingsView = ({ onDone }: SettingsViewProps) => { placeholder="Enter command prefix (e.g., 'git ')" style={{ flexGrow: 1 }} /> - - Add - + Add
-
+
{(allowedCommands ?? []).map((cmd, index) => ( -
+
{cmd} { - const newCommands = (allowedCommands ?? []).filter((_, i) => i !== index) + const newCommands = (allowedCommands ?? []).filter( + (_, i) => i !== index, + ) setAllowedCommands(newCommands) vscode.postMessage({ type: "allowedCommands", - commands: newCommands + commands: newCommands, }) - }} - > + }}>
@@ -603,8 +655,12 @@ const SettingsView = ({ onDone }: SettingsViewProps) => {
-

Browser Settings

- +

+ Browser Settings +

+ -

- Select the viewport size for browser interactions. This affects how websites are displayed and interacted with. +

+ Select the viewport size for browser interactions. This affects how websites are + displayed and interacted with.

-
- Screenshot quality +
+ Screenshot quality { onChange={(e) => setScreenshotQuality(parseInt(e.target.value))} style={{ flexGrow: 1, - accentColor: 'var(--vscode-button-background)', - height: '2px' + accentColor: "var(--vscode-button-background)", + height: "2px", }} /> - - {screenshotQuality ?? 75}% - + {screenshotQuality ?? 75}%
-

- Adjust the WebP quality of browser screenshots. Higher values provide clearer screenshots but increase token usage. +

+ Adjust the WebP quality of browser screenshots. Higher values provide clearer + screenshots but increase token usage.

-

Notification Settings

- setSoundEnabled(e.target.checked)}> +

+ Notification Settings +

+ setSoundEnabled(e.target.checked)}> Enable sound effects

{

{soundEnabled && (
-
- Volume +
+ Volume { onChange={(e) => setSoundVolume(parseFloat(e.target.value))} style={{ flexGrow: 1, - accentColor: 'var(--vscode-button-background)', - height: '2px' + accentColor: "var(--vscode-button-background)", + height: "2px", }} aria-label="Volume" /> - + {((soundVolume ?? 0.5) * 100).toFixed(0)}%
@@ -733,7 +795,8 @@ const SettingsView = ({ onDone }: SettingsViewProps) => { If you have any questions or feedback, feel free to open an issue at{" "} github.com/RooVetGit/Roo-Cline - or join {" "} + {" "} + or join{" "} reddit.com/r/roocline diff --git a/webview-ui/src/components/settings/__tests__/ApiConfigManager.test.tsx b/webview-ui/src/components/settings/__tests__/ApiConfigManager.test.tsx index 2f326f4..784d443 100644 --- a/webview-ui/src/components/settings/__tests__/ApiConfigManager.test.tsx +++ b/webview-ui/src/components/settings/__tests__/ApiConfigManager.test.tsx @@ -1,154 +1,136 @@ -import { render, screen, fireEvent } from '@testing-library/react'; -import '@testing-library/jest-dom'; -import ApiConfigManager from '../ApiConfigManager'; +import { render, screen, fireEvent } from "@testing-library/react" +import "@testing-library/jest-dom" +import ApiConfigManager from "../ApiConfigManager" // Mock VSCode components -jest.mock('@vscode/webview-ui-toolkit/react', () => ({ - VSCodeButton: ({ children, onClick, title, disabled }: any) => ( - - ), - VSCodeTextField: ({ value, onInput, placeholder }: any) => ( - onInput(e)} - placeholder={placeholder} - ref={undefined} // Explicitly set ref to undefined to avoid warning - /> - ), -})); +jest.mock("@vscode/webview-ui-toolkit/react", () => ({ + VSCodeButton: ({ children, onClick, title, disabled }: any) => ( + + ), + VSCodeTextField: ({ value, onInput, placeholder }: any) => ( + onInput(e)} + placeholder={placeholder} + ref={undefined} // Explicitly set ref to undefined to avoid warning + /> + ), +})) -describe('ApiConfigManager', () => { - const mockOnSelectConfig = jest.fn(); - const mockOnDeleteConfig = jest.fn(); - const mockOnRenameConfig = jest.fn(); - const mockOnUpsertConfig = jest.fn(); +describe("ApiConfigManager", () => { + const mockOnSelectConfig = jest.fn() + const mockOnDeleteConfig = jest.fn() + const mockOnRenameConfig = jest.fn() + const mockOnUpsertConfig = jest.fn() - const defaultProps = { - currentApiConfigName: 'Default Config', - listApiConfigMeta: [ - { name: 'Default Config' }, - { name: 'Another Config' } - ], - onSelectConfig: mockOnSelectConfig, - onDeleteConfig: mockOnDeleteConfig, - onRenameConfig: mockOnRenameConfig, - onUpsertConfig: mockOnUpsertConfig, - }; + const defaultProps = { + currentApiConfigName: "Default Config", + listApiConfigMeta: [{ name: "Default Config" }, { name: "Another Config" }], + onSelectConfig: mockOnSelectConfig, + onDeleteConfig: mockOnDeleteConfig, + onRenameConfig: mockOnRenameConfig, + onUpsertConfig: mockOnUpsertConfig, + } - beforeEach(() => { - jest.clearAllMocks(); - }); + beforeEach(() => { + jest.clearAllMocks() + }) - it('immediately creates a copy when clicking add button', () => { - render(); + it("immediately creates a copy when clicking add button", () => { + render() - // Find and click the add button - const addButton = screen.getByTitle('Add profile'); - fireEvent.click(addButton); + // Find and click the add button + const addButton = screen.getByTitle("Add profile") + fireEvent.click(addButton) - // Verify that onUpsertConfig was called with the correct name - expect(mockOnUpsertConfig).toHaveBeenCalledTimes(1); - expect(mockOnUpsertConfig).toHaveBeenCalledWith('Default Config (copy)'); - }); + // Verify that onUpsertConfig was called with the correct name + expect(mockOnUpsertConfig).toHaveBeenCalledTimes(1) + expect(mockOnUpsertConfig).toHaveBeenCalledWith("Default Config (copy)") + }) - it('creates copy with correct name when current config has spaces', () => { - render( - - ); + it("creates copy with correct name when current config has spaces", () => { + render() - const addButton = screen.getByTitle('Add profile'); - fireEvent.click(addButton); + const addButton = screen.getByTitle("Add profile") + fireEvent.click(addButton) - expect(mockOnUpsertConfig).toHaveBeenCalledWith('My Test Config (copy)'); - }); + expect(mockOnUpsertConfig).toHaveBeenCalledWith("My Test Config (copy)") + }) - it('handles empty current config name gracefully', () => { - render( - - ); + it("handles empty current config name gracefully", () => { + render() - const addButton = screen.getByTitle('Add profile'); - fireEvent.click(addButton); + const addButton = screen.getByTitle("Add profile") + fireEvent.click(addButton) - expect(mockOnUpsertConfig).toHaveBeenCalledWith(' (copy)'); - }); + expect(mockOnUpsertConfig).toHaveBeenCalledWith(" (copy)") + }) - it('allows renaming the current config', () => { - render(); - - // Start rename - const renameButton = screen.getByTitle('Rename profile'); - fireEvent.click(renameButton); + it("allows renaming the current config", () => { + render() - // Find input and enter new name - const input = screen.getByDisplayValue('Default Config'); - fireEvent.input(input, { target: { value: 'New Name' } }); + // Start rename + const renameButton = screen.getByTitle("Rename profile") + fireEvent.click(renameButton) - // Save - const saveButton = screen.getByTitle('Save'); - fireEvent.click(saveButton); + // Find input and enter new name + const input = screen.getByDisplayValue("Default Config") + fireEvent.input(input, { target: { value: "New Name" } }) - expect(mockOnRenameConfig).toHaveBeenCalledWith('Default Config', 'New Name'); - }); + // Save + const saveButton = screen.getByTitle("Save") + fireEvent.click(saveButton) - it('allows selecting a different config', () => { - render(); - - const select = screen.getByRole('combobox'); - fireEvent.change(select, { target: { value: 'Another Config' } }); + expect(mockOnRenameConfig).toHaveBeenCalledWith("Default Config", "New Name") + }) - expect(mockOnSelectConfig).toHaveBeenCalledWith('Another Config'); - }); + it("allows selecting a different config", () => { + render() - it('allows deleting the current config when not the only one', () => { - render(); - - const deleteButton = screen.getByTitle('Delete profile'); - expect(deleteButton).not.toBeDisabled(); - - fireEvent.click(deleteButton); - expect(mockOnDeleteConfig).toHaveBeenCalledWith('Default Config'); - }); + const select = screen.getByRole("combobox") + fireEvent.change(select, { target: { value: "Another Config" } }) - it('disables delete button when only one config exists', () => { - render( - - ); - - const deleteButton = screen.getByTitle('Cannot delete the only profile'); - expect(deleteButton).toHaveAttribute('disabled'); - }); + expect(mockOnSelectConfig).toHaveBeenCalledWith("Another Config") + }) - it('cancels rename operation when clicking cancel', () => { - render(); - - // Start rename - const renameButton = screen.getByTitle('Rename profile'); - fireEvent.click(renameButton); + it("allows deleting the current config when not the only one", () => { + render() - // Find input and enter new name - const input = screen.getByDisplayValue('Default Config'); - fireEvent.input(input, { target: { value: 'New Name' } }); + const deleteButton = screen.getByTitle("Delete profile") + expect(deleteButton).not.toBeDisabled() - // Cancel - const cancelButton = screen.getByTitle('Cancel'); - fireEvent.click(cancelButton); + fireEvent.click(deleteButton) + expect(mockOnDeleteConfig).toHaveBeenCalledWith("Default Config") + }) - // Verify rename was not called - expect(mockOnRenameConfig).not.toHaveBeenCalled(); - - // Verify we're back to normal view - expect(screen.queryByDisplayValue('New Name')).not.toBeInTheDocument(); - }); -}); \ No newline at end of file + it("disables delete button when only one config exists", () => { + render() + + const deleteButton = screen.getByTitle("Cannot delete the only profile") + expect(deleteButton).toHaveAttribute("disabled") + }) + + it("cancels rename operation when clicking cancel", () => { + render() + + // Start rename + const renameButton = screen.getByTitle("Rename profile") + fireEvent.click(renameButton) + + // Find input and enter new name + const input = screen.getByDisplayValue("Default Config") + fireEvent.input(input, { target: { value: "New Name" } }) + + // Cancel + const cancelButton = screen.getByTitle("Cancel") + fireEvent.click(cancelButton) + + // Verify rename was not called + expect(mockOnRenameConfig).not.toHaveBeenCalled() + + // Verify we're back to normal view + expect(screen.queryByDisplayValue("New Name")).not.toBeInTheDocument() + }) +}) diff --git a/webview-ui/src/components/settings/__tests__/SettingsView.test.tsx b/webview-ui/src/components/settings/__tests__/SettingsView.test.tsx index a82e85a..4beb30d 100644 --- a/webview-ui/src/components/settings/__tests__/SettingsView.test.tsx +++ b/webview-ui/src/components/settings/__tests__/SettingsView.test.tsx @@ -1,336 +1,340 @@ -import React from 'react' -import { render, screen, fireEvent } from '@testing-library/react' -import SettingsView from '../SettingsView' -import { ExtensionStateContextProvider } from '../../../context/ExtensionStateContext' -import { vscode } from '../../../utils/vscode' +import React from "react" +import { render, screen, fireEvent } from "@testing-library/react" +import SettingsView from "../SettingsView" +import { ExtensionStateContextProvider } from "../../../context/ExtensionStateContext" +import { vscode } from "../../../utils/vscode" // Mock vscode API -jest.mock('../../../utils/vscode', () => ({ - vscode: { - postMessage: jest.fn(), - }, +jest.mock("../../../utils/vscode", () => ({ + vscode: { + postMessage: jest.fn(), + }, })) // Mock ApiConfigManager component -jest.mock('../ApiConfigManager', () => ({ - __esModule: true, - default: ({ currentApiConfigName, listApiConfigMeta, onSelectConfig, onDeleteConfig, onRenameConfig, onUpsertConfig }: any) => ( -
- Current config: {currentApiConfigName} -
- ) +jest.mock("../ApiConfigManager", () => ({ + __esModule: true, + default: ({ + currentApiConfigName, + listApiConfigMeta, + onSelectConfig, + onDeleteConfig, + onRenameConfig, + onUpsertConfig, + }: any) => ( +
+ Current config: {currentApiConfigName} +
+ ), })) // Mock VSCode components -jest.mock('@vscode/webview-ui-toolkit/react', () => ({ - VSCodeButton: ({ children, onClick, appearance }: any) => ( - appearance === 'icon' ? - : - - ), - VSCodeCheckbox: ({ children, onChange, checked }: any) => ( - - ), - VSCodeTextField: ({ value, onInput, placeholder }: any) => ( - onInput({ target: { value: e.target.value } })} - placeholder={placeholder} - /> - ), - VSCodeTextArea: () =>