Merge branch 'main' into improve_diff_prompt

This commit is contained in:
Daniel
2025-01-18 17:01:21 -05:00
committed by GitHub
191 changed files with 32530 additions and 35310 deletions

View File

@@ -2,19 +2,19 @@
const getReleaseLine = async (changeset) => { const getReleaseLine = async (changeset) => {
const [firstLine] = changeset.summary const [firstLine] = changeset.summary
.split('\n') .split("\n")
.map(l => l.trim()) .map((l) => l.trim())
.filter(Boolean); .filter(Boolean)
return `- ${firstLine}`; return `- ${firstLine}`
}; }
const getDependencyReleaseLine = async () => { const getDependencyReleaseLine = async () => {
return ''; return ""
}; }
const changelogFunctions = { const changelogFunctions = {
getReleaseLine, getReleaseLine,
getDependencyReleaseLine, getDependencyReleaseLine,
}; }
module.exports = changelogFunctions; module.exports = changelogFunctions

View File

@@ -0,0 +1,5 @@
---
"roo-cline": patch
---
debug from vscode and changed output channel to Roo-Cline

View File

@@ -15,7 +15,6 @@
} }
], ],
"@typescript-eslint/semi": "off", "@typescript-eslint/semi": "off",
"curly": "warn",
"eqeqeq": "warn", "eqeqeq": "warn",
"no-throw-literal": "warn", "no-throw-literal": "warn",
"semi": "off", "semi": "off",

3
.git-blame-ignore-revs Normal file
View File

@@ -0,0 +1,3 @@
# Ran Prettier on all files - https://github.com/RooVetGit/Roo-Cline/pull/404
60a0a824b96a0b326af4d8871b6903f4ddcfe114
579bdd9dbf6d2d569e5e7adb5ff6292b1e42ea34

2
.gitconfig Normal file
View File

@@ -0,0 +1,2 @@
[blame]
ignoreRevsFile = .git-blame-ignore-revs

View File

@@ -1,28 +1,37 @@
<!-- **Note:** Consider creating PRs as a DRAFT. For early feedback and self-review. --> <!-- **Note:** Consider creating PRs as a DRAFT. For early feedback and self-review. -->
## Description ## Description
## Type of change ## Type of change
<!-- Please ignore options that are not relevant --> <!-- Please ignore options that are not relevant -->
- [ ] Bug fix (non-breaking change which fixes an issue) - [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature - [ ] New feature
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] This change requires a documentation update - [ ] This change requires a documentation update
## How Has This Been Tested? ## How Has This Been Tested?
<!-- Please describe the tests that you ran to verify your changes --> <!-- Please describe the tests that you ran to verify your changes -->
## Checklist: ## Checklist:
<!-- Go over all the following points, and put an `x` in all the boxes that apply --> <!-- Go over all the following points, and put an `x` in all the boxes that apply -->
- [ ] My code follows the patterns of this project - [ ] My code follows the patterns of this project
- [ ] I have performed a self-review of my own code - [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation - [ ] I have made corresponding changes to the documentation
## Additional context ## Additional context
<!-- Add any other context or screenshots about the pull request here --> <!-- Add any other context or screenshots about the pull request here -->
## Related Issues ## Related Issues
<!-- List any related issues here. Use the GitHub issue linking syntax: #issue-number --> <!-- List any related issues here. Use the GitHub issue linking syntax: #issue-number -->
## Reviewers ## Reviewers
<!-- @mention specific team members or individuals who should review this PR --> <!-- @mention specific team members or individuals who should review this PR -->

18
.vscode/launch.json vendored
View File

@@ -9,9 +9,21 @@
"name": "Run Extension", "name": "Run Extension",
"type": "extensionHost", "type": "extensionHost",
"request": "launch", "request": "launch",
"args": ["--extensionDevelopmentPath=${workspaceFolder}"], "runtimeExecutable": "${execPath}",
"args": [
"--extensionDevelopmentPath=${workspaceFolder}",
],
"sourceMaps": true,
"outFiles": ["${workspaceFolder}/dist/**/*.js"], "outFiles": ["${workspaceFolder}/dist/**/*.js"],
"preLaunchTask": "${defaultBuildTask}" "preLaunchTask": "compile",
} "env": {
"NODE_ENV": "development",
"VSCODE_DEBUG_MODE": "true"
},
"resolveSourceMapLocations": [
"${workspaceFolder}/**",
"!**/node_modules/**"
]
},
] ]
} }

17
.vscode/tasks.json vendored
View File

@@ -3,6 +3,21 @@
{ {
"version": "2.0.0", "version": "2.0.0",
"tasks": [ "tasks": [
{
"label": "compile",
"type": "npm",
"script": "compile",
"dependsOn": ["npm: build:webview"],
"group": {
"kind": "build",
"isDefault": true
},
"presentation": {
"reveal": "silent",
"panel": "shared"
},
"problemMatcher": ["$tsc", "$eslint-stylish"]
},
{ {
"label": "watch", "label": "watch",
"dependsOn": ["npm: build:webview", "npm: watch:tsc", "npm: watch:esbuild"], "dependsOn": ["npm: build:webview", "npm: watch:tsc", "npm: watch:esbuild"],
@@ -11,7 +26,7 @@
}, },
"group": { "group": {
"kind": "build", "kind": "build",
"isDefault": true "isDefault": false
} }
}, },
{ {

View File

@@ -11,6 +11,7 @@ Hot off the heels of **v3.0** introducing Code, Architect, and Ask chat modes, o
You can now tailor the **role definition** and **custom instructions** for every chat mode to perfectly fit your workflow. Want to adjust Architect mode to focus more on system scalability? Or tweak Ask mode for deeper research queries? Done. Plus, you can define these via **mode-specific `.clinerules-[mode]` files**. Youll find all of this in the new **Prompts** tab in the top menu. You can now tailor the **role definition** and **custom instructions** for every chat mode to perfectly fit your workflow. Want to adjust Architect mode to focus more on system scalability? Or tweak Ask mode for deeper research queries? Done. Plus, you can define these via **mode-specific `.clinerules-[mode]` files**. Youll find all of this in the new **Prompts** tab in the top menu.
The second big feature in this release is a complete revamp of **prompt enhancements**. This feature helps you craft messages to get even better results from Cline. Heres whats new: The second big feature in this release is a complete revamp of **prompt enhancements**. This feature helps you craft messages to get even better results from Cline. Heres whats new:
- Works with **any provider** and API configuration, not just OpenRouter. - Works with **any provider** and API configuration, not just OpenRouter.
- Fully customizable prompts to match your unique needs. - Fully customizable prompts to match your unique needs.
- Same simple workflow: just hit the ✨ **Enhance Prompt** button in the chat input to try it out. - Same simple workflow: just hit the ✨ **Enhance Prompt** button in the chat input to try it out.
@@ -33,6 +34,7 @@ You can now choose between different prompts for Roo Cline to better suit your w
Its super simple! Theres a dropdown in the bottom left of the chat input to switch modes. Right next to it, youll find a way to switch between the API configuration profiles associated with the current mode (configured on the settings screen). Its super simple! Theres a dropdown in the bottom left of the chat input to switch modes. Right next to it, youll find a way to switch between the API configuration profiles associated with the current mode (configured on the settings screen).
**Why Add This?** **Why Add This?**
- It keeps Cline from being overly eager to jump into solving problems when you just want to think or ask questions. - It keeps Cline from being overly eager to jump into solving problems when you just want to think or ask questions.
- Each mode remembers the API configuration you last used with it. For example, you can use more thoughtful models like OpenAI o1 for Architect and Ask, while sticking with Sonnet or DeepSeek for coding tasks. - Each mode remembers the API configuration you last used with it. For example, you can use more thoughtful models like OpenAI o1 for Architect and Ask, while sticking with Sonnet or DeepSeek for coding tasks.
- It builds on research suggesting better results when separating "thinking" from "coding," explained well in this very thoughtful [article](https://aider.chat/2024/09/26/architect.html) from aider. - It builds on research suggesting better results when separating "thinking" from "coding," explained well in this very thoughtful [article](https://aider.chat/2024/09/26/architect.html) from aider.
@@ -50,11 +52,13 @@ Here's an example of Roo-Cline autonomously creating a snake game with "Always a
https://github.com/user-attachments/assets/c2bb31dc-e9b2-4d73-885d-17f1471a4987 https://github.com/user-attachments/assets/c2bb31dc-e9b2-4d73-885d-17f1471a4987
## Contributing ## Contributing
To contribute to the project, start by exploring [open issues](https://github.com/RooVetGit/Roo-Cline/issues) or checking our [feature request board](https://github.com/RooVetGit/Roo-Cline/discussions/categories/feature-requests). We'd also love to have you join the [Roo Cline Reddit](https://www.reddit.com/r/roocline/) to share ideas and connect with other contributors. To contribute to the project, start by exploring [open issues](https://github.com/RooVetGit/Roo-Cline/issues) or checking our [feature request board](https://github.com/RooVetGit/Roo-Cline/discussions/categories/feature-requests). We'd also love to have you join the [Roo Cline Reddit](https://www.reddit.com/r/roocline/) to share ideas and connect with other contributors.
### Local Setup ### Local Setup
1. Install dependencies: 1. Install dependencies:
```bash ```bash
npm run install:all npm run install:all
``` ```
@@ -89,6 +93,7 @@ We use [changesets](https://github.com/changesets/changesets) for versioning and
4. Merge it 4. Merge it
Once your merge is successful: Once your merge is successful:
- The release workflow will automatically create a new "Changeset version bump" PR - The release workflow will automatically create a new "Changeset version bump" PR
- This PR will: - This PR will:
- Update the version based on your changeset - Update the version based on your changeset

View File

@@ -1,7 +1,7 @@
## For All Settings ## For All Settings
1. Add the setting to ExtensionMessage.ts: 1. Add the setting to ExtensionMessage.ts:
- Add the setting to the ExtensionState interface - Add the setting to the ExtensionState interface
- Make it required if it has a default value, optional if it can be undefined - Make it required if it has a default value, optional if it can be undefined
- Example: `preferredLanguage: string` - Example: `preferredLanguage: string`
@@ -14,10 +14,12 @@
## For Checkbox Settings ## For Checkbox Settings
1. Add the message type to WebviewMessage.ts: 1. Add the message type to WebviewMessage.ts:
- Add the setting name to the WebviewMessage type's type union - Add the setting name to the WebviewMessage type's type union
- Example: `| "multisearchDiffEnabled"` - Example: `| "multisearchDiffEnabled"`
2. Add the setting to ExtensionStateContext.tsx: 2. Add the setting to ExtensionStateContext.tsx:
- Add the setting to the ExtensionStateContextType interface - Add the setting to the ExtensionStateContextType interface
- Add the setter function to the interface - Add the setter function to the interface
- Add the setting to the initial state in useState - Add the setting to the initial state in useState
@@ -25,12 +27,13 @@
- Example: - Example:
```typescript ```typescript
interface ExtensionStateContextType { interface ExtensionStateContextType {
multisearchDiffEnabled: boolean; multisearchDiffEnabled: boolean
setMultisearchDiffEnabled: (value: boolean) => void; setMultisearchDiffEnabled: (value: boolean) => void
} }
``` ```
3. Add the setting to ClineProvider.ts: 3. Add the setting to ClineProvider.ts:
- Add the setting name to the GlobalStateKey type union - Add the setting name to the GlobalStateKey type union
- Add the setting to the Promise.all array in getState - Add the setting to the Promise.all array in getState
- Add the setting to the return value in getState with a default value - Add the setting to the return value in getState with a default value
@@ -46,6 +49,7 @@
``` ```
4. Add the checkbox UI to SettingsView.tsx: 4. Add the checkbox UI to SettingsView.tsx:
- Import the setting and its setter from ExtensionStateContext - Import the setting and its setter from ExtensionStateContext
- Add the VSCodeCheckbox component with the setting's state and onChange handler - Add the VSCodeCheckbox component with the setting's state and onChange handler
- Add appropriate labels and description text - Add appropriate labels and description text
@@ -69,10 +73,12 @@
## For Select/Dropdown Settings ## For Select/Dropdown Settings
1. Add the message type to WebviewMessage.ts: 1. Add the message type to WebviewMessage.ts:
- Add the setting name to the WebviewMessage type's type union - Add the setting name to the WebviewMessage type's type union
- Example: `| "preferredLanguage"` - Example: `| "preferredLanguage"`
2. Add the setting to ExtensionStateContext.tsx: 2. Add the setting to ExtensionStateContext.tsx:
- Add the setting to the ExtensionStateContextType interface - Add the setting to the ExtensionStateContextType interface
- Add the setter function to the interface - Add the setter function to the interface
- Add the setting to the initial state in useState with a default value - Add the setting to the initial state in useState with a default value
@@ -80,12 +86,13 @@
- Example: - Example:
```typescript ```typescript
interface ExtensionStateContextType { interface ExtensionStateContextType {
preferredLanguage: string; preferredLanguage: string
setPreferredLanguage: (value: string) => void; setPreferredLanguage: (value: string) => void
} }
``` ```
3. Add the setting to ClineProvider.ts: 3. Add the setting to ClineProvider.ts:
- Add the setting name to the GlobalStateKey type union - Add the setting name to the GlobalStateKey type union
- Add the setting to the Promise.all array in getState - Add the setting to the Promise.all array in getState
- Add the setting to the return value in getState with a default value - Add the setting to the return value in getState with a default value
@@ -101,6 +108,7 @@
``` ```
4. Add the select UI to SettingsView.tsx: 4. Add the select UI to SettingsView.tsx:
- Import the setting and its setter from ExtensionStateContext - Import the setting and its setter from ExtensionStateContext
- Add the select element with appropriate styling to match VSCode's theme - Add the select element with appropriate styling to match VSCode's theme
- Add options for the dropdown - Add options for the dropdown
@@ -132,6 +140,7 @@
``` ```
These steps ensure that: These steps ensure that:
- The setting's state is properly typed throughout the application - The setting's state is properly typed throughout the application
- The setting persists between sessions - The setting persists between sessions
- The setting's value is properly synchronized between the webview and extension - The setting's value is properly synchronized between the webview and extension

View File

@@ -1,41 +1,40 @@
/** @type {import('ts-jest').JestConfigWithTsJest} */ /** @type {import('ts-jest').JestConfigWithTsJest} */
module.exports = { module.exports = {
preset: 'ts-jest', preset: "ts-jest",
testEnvironment: 'node', testEnvironment: "node",
moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'], moduleFileExtensions: ["ts", "tsx", "js", "jsx", "json", "node"],
transform: { transform: {
'^.+\\.tsx?$': ['ts-jest', { "^.+\\.tsx?$": [
"ts-jest",
{
tsconfig: { tsconfig: {
"module": "CommonJS", module: "CommonJS",
"moduleResolution": "node", moduleResolution: "node",
"esModuleInterop": true, esModuleInterop: true,
"allowJs": true allowJs: true,
}, },
diagnostics: false, diagnostics: false,
isolatedModules: true isolatedModules: true,
}]
}, },
testMatch: ['**/__tests__/**/*.test.ts'], ],
},
testMatch: ["**/__tests__/**/*.test.ts"],
moduleNameMapper: { moduleNameMapper: {
'^vscode$': '<rootDir>/src/__mocks__/vscode.js', "^vscode$": "<rootDir>/src/__mocks__/vscode.js",
'@modelcontextprotocol/sdk$': '<rootDir>/src/__mocks__/@modelcontextprotocol/sdk/index.js', "@modelcontextprotocol/sdk$": "<rootDir>/src/__mocks__/@modelcontextprotocol/sdk/index.js",
'@modelcontextprotocol/sdk/(.*)': '<rootDir>/src/__mocks__/@modelcontextprotocol/sdk/$1', "@modelcontextprotocol/sdk/(.*)": "<rootDir>/src/__mocks__/@modelcontextprotocol/sdk/$1",
'^delay$': '<rootDir>/src/__mocks__/delay.js', "^delay$": "<rootDir>/src/__mocks__/delay.js",
'^p-wait-for$': '<rootDir>/src/__mocks__/p-wait-for.js', "^p-wait-for$": "<rootDir>/src/__mocks__/p-wait-for.js",
'^globby$': '<rootDir>/src/__mocks__/globby.js', "^globby$": "<rootDir>/src/__mocks__/globby.js",
'^serialize-error$': '<rootDir>/src/__mocks__/serialize-error.js', "^serialize-error$": "<rootDir>/src/__mocks__/serialize-error.js",
'^strip-ansi$': '<rootDir>/src/__mocks__/strip-ansi.js', "^strip-ansi$": "<rootDir>/src/__mocks__/strip-ansi.js",
'^default-shell$': '<rootDir>/src/__mocks__/default-shell.js', "^default-shell$": "<rootDir>/src/__mocks__/default-shell.js",
'^os-name$': '<rootDir>/src/__mocks__/os-name.js' "^os-name$": "<rootDir>/src/__mocks__/os-name.js",
}, },
transformIgnorePatterns: [ transformIgnorePatterns: [
'node_modules/(?!(@modelcontextprotocol|delay|p-wait-for|globby|serialize-error|strip-ansi|default-shell|os-name)/)' "node_modules/(?!(@modelcontextprotocol|delay|p-wait-for|globby|serialize-error|strip-ansi|default-shell|os-name)/)",
], ],
modulePathIgnorePatterns: [ modulePathIgnorePatterns: [".vscode-test"],
'.vscode-test' reporters: [["jest-simple-dot-reporter", {}]],
], setupFiles: [],
reporters: [
["jest-simple-dot-reporter", {}]
],
setupFiles: []
} }

397
package-lock.json generated
View File

@@ -73,6 +73,7 @@
"jest-simple-dot-reporter": "^1.0.5", "jest-simple-dot-reporter": "^1.0.5",
"lint-staged": "^15.2.11", "lint-staged": "^15.2.11",
"npm-run-all": "^4.1.5", "npm-run-all": "^4.1.5",
"prettier": "^3.4.2",
"ts-jest": "^29.2.5", "ts-jest": "^29.2.5",
"typescript": "^5.4.5" "typescript": "^5.4.5"
}, },
@@ -2733,6 +2734,21 @@
"semver": "^7.5.3" "semver": "^7.5.3"
} }
}, },
"node_modules/@changesets/apply-release-plan/node_modules/prettier": {
"version": "2.8.8",
"resolved": "https://registry.npmjs.org/prettier/-/prettier-2.8.8.tgz",
"integrity": "sha512-tdN8qQGvNjw4CHbY+XXk0JgCXn9QiF21a55rBe5LJAU+kDyC4WQn4+awm2Xfk2lQMk5fKup9XgzTZtGkjBdP9Q==",
"dev": true,
"bin": {
"prettier": "bin-prettier.js"
},
"engines": {
"node": ">=10.13.0"
},
"funding": {
"url": "https://github.com/prettier/prettier?sponsor=1"
}
},
"node_modules/@changesets/apply-release-plan/node_modules/resolve-from": { "node_modules/@changesets/apply-release-plan/node_modules/resolve-from": {
"version": "5.0.0", "version": "5.0.0",
"resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-5.0.0.tgz", "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-5.0.0.tgz",
@@ -2999,68 +3015,19 @@
"prettier": "^2.7.1" "prettier": "^2.7.1"
} }
}, },
"node_modules/@esbuild/aix-ppc64": { "node_modules/@changesets/write/node_modules/prettier": {
"version": "0.24.0", "version": "2.8.8",
"resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.24.0.tgz", "resolved": "https://registry.npmjs.org/prettier/-/prettier-2.8.8.tgz",
"integrity": "sha512-WtKdFM7ls47zkKHFVzMz8opM7LkcsIp9amDUBIAWirg70RM71WRSjdILPsY5Uv1D42ZpUfaPILDlfactHgsRkw==", "integrity": "sha512-tdN8qQGvNjw4CHbY+XXk0JgCXn9QiF21a55rBe5LJAU+kDyC4WQn4+awm2Xfk2lQMk5fKup9XgzTZtGkjBdP9Q==",
"cpu": [
"ppc64"
],
"dev": true, "dev": true,
"optional": true, "bin": {
"os": [ "prettier": "bin-prettier.js"
"aix"
],
"engines": {
"node": ">=18"
}
}, },
"node_modules/@esbuild/android-arm": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.24.0.tgz",
"integrity": "sha512-arAtTPo76fJ/ICkXWetLCc9EwEHKaeya4vMrReVlEIUCAUncH7M4bhMQ+M9Vf+FFOZJdTNMXNBrWwW+OXWpSew==",
"cpu": [
"arm"
],
"dev": true,
"optional": true,
"os": [
"android"
],
"engines": { "engines": {
"node": ">=18" "node": ">=10.13.0"
}
}, },
"node_modules/@esbuild/android-arm64": { "funding": {
"version": "0.24.0", "url": "https://github.com/prettier/prettier?sponsor=1"
"resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.24.0.tgz",
"integrity": "sha512-Vsm497xFM7tTIPYK9bNTYJyF/lsP590Qc1WxJdlB6ljCbdZKU9SY8i7+Iin4kyhV/KV5J2rOKsBQbB77Ab7L/w==",
"cpu": [
"arm64"
],
"dev": true,
"optional": true,
"os": [
"android"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/android-x64": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.24.0.tgz",
"integrity": "sha512-t8GrvnFkiIY7pa7mMgJd7p8p8qqYIz1NYiAoKc75Zyv73L3DZW++oYMSHPRarcotTKuSs6m3hTOa5CKHaS02TQ==",
"cpu": [
"x64"
],
"dev": true,
"optional": true,
"os": [
"android"
],
"engines": {
"node": ">=18"
} }
}, },
"node_modules/@esbuild/darwin-arm64": { "node_modules/@esbuild/darwin-arm64": {
@@ -3079,310 +3046,6 @@
"node": ">=18" "node": ">=18"
} }
}, },
"node_modules/@esbuild/darwin-x64": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.24.0.tgz",
"integrity": "sha512-rgtz6flkVkh58od4PwTRqxbKH9cOjaXCMZgWD905JOzjFKW+7EiUObfd/Kav+A6Gyud6WZk9w+xu6QLytdi2OA==",
"cpu": [
"x64"
],
"dev": true,
"optional": true,
"os": [
"darwin"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/freebsd-arm64": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.24.0.tgz",
"integrity": "sha512-6Mtdq5nHggwfDNLAHkPlyLBpE5L6hwsuXZX8XNmHno9JuL2+bg2BX5tRkwjyfn6sKbxZTq68suOjgWqCicvPXA==",
"cpu": [
"arm64"
],
"dev": true,
"optional": true,
"os": [
"freebsd"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/freebsd-x64": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.24.0.tgz",
"integrity": "sha512-D3H+xh3/zphoX8ck4S2RxKR6gHlHDXXzOf6f/9dbFt/NRBDIE33+cVa49Kil4WUjxMGW0ZIYBYtaGCa2+OsQwQ==",
"cpu": [
"x64"
],
"dev": true,
"optional": true,
"os": [
"freebsd"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-arm": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.24.0.tgz",
"integrity": "sha512-gJKIi2IjRo5G6Glxb8d3DzYXlxdEj2NlkixPsqePSZMhLudqPhtZ4BUrpIuTjJYXxvF9njql+vRjB2oaC9XpBw==",
"cpu": [
"arm"
],
"dev": true,
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-arm64": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.24.0.tgz",
"integrity": "sha512-TDijPXTOeE3eaMkRYpcy3LarIg13dS9wWHRdwYRnzlwlA370rNdZqbcp0WTyyV/k2zSxfko52+C7jU5F9Tfj1g==",
"cpu": [
"arm64"
],
"dev": true,
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-ia32": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.24.0.tgz",
"integrity": "sha512-K40ip1LAcA0byL05TbCQ4yJ4swvnbzHscRmUilrmP9Am7//0UjPreh4lpYzvThT2Quw66MhjG//20mrufm40mA==",
"cpu": [
"ia32"
],
"dev": true,
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-loong64": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.24.0.tgz",
"integrity": "sha512-0mswrYP/9ai+CU0BzBfPMZ8RVm3RGAN/lmOMgW4aFUSOQBjA31UP8Mr6DDhWSuMwj7jaWOT0p0WoZ6jeHhrD7g==",
"cpu": [
"loong64"
],
"dev": true,
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-mips64el": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.24.0.tgz",
"integrity": "sha512-hIKvXm0/3w/5+RDtCJeXqMZGkI2s4oMUGj3/jM0QzhgIASWrGO5/RlzAzm5nNh/awHE0A19h/CvHQe6FaBNrRA==",
"cpu": [
"mips64el"
],
"dev": true,
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-ppc64": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.24.0.tgz",
"integrity": "sha512-HcZh5BNq0aC52UoocJxaKORfFODWXZxtBaaZNuN3PUX3MoDsChsZqopzi5UupRhPHSEHotoiptqikjN/B77mYQ==",
"cpu": [
"ppc64"
],
"dev": true,
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-riscv64": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.24.0.tgz",
"integrity": "sha512-bEh7dMn/h3QxeR2KTy1DUszQjUrIHPZKyO6aN1X4BCnhfYhuQqedHaa5MxSQA/06j3GpiIlFGSsy1c7Gf9padw==",
"cpu": [
"riscv64"
],
"dev": true,
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-s390x": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.24.0.tgz",
"integrity": "sha512-ZcQ6+qRkw1UcZGPyrCiHHkmBaj9SiCD8Oqd556HldP+QlpUIe2Wgn3ehQGVoPOvZvtHm8HPx+bH20c9pvbkX3g==",
"cpu": [
"s390x"
],
"dev": true,
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-x64": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.24.0.tgz",
"integrity": "sha512-vbutsFqQ+foy3wSSbmjBXXIJ6PL3scghJoM8zCL142cGaZKAdCZHyf+Bpu/MmX9zT9Q0zFBVKb36Ma5Fzfa8xA==",
"cpu": [
"x64"
],
"dev": true,
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/netbsd-x64": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.24.0.tgz",
"integrity": "sha512-hjQ0R/ulkO8fCYFsG0FZoH+pWgTTDreqpqY7UnQntnaKv95uP5iW3+dChxnx7C3trQQU40S+OgWhUVwCjVFLvg==",
"cpu": [
"x64"
],
"dev": true,
"optional": true,
"os": [
"netbsd"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/openbsd-arm64": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/openbsd-arm64/-/openbsd-arm64-0.24.0.tgz",
"integrity": "sha512-MD9uzzkPQbYehwcN583yx3Tu5M8EIoTD+tUgKF982WYL9Pf5rKy9ltgD0eUgs8pvKnmizxjXZyLt0z6DC3rRXg==",
"cpu": [
"arm64"
],
"dev": true,
"optional": true,
"os": [
"openbsd"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/openbsd-x64": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.24.0.tgz",
"integrity": "sha512-4ir0aY1NGUhIC1hdoCzr1+5b43mw99uNwVzhIq1OY3QcEwPDO3B7WNXBzaKY5Nsf1+N11i1eOfFcq+D/gOS15Q==",
"cpu": [
"x64"
],
"dev": true,
"optional": true,
"os": [
"openbsd"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/sunos-x64": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.24.0.tgz",
"integrity": "sha512-jVzdzsbM5xrotH+W5f1s+JtUy1UWgjU0Cf4wMvffTB8m6wP5/kx0KiaLHlbJO+dMgtxKV8RQ/JvtlFcdZ1zCPA==",
"cpu": [
"x64"
],
"dev": true,
"optional": true,
"os": [
"sunos"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/win32-arm64": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.24.0.tgz",
"integrity": "sha512-iKc8GAslzRpBytO2/aN3d2yb2z8XTVfNV0PjGlCxKo5SgWmNXx82I/Q3aG1tFfS+A2igVCY97TJ8tnYwpUWLCA==",
"cpu": [
"arm64"
],
"dev": true,
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/win32-ia32": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.24.0.tgz",
"integrity": "sha512-vQW36KZolfIudCcTnaTpmLQ24Ha1RjygBo39/aLkM2kmjkWmZGEJ5Gn9l5/7tzXA42QGIoWbICfg6KLLkIw6yw==",
"cpu": [
"ia32"
],
"dev": true,
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/win32-x64": {
"version": "0.24.0",
"resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.24.0.tgz",
"integrity": "sha512-7IAFPrjSQIJrGsK6flwg7NFmwBoSTyF3rl7If0hNUFQU4ilTsEPL6GuMuU9BfIWVVGuRnuIidkSMC+c0Otu8IA==",
"cpu": [
"x64"
],
"dev": true,
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@eslint-community/eslint-utils": { "node_modules/@eslint-community/eslint-utils": {
"version": "4.4.1", "version": "4.4.1",
"resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.4.1.tgz", "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.4.1.tgz",
@@ -13217,15 +12880,15 @@
} }
}, },
"node_modules/prettier": { "node_modules/prettier": {
"version": "2.8.8", "version": "3.4.2",
"resolved": "https://registry.npmjs.org/prettier/-/prettier-2.8.8.tgz", "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.4.2.tgz",
"integrity": "sha512-tdN8qQGvNjw4CHbY+XXk0JgCXn9QiF21a55rBe5LJAU+kDyC4WQn4+awm2Xfk2lQMk5fKup9XgzTZtGkjBdP9Q==", "integrity": "sha512-e9MewbtFo+Fevyuxn/4rrcDAaq0IYxPGLvObpQjiZBMAzB9IGmzlnG9RZy3FFas+eBMu2vA0CszMeduow5dIuQ==",
"dev": true, "dev": true,
"bin": { "bin": {
"prettier": "bin-prettier.js" "prettier": "bin/prettier.cjs"
}, },
"engines": { "engines": {
"node": ">=10.13.0" "node": ">=14"
}, },
"funding": { "funding": {
"url": "https://github.com/prettier/prettier?sponsor=1" "url": "https://github.com/prettier/prettier?sponsor=1"

View File

@@ -219,6 +219,7 @@
"jest-simple-dot-reporter": "^1.0.5", "jest-simple-dot-reporter": "^1.0.5",
"lint-staged": "^15.2.11", "lint-staged": "^15.2.11",
"npm-run-all": "^4.1.5", "npm-run-all": "^4.1.5",
"prettier": "^3.4.2",
"ts-jest": "^29.2.5", "ts-jest": "^29.2.5",
"typescript": "^5.4.5" "typescript": "^5.4.5"
}, },
@@ -268,8 +269,12 @@
"zod": "^3.23.8" "zod": "^3.23.8"
}, },
"lint-staged": { "lint-staged": {
"*.{js,jsx,ts,tsx,json,css,md}": [
"prettier --write"
],
"src/**/*.{ts,tsx}": [ "src/**/*.{ts,tsx}": [
"npx eslint -c .eslintrc.json" "prettier --write",
"npx eslint -c .eslintrc.json --fix"
] ]
} }
} }

View File

@@ -13,5 +13,5 @@ class Client {
} }
module.exports = { module.exports = {
Client Client,
} }

View File

@@ -3,14 +3,14 @@ class StdioClientTransport {
this.start = jest.fn().mockResolvedValue(undefined) this.start = jest.fn().mockResolvedValue(undefined)
this.close = jest.fn().mockResolvedValue(undefined) this.close = jest.fn().mockResolvedValue(undefined)
this.stderr = { this.stderr = {
on: jest.fn() on: jest.fn(),
} }
} }
} }
class StdioServerParameters { class StdioServerParameters {
constructor() { constructor() {
this.command = '' this.command = ""
this.args = [] this.args = []
this.env = {} this.env = {}
} }
@@ -18,5 +18,5 @@ class StdioServerParameters {
module.exports = { module.exports = {
StdioClientTransport, StdioClientTransport,
StdioServerParameters StdioServerParameters,
} }

View File

@@ -1,5 +1,5 @@
const { Client } = require('./client/index.js') const { Client } = require("./client/index.js")
const { StdioClientTransport, StdioServerParameters } = require('./client/stdio.js') const { StdioClientTransport, StdioServerParameters } = require("./client/stdio.js")
const { const {
CallToolResultSchema, CallToolResultSchema,
ListToolsResultSchema, ListToolsResultSchema,
@@ -7,8 +7,8 @@ const {
ListResourceTemplatesResultSchema, ListResourceTemplatesResultSchema,
ReadResourceResultSchema, ReadResourceResultSchema,
ErrorCode, ErrorCode,
McpError McpError,
} = require('./types.js') } = require("./types.js")
module.exports = { module.exports = {
Client, Client,
@@ -20,5 +20,5 @@ module.exports = {
ListResourceTemplatesResultSchema, ListResourceTemplatesResultSchema,
ReadResourceResultSchema, ReadResourceResultSchema,
ErrorCode, ErrorCode,
McpError McpError,
} }

View File

@@ -1,36 +1,36 @@
const CallToolResultSchema = { const CallToolResultSchema = {
parse: jest.fn().mockReturnValue({}) parse: jest.fn().mockReturnValue({}),
} }
const ListToolsResultSchema = { const ListToolsResultSchema = {
parse: jest.fn().mockReturnValue({ parse: jest.fn().mockReturnValue({
tools: [] tools: [],
}) }),
} }
const ListResourcesResultSchema = { const ListResourcesResultSchema = {
parse: jest.fn().mockReturnValue({ parse: jest.fn().mockReturnValue({
resources: [] resources: [],
}) }),
} }
const ListResourceTemplatesResultSchema = { const ListResourceTemplatesResultSchema = {
parse: jest.fn().mockReturnValue({ parse: jest.fn().mockReturnValue({
resourceTemplates: [] resourceTemplates: [],
}) }),
} }
const ReadResourceResultSchema = { const ReadResourceResultSchema = {
parse: jest.fn().mockReturnValue({ parse: jest.fn().mockReturnValue({
contents: [] contents: [],
}) }),
} }
const ErrorCode = { const ErrorCode = {
InvalidRequest: 'InvalidRequest', InvalidRequest: "InvalidRequest",
MethodNotFound: 'MethodNotFound', MethodNotFound: "MethodNotFound",
InvalidParams: 'InvalidParams', InvalidParams: "InvalidParams",
InternalError: 'InternalError' InternalError: "InternalError",
} }
class McpError extends Error { class McpError extends Error {
@@ -47,5 +47,5 @@ module.exports = {
ListResourceTemplatesResultSchema, ListResourceTemplatesResultSchema,
ReadResourceResultSchema, ReadResourceResultSchema,
ErrorCode, ErrorCode,
McpError McpError,
} }

View File

@@ -12,6 +12,6 @@ export class McpHub {
} }
async callTool(serverName: string, toolName: string, toolArguments?: Record<string, unknown>): Promise<any> { async callTool(serverName: string, toolName: string, toolArguments?: Record<string, unknown>): Promise<any> {
return Promise.resolve({ result: 'success' }) return Promise.resolve({ result: "success" })
} }
} }

View File

@@ -1,12 +1,12 @@
// Mock default shell based on platform // Mock default shell based on platform
const os = require('os'); const os = require("os")
let defaultShell; let defaultShell
if (os.platform() === 'win32') { if (os.platform() === "win32") {
defaultShell = 'cmd.exe'; defaultShell = "cmd.exe"
} else { } else {
defaultShell = '/bin/bash'; defaultShell = "/bin/bash"
} }
module.exports = defaultShell; module.exports = defaultShell
module.exports.default = defaultShell; module.exports.default = defaultShell

View File

@@ -1,6 +1,6 @@
function delay(ms) { function delay(ms) {
return new Promise(resolve => setTimeout(resolve, ms)); return new Promise((resolve) => setTimeout(resolve, ms))
} }
module.exports = delay; module.exports = delay
module.exports.default = delay; module.exports.default = delay

View File

@@ -1,10 +1,10 @@
function globby(patterns, options) { function globby(patterns, options) {
return Promise.resolve([]); return Promise.resolve([])
} }
globby.sync = function(patterns, options) { globby.sync = function (patterns, options) {
return []; return []
}; }
module.exports = globby; module.exports = globby
module.exports.default = globby; module.exports.default = globby

View File

@@ -1,6 +1,6 @@
function osName() { function osName() {
return 'macOS'; return "macOS"
} }
module.exports = osName; module.exports = osName
module.exports.default = osName; module.exports.default = osName

View File

@@ -2,19 +2,19 @@ function pWaitFor(condition, options = {}) {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
const interval = setInterval(() => { const interval = setInterval(() => {
if (condition()) { if (condition()) {
clearInterval(interval); clearInterval(interval)
resolve(); resolve()
} }
}, options.interval || 20); }, options.interval || 20)
if (options.timeout) { if (options.timeout) {
setTimeout(() => { setTimeout(() => {
clearInterval(interval); clearInterval(interval)
reject(new Error('Timed out')); reject(new Error("Timed out"))
}, options.timeout); }, options.timeout)
} }
}); })
} }
module.exports = pWaitFor; module.exports = pWaitFor
module.exports.default = pWaitFor; module.exports.default = pWaitFor

View File

@@ -3,23 +3,23 @@ function serializeError(error) {
return { return {
name: error.name, name: error.name,
message: error.message, message: error.message,
stack: error.stack stack: error.stack,
};
} }
return error; }
return error
} }
function deserializeError(errorData) { function deserializeError(errorData) {
if (errorData && typeof errorData === 'object') { if (errorData && typeof errorData === "object") {
const error = new Error(errorData.message); const error = new Error(errorData.message)
error.name = errorData.name; error.name = errorData.name
error.stack = errorData.stack; error.stack = errorData.stack
return error; return error
} }
return errorData; return errorData
} }
module.exports = { module.exports = {
serializeError, serializeError,
deserializeError deserializeError,
}; }

View File

@@ -1,7 +1,7 @@
function stripAnsi(string) { function stripAnsi(string) {
// Simple mock that just returns the input string // Simple mock that just returns the input string
return string; return string
} }
module.exports = stripAnsi; module.exports = stripAnsi
module.exports.default = stripAnsi; module.exports.default = stripAnsi

View File

@@ -3,11 +3,11 @@ const vscode = {
showInformationMessage: jest.fn(), showInformationMessage: jest.fn(),
showErrorMessage: jest.fn(), showErrorMessage: jest.fn(),
createTextEditorDecorationType: jest.fn().mockReturnValue({ createTextEditorDecorationType: jest.fn().mockReturnValue({
dispose: jest.fn() dispose: jest.fn(),
}) }),
}, },
workspace: { workspace: {
onDidSaveTextDocument: jest.fn() onDidSaveTextDocument: jest.fn(),
}, },
Disposable: class { Disposable: class {
dispose() {} dispose() {}
@@ -15,43 +15,43 @@ const vscode = {
Uri: { Uri: {
file: (path) => ({ file: (path) => ({
fsPath: path, fsPath: path,
scheme: 'file', scheme: "file",
authority: '', authority: "",
path: path, path: path,
query: '', query: "",
fragment: '', fragment: "",
with: jest.fn(), with: jest.fn(),
toJSON: jest.fn() toJSON: jest.fn(),
}) }),
}, },
EventEmitter: class { EventEmitter: class {
constructor() { constructor() {
this.event = jest.fn(); this.event = jest.fn()
this.fire = jest.fn(); this.fire = jest.fn()
} }
}, },
ConfigurationTarget: { ConfigurationTarget: {
Global: 1, Global: 1,
Workspace: 2, Workspace: 2,
WorkspaceFolder: 3 WorkspaceFolder: 3,
}, },
Position: class { Position: class {
constructor(line, character) { constructor(line, character) {
this.line = line; this.line = line
this.character = character; this.character = character
} }
}, },
Range: class { Range: class {
constructor(startLine, startCharacter, endLine, endCharacter) { constructor(startLine, startCharacter, endLine, endCharacter) {
this.start = new vscode.Position(startLine, startCharacter); this.start = new vscode.Position(startLine, startCharacter)
this.end = new vscode.Position(endLine, endCharacter); this.end = new vscode.Position(endLine, endCharacter)
} }
}, },
ThemeColor: class { ThemeColor: class {
constructor(id) { constructor(id) {
this.id = id; this.id = id
} }
} },
}; }
module.exports = vscode; module.exports = vscode

View File

@@ -1,12 +1,12 @@
import { AnthropicHandler } from '../anthropic'; import { AnthropicHandler } from "../anthropic"
import { ApiHandlerOptions } from '../../../shared/api'; import { ApiHandlerOptions } from "../../../shared/api"
import { ApiStream } from '../../transform/stream'; import { ApiStream } from "../../transform/stream"
import { Anthropic } from '@anthropic-ai/sdk'; import { Anthropic } from "@anthropic-ai/sdk"
// Mock Anthropic client // Mock Anthropic client
const mockBetaCreate = jest.fn(); const mockBetaCreate = jest.fn()
const mockCreate = jest.fn(); const mockCreate = jest.fn()
jest.mock('@anthropic-ai/sdk', () => { jest.mock("@anthropic-ai/sdk", () => {
return { return {
Anthropic: jest.fn().mockImplementation(() => ({ Anthropic: jest.fn().mockImplementation(() => ({
beta: { beta: {
@@ -15,225 +15,224 @@ jest.mock('@anthropic-ai/sdk', () => {
create: mockBetaCreate.mockImplementation(async () => ({ create: mockBetaCreate.mockImplementation(async () => ({
async *[Symbol.asyncIterator]() { async *[Symbol.asyncIterator]() {
yield { yield {
type: 'message_start', type: "message_start",
message: { message: {
usage: { usage: {
input_tokens: 100, input_tokens: 100,
output_tokens: 50, output_tokens: 50,
cache_creation_input_tokens: 20, cache_creation_input_tokens: 20,
cache_read_input_tokens: 10 cache_read_input_tokens: 10,
},
},
} }
}
};
yield { yield {
type: 'content_block_start', type: "content_block_start",
index: 0, index: 0,
content_block: { content_block: {
type: 'text', type: "text",
text: 'Hello' text: "Hello",
},
} }
};
yield { yield {
type: 'content_block_delta', type: "content_block_delta",
delta: { delta: {
type: 'text_delta', type: "text_delta",
text: ' world' text: " world",
} },
};
}
}))
}
} }
}, },
})),
},
},
},
messages: { messages: {
create: mockCreate.mockImplementation(async (options) => { create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) { if (!options.stream) {
return { return {
id: 'test-completion', id: "test-completion",
content: [ content: [{ type: "text", text: "Test response" }],
{ type: 'text', text: 'Test response' } role: "assistant",
],
role: 'assistant',
model: options.model, model: options.model,
usage: { usage: {
input_tokens: 10, input_tokens: 10,
output_tokens: 5 output_tokens: 5,
} },
} }
} }
return { return {
async *[Symbol.asyncIterator]() { async *[Symbol.asyncIterator]() {
yield { yield {
type: 'message_start', type: "message_start",
message: { message: {
usage: { usage: {
input_tokens: 10, input_tokens: 10,
output_tokens: 5 output_tokens: 5,
} },
} },
} }
yield { yield {
type: 'content_block_start', type: "content_block_start",
content_block: { content_block: {
type: 'text', type: "text",
text: 'Test response' text: "Test response",
},
} }
},
} }
}),
},
})),
} }
} })
})
}
}))
};
});
describe('AnthropicHandler', () => { describe("AnthropicHandler", () => {
let handler: AnthropicHandler; let handler: AnthropicHandler
let mockOptions: ApiHandlerOptions; let mockOptions: ApiHandlerOptions
beforeEach(() => { beforeEach(() => {
mockOptions = { mockOptions = {
apiKey: 'test-api-key', apiKey: "test-api-key",
apiModelId: 'claude-3-5-sonnet-20241022' apiModelId: "claude-3-5-sonnet-20241022",
}; }
handler = new AnthropicHandler(mockOptions); handler = new AnthropicHandler(mockOptions)
mockBetaCreate.mockClear(); mockBetaCreate.mockClear()
mockCreate.mockClear(); mockCreate.mockClear()
}); })
describe('constructor', () => { describe("constructor", () => {
it('should initialize with provided options', () => { it("should initialize with provided options", () => {
expect(handler).toBeInstanceOf(AnthropicHandler); expect(handler).toBeInstanceOf(AnthropicHandler)
expect(handler.getModel().id).toBe(mockOptions.apiModelId); expect(handler.getModel().id).toBe(mockOptions.apiModelId)
}); })
it('should initialize with undefined API key', () => { it("should initialize with undefined API key", () => {
// The SDK will handle API key validation, so we just verify it initializes // The SDK will handle API key validation, so we just verify it initializes
const handlerWithoutKey = new AnthropicHandler({ const handlerWithoutKey = new AnthropicHandler({
...mockOptions, ...mockOptions,
apiKey: undefined apiKey: undefined,
}); })
expect(handlerWithoutKey).toBeInstanceOf(AnthropicHandler); expect(handlerWithoutKey).toBeInstanceOf(AnthropicHandler)
}); })
it('should use custom base URL if provided', () => { it("should use custom base URL if provided", () => {
const customBaseUrl = 'https://custom.anthropic.com'; const customBaseUrl = "https://custom.anthropic.com"
const handlerWithCustomUrl = new AnthropicHandler({ const handlerWithCustomUrl = new AnthropicHandler({
...mockOptions, ...mockOptions,
anthropicBaseUrl: customBaseUrl anthropicBaseUrl: customBaseUrl,
}); })
expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler); expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler)
}); })
}); })
describe('createMessage', () => { describe("createMessage", () => {
const systemPrompt = 'You are a helpful assistant.'; const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [ const messages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: "user",
content: [{ content: [
type: 'text' as const, {
text: 'Hello!' type: "text" as const,
}] text: "Hello!",
} },
]; ],
},
]
it('should handle prompt caching for supported models', async () => { it("should handle prompt caching for supported models", async () => {
const stream = handler.createMessage(systemPrompt, [ const stream = handler.createMessage(systemPrompt, [
{ {
role: 'user', role: "user",
content: [{ type: 'text' as const, text: 'First message' }] content: [{ type: "text" as const, text: "First message" }],
}, },
{ {
role: 'assistant', role: "assistant",
content: [{ type: 'text' as const, text: 'Response' }] content: [{ type: "text" as const, text: "Response" }],
}, },
{ {
role: 'user', role: "user",
content: [{ type: 'text' as const, text: 'Second message' }] content: [{ type: "text" as const, text: "Second message" }],
} },
]); ])
const chunks: any[] = []; const chunks: any[] = []
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk)
} }
// Verify usage information // Verify usage information
const usageChunk = chunks.find(chunk => chunk.type === 'usage'); const usageChunk = chunks.find((chunk) => chunk.type === "usage")
expect(usageChunk).toBeDefined(); expect(usageChunk).toBeDefined()
expect(usageChunk?.inputTokens).toBe(100); expect(usageChunk?.inputTokens).toBe(100)
expect(usageChunk?.outputTokens).toBe(50); expect(usageChunk?.outputTokens).toBe(50)
expect(usageChunk?.cacheWriteTokens).toBe(20); expect(usageChunk?.cacheWriteTokens).toBe(20)
expect(usageChunk?.cacheReadTokens).toBe(10); expect(usageChunk?.cacheReadTokens).toBe(10)
// Verify text content // Verify text content
const textChunks = chunks.filter(chunk => chunk.type === 'text'); const textChunks = chunks.filter((chunk) => chunk.type === "text")
expect(textChunks).toHaveLength(2); expect(textChunks).toHaveLength(2)
expect(textChunks[0].text).toBe('Hello'); expect(textChunks[0].text).toBe("Hello")
expect(textChunks[1].text).toBe(' world'); expect(textChunks[1].text).toBe(" world")
// Verify beta API was used // Verify beta API was used
expect(mockBetaCreate).toHaveBeenCalled(); expect(mockBetaCreate).toHaveBeenCalled()
expect(mockCreate).not.toHaveBeenCalled(); expect(mockCreate).not.toHaveBeenCalled()
}); })
}); })
describe('completePrompt', () => { describe("completePrompt", () => {
it('should complete prompt successfully', async () => { it("should complete prompt successfully", async () => {
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe('Test response'); expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({ expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.apiModelId, model: mockOptions.apiModelId,
messages: [{ role: 'user', content: 'Test prompt' }], messages: [{ role: "user", content: "Test prompt" }],
max_tokens: 8192, max_tokens: 8192,
temperature: 0, temperature: 0,
stream: false stream: false,
}); })
}); })
it('should handle API errors', async () => { it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error')); mockCreate.mockRejectedValueOnce(new Error("API Error"))
await expect(handler.completePrompt('Test prompt')) await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Anthropic completion error: API Error")
.rejects.toThrow('Anthropic completion error: API Error'); })
});
it('should handle non-text content', async () => { it("should handle non-text content", async () => {
mockCreate.mockImplementationOnce(async () => ({ mockCreate.mockImplementationOnce(async () => ({
content: [{ type: 'image' }] content: [{ type: "image" }],
})); }))
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe(''); expect(result).toBe("")
}); })
it('should handle empty response', async () => { it("should handle empty response", async () => {
mockCreate.mockImplementationOnce(async () => ({ mockCreate.mockImplementationOnce(async () => ({
content: [{ type: 'text', text: '' }] content: [{ type: "text", text: "" }],
})); }))
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe(''); expect(result).toBe("")
}); })
}); })
describe('getModel', () => { describe("getModel", () => {
it('should return default model if no model ID is provided', () => { it("should return default model if no model ID is provided", () => {
const handlerWithoutModel = new AnthropicHandler({ const handlerWithoutModel = new AnthropicHandler({
...mockOptions, ...mockOptions,
apiModelId: undefined apiModelId: undefined,
}); })
const model = handlerWithoutModel.getModel(); const model = handlerWithoutModel.getModel()
expect(model.id).toBeDefined(); expect(model.id).toBeDefined()
expect(model.info).toBeDefined(); expect(model.info).toBeDefined()
}); })
it('should return specified model if valid model ID is provided', () => { it("should return specified model if valid model ID is provided", () => {
const model = handler.getModel(); const model = handler.getModel()
expect(model.id).toBe(mockOptions.apiModelId); expect(model.id).toBe(mockOptions.apiModelId)
expect(model.info).toBeDefined(); expect(model.info).toBeDefined()
expect(model.info.maxTokens).toBe(8192); expect(model.info.maxTokens).toBe(8192)
expect(model.info.contextWindow).toBe(200_000); expect(model.info.contextWindow).toBe(200_000)
expect(model.info.supportsImages).toBe(true); expect(model.info.supportsImages).toBe(true)
expect(model.info.supportsPromptCache).toBe(true); expect(model.info.supportsPromptCache).toBe(true)
}); })
}); })
}); })

View File

@@ -1,62 +1,64 @@
import { AwsBedrockHandler } from '../bedrock'; import { AwsBedrockHandler } from "../bedrock"
import { MessageContent } from '../../../shared/api'; import { MessageContent } from "../../../shared/api"
import { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime'; import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
import { Anthropic } from '@anthropic-ai/sdk'; import { Anthropic } from "@anthropic-ai/sdk"
describe('AwsBedrockHandler', () => { describe("AwsBedrockHandler", () => {
let handler: AwsBedrockHandler; let handler: AwsBedrockHandler
beforeEach(() => { beforeEach(() => {
handler = new AwsBedrockHandler({ handler = new AwsBedrockHandler({
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0', apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: 'test-access-key', awsAccessKey: "test-access-key",
awsSecretKey: 'test-secret-key', awsSecretKey: "test-secret-key",
awsRegion: 'us-east-1' awsRegion: "us-east-1",
}); })
}); })
describe('constructor', () => { describe("constructor", () => {
it('should initialize with provided config', () => { it("should initialize with provided config", () => {
expect(handler['options'].awsAccessKey).toBe('test-access-key'); expect(handler["options"].awsAccessKey).toBe("test-access-key")
expect(handler['options'].awsSecretKey).toBe('test-secret-key'); expect(handler["options"].awsSecretKey).toBe("test-secret-key")
expect(handler['options'].awsRegion).toBe('us-east-1'); expect(handler["options"].awsRegion).toBe("us-east-1")
expect(handler['options'].apiModelId).toBe('anthropic.claude-3-5-sonnet-20241022-v2:0'); expect(handler["options"].apiModelId).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
}); })
it('should initialize with missing AWS credentials', () => { it("should initialize with missing AWS credentials", () => {
const handlerWithoutCreds = new AwsBedrockHandler({ const handlerWithoutCreds = new AwsBedrockHandler({
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0', apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsRegion: 'us-east-1' awsRegion: "us-east-1",
}); })
expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler); expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler)
}); })
}); })
describe('createMessage', () => { describe("createMessage", () => {
const mockMessages: Anthropic.Messages.MessageParam[] = [ const mockMessages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: "user",
content: 'Hello' content: "Hello",
}, },
{ {
role: 'assistant', role: "assistant",
content: 'Hi there!' content: "Hi there!",
} },
]; ]
const systemPrompt = 'You are a helpful assistant'; const systemPrompt = "You are a helpful assistant"
it('should handle text messages correctly', async () => { it("should handle text messages correctly", async () => {
const mockResponse = { const mockResponse = {
messages: [{ messages: [
role: 'assistant', {
content: [{ type: 'text', text: 'Hello! How can I help you?' }] role: "assistant",
}], content: [{ type: "text", text: "Hello! How can I help you?" }],
},
],
usage: { usage: {
input_tokens: 10, input_tokens: 10,
output_tokens: 5 output_tokens: 5,
},
} }
};
// Mock AWS SDK invoke // Mock AWS SDK invoke
const mockStream = { const mockStream = {
@@ -65,182 +67,193 @@ describe('AwsBedrockHandler', () => {
metadata: { metadata: {
usage: { usage: {
inputTokens: 10, inputTokens: 10,
outputTokens: 5 outputTokens: 5,
},
},
} }
},
} }
};
}
};
const mockInvoke = jest.fn().mockResolvedValue({ const mockInvoke = jest.fn().mockResolvedValue({
stream: mockStream stream: mockStream,
}); })
handler['client'] = { handler["client"] = {
send: mockInvoke send: mockInvoke,
} as unknown as BedrockRuntimeClient; } as unknown as BedrockRuntimeClient
const stream = handler.createMessage(systemPrompt, mockMessages); const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []; const chunks = []
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk)
} }
expect(chunks.length).toBeGreaterThan(0); expect(chunks.length).toBeGreaterThan(0)
expect(chunks[0]).toEqual({ expect(chunks[0]).toEqual({
type: 'usage', type: "usage",
inputTokens: 10, inputTokens: 10,
outputTokens: 5 outputTokens: 5,
});
expect(mockInvoke).toHaveBeenCalledWith(expect.objectContaining({
input: expect.objectContaining({
modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0'
}) })
}));
});
it('should handle API errors', async () => { expect(mockInvoke).toHaveBeenCalledWith(
expect.objectContaining({
input: expect.objectContaining({
modelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
}),
}),
)
})
it("should handle API errors", async () => {
// Mock AWS SDK invoke with error // Mock AWS SDK invoke with error
const mockInvoke = jest.fn().mockRejectedValue(new Error('AWS Bedrock error')); const mockInvoke = jest.fn().mockRejectedValue(new Error("AWS Bedrock error"))
handler['client'] = { handler["client"] = {
send: mockInvoke send: mockInvoke,
} as unknown as BedrockRuntimeClient; } as unknown as BedrockRuntimeClient
const stream = handler.createMessage(systemPrompt, mockMessages); const stream = handler.createMessage(systemPrompt, mockMessages)
await expect(async () => { await expect(async () => {
for await (const chunk of stream) { for await (const chunk of stream) {
// Should throw before yielding any chunks // Should throw before yielding any chunks
} }
}).rejects.toThrow('AWS Bedrock error'); }).rejects.toThrow("AWS Bedrock error")
}); })
}); })
describe('completePrompt', () => { describe("completePrompt", () => {
it('should complete prompt successfully', async () => { it("should complete prompt successfully", async () => {
const mockResponse = { const mockResponse = {
output: new TextEncoder().encode(JSON.stringify({ output: new TextEncoder().encode(
content: 'Test response' JSON.stringify({
})) content: "Test response",
}; }),
),
}
const mockSend = jest.fn().mockResolvedValue(mockResponse); const mockSend = jest.fn().mockResolvedValue(mockResponse)
handler['client'] = { handler["client"] = {
send: mockSend send: mockSend,
} as unknown as BedrockRuntimeClient; } as unknown as BedrockRuntimeClient
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe('Test response'); expect(result).toBe("Test response")
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({ expect(mockSend).toHaveBeenCalledWith(
expect.objectContaining({
input: expect.objectContaining({ input: expect.objectContaining({
modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0', modelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
messages: expect.arrayContaining([ messages: expect.arrayContaining([
expect.objectContaining({ expect.objectContaining({
role: 'user', role: "user",
content: [{ text: 'Test prompt' }] content: [{ text: "Test prompt" }],
}) }),
]), ]),
inferenceConfig: expect.objectContaining({ inferenceConfig: expect.objectContaining({
maxTokens: 5000, maxTokens: 5000,
temperature: 0.3, temperature: 0.3,
topP: 0.1 topP: 0.1,
}),
}),
}),
)
}) })
it("should handle API errors", async () => {
const mockError = new Error("AWS Bedrock error")
const mockSend = jest.fn().mockRejectedValue(mockError)
handler["client"] = {
send: mockSend,
} as unknown as BedrockRuntimeClient
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
"Bedrock completion error: AWS Bedrock error",
)
}) })
}));
});
it('should handle API errors', async () => { it("should handle invalid response format", async () => {
const mockError = new Error('AWS Bedrock error');
const mockSend = jest.fn().mockRejectedValue(mockError);
handler['client'] = {
send: mockSend
} as unknown as BedrockRuntimeClient;
await expect(handler.completePrompt('Test prompt'))
.rejects.toThrow('Bedrock completion error: AWS Bedrock error');
});
it('should handle invalid response format', async () => {
const mockResponse = { const mockResponse = {
output: new TextEncoder().encode('invalid json') output: new TextEncoder().encode("invalid json"),
}; }
const mockSend = jest.fn().mockResolvedValue(mockResponse); const mockSend = jest.fn().mockResolvedValue(mockResponse)
handler['client'] = { handler["client"] = {
send: mockSend send: mockSend,
} as unknown as BedrockRuntimeClient; } as unknown as BedrockRuntimeClient
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe(''); expect(result).toBe("")
}); })
it('should handle empty response', async () => { it("should handle empty response", async () => {
const mockResponse = { const mockResponse = {
output: new TextEncoder().encode(JSON.stringify({})) output: new TextEncoder().encode(JSON.stringify({})),
}; }
const mockSend = jest.fn().mockResolvedValue(mockResponse); const mockSend = jest.fn().mockResolvedValue(mockResponse)
handler['client'] = { handler["client"] = {
send: mockSend send: mockSend,
} as unknown as BedrockRuntimeClient; } as unknown as BedrockRuntimeClient
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe(''); expect(result).toBe("")
}); })
it('should handle cross-region inference', async () => { it("should handle cross-region inference", async () => {
handler = new AwsBedrockHandler({ handler = new AwsBedrockHandler({
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0', apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: 'test-access-key', awsAccessKey: "test-access-key",
awsSecretKey: 'test-secret-key', awsSecretKey: "test-secret-key",
awsRegion: 'us-east-1', awsRegion: "us-east-1",
awsUseCrossRegionInference: true awsUseCrossRegionInference: true,
}); })
const mockResponse = { const mockResponse = {
output: new TextEncoder().encode(JSON.stringify({ output: new TextEncoder().encode(
content: 'Test response' JSON.stringify({
})) content: "Test response",
}; }),
),
}
const mockSend = jest.fn().mockResolvedValue(mockResponse); const mockSend = jest.fn().mockResolvedValue(mockResponse)
handler['client'] = { handler["client"] = {
send: mockSend send: mockSend,
} as unknown as BedrockRuntimeClient; } as unknown as BedrockRuntimeClient
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe('Test response'); expect(result).toBe("Test response")
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({ expect(mockSend).toHaveBeenCalledWith(
expect.objectContaining({
input: expect.objectContaining({ input: expect.objectContaining({
modelId: 'us.anthropic.claude-3-5-sonnet-20241022-v2:0' modelId: "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
}),
}),
)
})
}) })
}));
});
});
describe('getModel', () => { describe("getModel", () => {
it('should return correct model info in test environment', () => { it("should return correct model info in test environment", () => {
const modelInfo = handler.getModel(); const modelInfo = handler.getModel()
expect(modelInfo.id).toBe('anthropic.claude-3-5-sonnet-20241022-v2:0'); expect(modelInfo.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
expect(modelInfo.info).toBeDefined(); expect(modelInfo.info).toBeDefined()
expect(modelInfo.info.maxTokens).toBe(5000); // Test environment value expect(modelInfo.info.maxTokens).toBe(5000) // Test environment value
expect(modelInfo.info.contextWindow).toBe(128_000); // Test environment value expect(modelInfo.info.contextWindow).toBe(128_000) // Test environment value
}); })
it('should return test model info for invalid model in test environment', () => { it("should return test model info for invalid model in test environment", () => {
const invalidHandler = new AwsBedrockHandler({ const invalidHandler = new AwsBedrockHandler({
apiModelId: 'invalid-model', apiModelId: "invalid-model",
awsAccessKey: 'test-access-key', awsAccessKey: "test-access-key",
awsSecretKey: 'test-secret-key', awsSecretKey: "test-secret-key",
awsRegion: 'us-east-1' awsRegion: "us-east-1",
}); })
const modelInfo = invalidHandler.getModel(); const modelInfo = invalidHandler.getModel()
expect(modelInfo.id).toBe('invalid-model'); // In test env, returns whatever is passed expect(modelInfo.id).toBe("invalid-model") // In test env, returns whatever is passed
expect(modelInfo.info.maxTokens).toBe(5000); expect(modelInfo.info.maxTokens).toBe(5000)
expect(modelInfo.info.contextWindow).toBe(128_000); expect(modelInfo.info.contextWindow).toBe(128_000)
}); })
}); })
}); })

View File

@@ -1,11 +1,11 @@
import { DeepSeekHandler } from '../deepseek'; import { DeepSeekHandler } from "../deepseek"
import { ApiHandlerOptions, deepSeekDefaultModelId } from '../../../shared/api'; import { ApiHandlerOptions, deepSeekDefaultModelId } from "../../../shared/api"
import OpenAI from 'openai'; import OpenAI from "openai"
import { Anthropic } from '@anthropic-ai/sdk'; import { Anthropic } from "@anthropic-ai/sdk"
// Mock OpenAI client // Mock OpenAI client
const mockCreate = jest.fn(); const mockCreate = jest.fn()
jest.mock('openai', () => { jest.mock("openai", () => {
return { return {
__esModule: true, __esModule: true,
default: jest.fn().mockImplementation(() => ({ default: jest.fn().mockImplementation(() => ({
@@ -14,190 +14,204 @@ jest.mock('openai', () => {
create: mockCreate.mockImplementation(async (options) => { create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) { if (!options.stream) {
return { return {
id: 'test-completion', id: "test-completion",
choices: [{ choices: [
message: { role: 'assistant', content: 'Test response', refusal: null }, {
finish_reason: 'stop', message: { role: "assistant", content: "Test response", refusal: null },
index: 0 finish_reason: "stop",
}], index: 0,
},
],
usage: { usage: {
prompt_tokens: 10, prompt_tokens: 10,
completion_tokens: 5, completion_tokens: 5,
total_tokens: 15 total_tokens: 15,
},
} }
};
} }
// Return async iterator for streaming // Return async iterator for streaming
return { return {
[Symbol.asyncIterator]: async function* () { [Symbol.asyncIterator]: async function* () {
yield { yield {
choices: [{ choices: [
delta: { content: 'Test response' }, {
index: 0 delta: { content: "Test response" },
}], index: 0,
usage: null },
}; ],
usage: null,
}
yield { yield {
choices: [{ choices: [
{
delta: {}, delta: {},
index: 0 index: 0,
}], },
],
usage: { usage: {
prompt_tokens: 10, prompt_tokens: 10,
completion_tokens: 5, completion_tokens: 5,
total_tokens: 15 total_tokens: 15,
},
} }
}; },
} }
}; }),
}) },
},
})),
} }
} })
}))
};
});
describe('DeepSeekHandler', () => { describe("DeepSeekHandler", () => {
let handler: DeepSeekHandler; let handler: DeepSeekHandler
let mockOptions: ApiHandlerOptions; let mockOptions: ApiHandlerOptions
beforeEach(() => { beforeEach(() => {
mockOptions = { mockOptions = {
deepSeekApiKey: 'test-api-key', deepSeekApiKey: "test-api-key",
deepSeekModelId: 'deepseek-chat', deepSeekModelId: "deepseek-chat",
deepSeekBaseUrl: 'https://api.deepseek.com/v1' deepSeekBaseUrl: "https://api.deepseek.com/v1",
}; }
handler = new DeepSeekHandler(mockOptions); handler = new DeepSeekHandler(mockOptions)
mockCreate.mockClear(); mockCreate.mockClear()
}); })
describe('constructor', () => { describe("constructor", () => {
it('should initialize with provided options', () => { it("should initialize with provided options", () => {
expect(handler).toBeInstanceOf(DeepSeekHandler); expect(handler).toBeInstanceOf(DeepSeekHandler)
expect(handler.getModel().id).toBe(mockOptions.deepSeekModelId); expect(handler.getModel().id).toBe(mockOptions.deepSeekModelId)
}); })
it('should throw error if API key is missing', () => { it("should throw error if API key is missing", () => {
expect(() => { expect(() => {
new DeepSeekHandler({ new DeepSeekHandler({
...mockOptions, ...mockOptions,
deepSeekApiKey: undefined deepSeekApiKey: undefined,
}); })
}).toThrow('DeepSeek API key is required'); }).toThrow("DeepSeek API key is required")
}); })
it('should use default model ID if not provided', () => { it("should use default model ID if not provided", () => {
const handlerWithoutModel = new DeepSeekHandler({ const handlerWithoutModel = new DeepSeekHandler({
...mockOptions, ...mockOptions,
deepSeekModelId: undefined deepSeekModelId: undefined,
}); })
expect(handlerWithoutModel.getModel().id).toBe(deepSeekDefaultModelId); expect(handlerWithoutModel.getModel().id).toBe(deepSeekDefaultModelId)
}); })
it('should use default base URL if not provided', () => { it("should use default base URL if not provided", () => {
const handlerWithoutBaseUrl = new DeepSeekHandler({ const handlerWithoutBaseUrl = new DeepSeekHandler({
...mockOptions, ...mockOptions,
deepSeekBaseUrl: undefined deepSeekBaseUrl: undefined,
}); })
expect(handlerWithoutBaseUrl).toBeInstanceOf(DeepSeekHandler); expect(handlerWithoutBaseUrl).toBeInstanceOf(DeepSeekHandler)
// The base URL is passed to OpenAI client internally // The base URL is passed to OpenAI client internally
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ expect(OpenAI).toHaveBeenCalledWith(
baseURL: 'https://api.deepseek.com/v1' expect.objectContaining({
})); baseURL: "https://api.deepseek.com/v1",
}); }),
)
})
it('should use custom base URL if provided', () => { it("should use custom base URL if provided", () => {
const customBaseUrl = 'https://custom.deepseek.com/v1'; const customBaseUrl = "https://custom.deepseek.com/v1"
const handlerWithCustomUrl = new DeepSeekHandler({ const handlerWithCustomUrl = new DeepSeekHandler({
...mockOptions, ...mockOptions,
deepSeekBaseUrl: customBaseUrl deepSeekBaseUrl: customBaseUrl,
}); })
expect(handlerWithCustomUrl).toBeInstanceOf(DeepSeekHandler); expect(handlerWithCustomUrl).toBeInstanceOf(DeepSeekHandler)
// The custom base URL is passed to OpenAI client // The custom base URL is passed to OpenAI client
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ expect(OpenAI).toHaveBeenCalledWith(
baseURL: customBaseUrl expect.objectContaining({
})); baseURL: customBaseUrl,
}); }),
)
})
it('should set includeMaxTokens to true', () => { it("should set includeMaxTokens to true", () => {
// Create a new handler and verify OpenAI client was called with includeMaxTokens // Create a new handler and verify OpenAI client was called with includeMaxTokens
new DeepSeekHandler(mockOptions); new DeepSeekHandler(mockOptions)
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ expect(OpenAI).toHaveBeenCalledWith(
apiKey: mockOptions.deepSeekApiKey expect.objectContaining({
})); apiKey: mockOptions.deepSeekApiKey,
}); }),
}); )
})
})
describe('getModel', () => { describe("getModel", () => {
it('should return model info for valid model ID', () => { it("should return model info for valid model ID", () => {
const model = handler.getModel(); const model = handler.getModel()
expect(model.id).toBe(mockOptions.deepSeekModelId); expect(model.id).toBe(mockOptions.deepSeekModelId)
expect(model.info).toBeDefined(); expect(model.info).toBeDefined()
expect(model.info.maxTokens).toBe(8192); expect(model.info.maxTokens).toBe(8192)
expect(model.info.contextWindow).toBe(64_000); expect(model.info.contextWindow).toBe(64_000)
expect(model.info.supportsImages).toBe(false); expect(model.info.supportsImages).toBe(false)
expect(model.info.supportsPromptCache).toBe(false); expect(model.info.supportsPromptCache).toBe(false)
}); })
it('should return provided model ID with default model info if model does not exist', () => { it("should return provided model ID with default model info if model does not exist", () => {
const handlerWithInvalidModel = new DeepSeekHandler({ const handlerWithInvalidModel = new DeepSeekHandler({
...mockOptions, ...mockOptions,
deepSeekModelId: 'invalid-model' deepSeekModelId: "invalid-model",
}); })
const model = handlerWithInvalidModel.getModel(); const model = handlerWithInvalidModel.getModel()
expect(model.id).toBe('invalid-model'); // Returns provided ID expect(model.id).toBe("invalid-model") // Returns provided ID
expect(model.info).toBeDefined(); expect(model.info).toBeDefined()
expect(model.info).toBe(handler.getModel().info); // But uses default model info expect(model.info).toBe(handler.getModel().info) // But uses default model info
}); })
it('should return default model if no model ID is provided', () => { it("should return default model if no model ID is provided", () => {
const handlerWithoutModel = new DeepSeekHandler({ const handlerWithoutModel = new DeepSeekHandler({
...mockOptions, ...mockOptions,
deepSeekModelId: undefined deepSeekModelId: undefined,
}); })
const model = handlerWithoutModel.getModel(); const model = handlerWithoutModel.getModel()
expect(model.id).toBe(deepSeekDefaultModelId); expect(model.id).toBe(deepSeekDefaultModelId)
expect(model.info).toBeDefined(); expect(model.info).toBeDefined()
}); })
}); })
describe('createMessage', () => { describe("createMessage", () => {
const systemPrompt = 'You are a helpful assistant.'; const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [ const messages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: "user",
content: [{ content: [
type: 'text' as const, {
text: 'Hello!' type: "text" as const,
}] text: "Hello!",
} },
]; ],
},
]
it('should handle streaming responses', async () => { it("should handle streaming responses", async () => {
const stream = handler.createMessage(systemPrompt, messages); const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []; const chunks: any[] = []
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk)
} }
expect(chunks.length).toBeGreaterThan(0); expect(chunks.length).toBeGreaterThan(0)
const textChunks = chunks.filter(chunk => chunk.type === 'text'); const textChunks = chunks.filter((chunk) => chunk.type === "text")
expect(textChunks).toHaveLength(1); expect(textChunks).toHaveLength(1)
expect(textChunks[0].text).toBe('Test response'); expect(textChunks[0].text).toBe("Test response")
}); })
it('should include usage information', async () => { it("should include usage information", async () => {
const stream = handler.createMessage(systemPrompt, messages); const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []; const chunks: any[] = []
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk)
} }
const usageChunks = chunks.filter(chunk => chunk.type === 'usage'); const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
expect(usageChunks.length).toBeGreaterThan(0); expect(usageChunks.length).toBeGreaterThan(0)
expect(usageChunks[0].inputTokens).toBe(10); expect(usageChunks[0].inputTokens).toBe(10)
expect(usageChunks[0].outputTokens).toBe(5); expect(usageChunks[0].outputTokens).toBe(5)
}); })
}); })
}); })

View File

@@ -1,212 +1,210 @@
import { GeminiHandler } from '../gemini'; import { GeminiHandler } from "../gemini"
import { Anthropic } from '@anthropic-ai/sdk'; import { Anthropic } from "@anthropic-ai/sdk"
import { GoogleGenerativeAI } from '@google/generative-ai'; import { GoogleGenerativeAI } from "@google/generative-ai"
// Mock the Google Generative AI SDK // Mock the Google Generative AI SDK
jest.mock('@google/generative-ai', () => ({ jest.mock("@google/generative-ai", () => ({
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({ GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
getGenerativeModel: jest.fn().mockReturnValue({ getGenerativeModel: jest.fn().mockReturnValue({
generateContentStream: jest.fn(), generateContentStream: jest.fn(),
generateContent: jest.fn().mockResolvedValue({ generateContent: jest.fn().mockResolvedValue({
response: { response: {
text: () => 'Test response' text: () => "Test response",
} },
}) }),
}) }),
})) })),
})); }))
describe('GeminiHandler', () => { describe("GeminiHandler", () => {
let handler: GeminiHandler; let handler: GeminiHandler
beforeEach(() => { beforeEach(() => {
handler = new GeminiHandler({ handler = new GeminiHandler({
apiKey: 'test-key', apiKey: "test-key",
apiModelId: 'gemini-2.0-flash-thinking-exp-1219', apiModelId: "gemini-2.0-flash-thinking-exp-1219",
geminiApiKey: 'test-key' geminiApiKey: "test-key",
}); })
}); })
describe('constructor', () => { describe("constructor", () => {
it('should initialize with provided config', () => { it("should initialize with provided config", () => {
expect(handler['options'].geminiApiKey).toBe('test-key'); expect(handler["options"].geminiApiKey).toBe("test-key")
expect(handler['options'].apiModelId).toBe('gemini-2.0-flash-thinking-exp-1219'); expect(handler["options"].apiModelId).toBe("gemini-2.0-flash-thinking-exp-1219")
}); })
it('should throw if API key is missing', () => { it("should throw if API key is missing", () => {
expect(() => { expect(() => {
new GeminiHandler({ new GeminiHandler({
apiModelId: 'gemini-2.0-flash-thinking-exp-1219', apiModelId: "gemini-2.0-flash-thinking-exp-1219",
geminiApiKey: '' geminiApiKey: "",
}); })
}).toThrow('API key is required for Google Gemini'); }).toThrow("API key is required for Google Gemini")
}); })
}); })
describe('createMessage', () => { describe("createMessage", () => {
const mockMessages: Anthropic.Messages.MessageParam[] = [ const mockMessages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: "user",
content: 'Hello' content: "Hello",
}, },
{ {
role: 'assistant', role: "assistant",
content: 'Hi there!' content: "Hi there!",
} },
]; ]
const systemPrompt = 'You are a helpful assistant'; const systemPrompt = "You are a helpful assistant"
it('should handle text messages correctly', async () => { it("should handle text messages correctly", async () => {
// Mock the stream response // Mock the stream response
const mockStream = { const mockStream = {
stream: [ stream: [{ text: () => "Hello" }, { text: () => " world!" }],
{ text: () => 'Hello' },
{ text: () => ' world!' }
],
response: { response: {
usageMetadata: { usageMetadata: {
promptTokenCount: 10, promptTokenCount: 10,
candidatesTokenCount: 5 candidatesTokenCount: 5,
},
},
} }
}
};
// Setup the mock implementation // Setup the mock implementation
const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream); const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream)
const mockGetGenerativeModel = jest.fn().mockReturnValue({ const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContentStream: mockGenerateContentStream generateContentStream: mockGenerateContentStream,
}); })
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel; ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
const stream = handler.createMessage(systemPrompt, mockMessages); const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []; const chunks = []
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk)
} }
// Should have 3 chunks: 'Hello', ' world!', and usage info // Should have 3 chunks: 'Hello', ' world!', and usage info
expect(chunks.length).toBe(3); expect(chunks.length).toBe(3)
expect(chunks[0]).toEqual({ expect(chunks[0]).toEqual({
type: 'text', type: "text",
text: 'Hello' text: "Hello",
}); })
expect(chunks[1]).toEqual({ expect(chunks[1]).toEqual({
type: 'text', type: "text",
text: ' world!' text: " world!",
}); })
expect(chunks[2]).toEqual({ expect(chunks[2]).toEqual({
type: 'usage', type: "usage",
inputTokens: 10, inputTokens: 10,
outputTokens: 5 outputTokens: 5,
}); })
// Verify the model configuration // Verify the model configuration
expect(mockGetGenerativeModel).toHaveBeenCalledWith({ expect(mockGetGenerativeModel).toHaveBeenCalledWith({
model: 'gemini-2.0-flash-thinking-exp-1219', model: "gemini-2.0-flash-thinking-exp-1219",
systemInstruction: systemPrompt systemInstruction: systemPrompt,
}); })
// Verify generation config // Verify generation config
expect(mockGenerateContentStream).toHaveBeenCalledWith( expect(mockGenerateContentStream).toHaveBeenCalledWith(
expect.objectContaining({ expect.objectContaining({
generationConfig: { generationConfig: {
temperature: 0 temperature: 0,
} },
}),
)
}) })
);
});
it('should handle API errors', async () => { it("should handle API errors", async () => {
const mockError = new Error('Gemini API error'); const mockError = new Error("Gemini API error")
const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError); const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError)
const mockGetGenerativeModel = jest.fn().mockReturnValue({ const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContentStream: mockGenerateContentStream generateContentStream: mockGenerateContentStream,
}); })
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel; ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
const stream = handler.createMessage(systemPrompt, mockMessages); const stream = handler.createMessage(systemPrompt, mockMessages)
await expect(async () => { await expect(async () => {
for await (const chunk of stream) { for await (const chunk of stream) {
// Should throw before yielding any chunks // Should throw before yielding any chunks
} }
}).rejects.toThrow('Gemini API error'); }).rejects.toThrow("Gemini API error")
}); })
}); })
describe('completePrompt', () => { describe("completePrompt", () => {
it('should complete prompt successfully', async () => { it("should complete prompt successfully", async () => {
const mockGenerateContent = jest.fn().mockResolvedValue({ const mockGenerateContent = jest.fn().mockResolvedValue({
response: { response: {
text: () => 'Test response' text: () => "Test response",
} },
}); })
const mockGetGenerativeModel = jest.fn().mockReturnValue({ const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent generateContent: mockGenerateContent,
}); })
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel; ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe('Test response'); expect(result).toBe("Test response")
expect(mockGetGenerativeModel).toHaveBeenCalledWith({ expect(mockGetGenerativeModel).toHaveBeenCalledWith({
model: 'gemini-2.0-flash-thinking-exp-1219' model: "gemini-2.0-flash-thinking-exp-1219",
}); })
expect(mockGenerateContent).toHaveBeenCalledWith({ expect(mockGenerateContent).toHaveBeenCalledWith({
contents: [{ role: 'user', parts: [{ text: 'Test prompt' }] }], contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
generationConfig: { generationConfig: {
temperature: 0 temperature: 0,
} },
}); })
}); })
it('should handle API errors', async () => { it("should handle API errors", async () => {
const mockError = new Error('Gemini API error'); const mockError = new Error("Gemini API error")
const mockGenerateContent = jest.fn().mockRejectedValue(mockError); const mockGenerateContent = jest.fn().mockRejectedValue(mockError)
const mockGetGenerativeModel = jest.fn().mockReturnValue({ const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent generateContent: mockGenerateContent,
}); })
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel; ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
await expect(handler.completePrompt('Test prompt')) await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
.rejects.toThrow('Gemini completion error: Gemini API error'); "Gemini completion error: Gemini API error",
}); )
})
it('should handle empty response', async () => { it("should handle empty response", async () => {
const mockGenerateContent = jest.fn().mockResolvedValue({ const mockGenerateContent = jest.fn().mockResolvedValue({
response: { response: {
text: () => '' text: () => "",
} },
}); })
const mockGetGenerativeModel = jest.fn().mockReturnValue({ const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent generateContent: mockGenerateContent,
}); })
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel; ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe(''); expect(result).toBe("")
}); })
}); })
describe('getModel', () => { describe("getModel", () => {
it('should return correct model info', () => { it("should return correct model info", () => {
const modelInfo = handler.getModel(); const modelInfo = handler.getModel()
expect(modelInfo.id).toBe('gemini-2.0-flash-thinking-exp-1219'); expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219")
expect(modelInfo.info).toBeDefined(); expect(modelInfo.info).toBeDefined()
expect(modelInfo.info.maxTokens).toBe(8192); expect(modelInfo.info.maxTokens).toBe(8192)
expect(modelInfo.info.contextWindow).toBe(32_767); expect(modelInfo.info.contextWindow).toBe(32_767)
}); })
it('should return default model if invalid model specified', () => { it("should return default model if invalid model specified", () => {
const invalidHandler = new GeminiHandler({ const invalidHandler = new GeminiHandler({
apiModelId: 'invalid-model', apiModelId: "invalid-model",
geminiApiKey: 'test-key' geminiApiKey: "test-key",
}); })
const modelInfo = invalidHandler.getModel(); const modelInfo = invalidHandler.getModel()
expect(modelInfo.id).toBe('gemini-2.0-flash-thinking-exp-1219'); // Default model expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219") // Default model
}); })
}); })
}); })

View File

@@ -1,14 +1,14 @@
import { GlamaHandler } from '../glama'; import { GlamaHandler } from "../glama"
import { ApiHandlerOptions } from '../../../shared/api'; import { ApiHandlerOptions } from "../../../shared/api"
import OpenAI from 'openai'; import OpenAI from "openai"
import { Anthropic } from '@anthropic-ai/sdk'; import { Anthropic } from "@anthropic-ai/sdk"
import axios from 'axios'; import axios from "axios"
// Mock OpenAI client // Mock OpenAI client
const mockCreate = jest.fn(); const mockCreate = jest.fn()
const mockWithResponse = jest.fn(); const mockWithResponse = jest.fn()
jest.mock('openai', () => { jest.mock("openai", () => {
return { return {
__esModule: true, __esModule: true,
default: jest.fn().mockImplementation(() => ({ default: jest.fn().mockImplementation(() => ({
@@ -18,209 +18,221 @@ jest.mock('openai', () => {
const stream = { const stream = {
[Symbol.asyncIterator]: async function* () { [Symbol.asyncIterator]: async function* () {
yield { yield {
choices: [{ choices: [
delta: { content: 'Test response' }, {
index: 0 delta: { content: "Test response" },
}], index: 0,
usage: null },
}; ],
usage: null,
}
yield { yield {
choices: [{ choices: [
{
delta: {}, delta: {},
index: 0 index: 0,
}], },
],
usage: { usage: {
prompt_tokens: 10, prompt_tokens: 10,
completion_tokens: 5, completion_tokens: 5,
total_tokens: 15 total_tokens: 15,
},
} }
}; },
} }
};
const result = mockCreate(...args); const result = mockCreate(...args)
if (args[0].stream) { if (args[0].stream) {
mockWithResponse.mockReturnValue(Promise.resolve({ mockWithResponse.mockReturnValue(
Promise.resolve({
data: stream, data: stream,
response: { response: {
headers: { headers: {
get: (name: string) => name === 'x-completion-request-id' ? 'test-request-id' : null get: (name: string) =>
name === "x-completion-request-id" ? "test-request-id" : null,
},
},
}),
)
result.withResponse = mockWithResponse
} }
return result
},
},
},
})),
} }
})); })
result.withResponse = mockWithResponse;
}
return result;
}
}
}
}))
};
});
describe('GlamaHandler', () => { describe("GlamaHandler", () => {
let handler: GlamaHandler; let handler: GlamaHandler
let mockOptions: ApiHandlerOptions; let mockOptions: ApiHandlerOptions
beforeEach(() => { beforeEach(() => {
mockOptions = { mockOptions = {
apiModelId: 'anthropic/claude-3-5-sonnet', apiModelId: "anthropic/claude-3-5-sonnet",
glamaModelId: 'anthropic/claude-3-5-sonnet', glamaModelId: "anthropic/claude-3-5-sonnet",
glamaApiKey: 'test-api-key' glamaApiKey: "test-api-key",
}; }
handler = new GlamaHandler(mockOptions); handler = new GlamaHandler(mockOptions)
mockCreate.mockClear(); mockCreate.mockClear()
mockWithResponse.mockClear(); mockWithResponse.mockClear()
// Default mock implementation for non-streaming responses // Default mock implementation for non-streaming responses
mockCreate.mockResolvedValue({ mockCreate.mockResolvedValue({
id: 'test-completion', id: "test-completion",
choices: [{ choices: [
message: { role: 'assistant', content: 'Test response' }, {
finish_reason: 'stop', message: { role: "assistant", content: "Test response" },
index: 0 finish_reason: "stop",
}], index: 0,
},
],
usage: { usage: {
prompt_tokens: 10, prompt_tokens: 10,
completion_tokens: 5, completion_tokens: 5,
total_tokens: 15 total_tokens: 15,
} },
}); })
}); })
describe('constructor', () => { describe("constructor", () => {
it('should initialize with provided options', () => { it("should initialize with provided options", () => {
expect(handler).toBeInstanceOf(GlamaHandler); expect(handler).toBeInstanceOf(GlamaHandler)
expect(handler.getModel().id).toBe(mockOptions.apiModelId); expect(handler.getModel().id).toBe(mockOptions.apiModelId)
}); })
}); })
describe('createMessage', () => { describe("createMessage", () => {
const systemPrompt = 'You are a helpful assistant.'; const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [ const messages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: "user",
content: 'Hello!' content: "Hello!",
} },
]; ]
it('should handle streaming responses', async () => { it("should handle streaming responses", async () => {
// Mock axios for token usage request // Mock axios for token usage request
const mockAxios = jest.spyOn(axios, 'get').mockResolvedValueOnce({ const mockAxios = jest.spyOn(axios, "get").mockResolvedValueOnce({
data: { data: {
tokenUsage: { tokenUsage: {
promptTokens: 10, promptTokens: 10,
completionTokens: 5, completionTokens: 5,
cacheCreationInputTokens: 0, cacheCreationInputTokens: 0,
cacheReadInputTokens: 0 cacheReadInputTokens: 0,
}, },
totalCostUsd: "0.00" totalCostUsd: "0.00",
} },
}); })
const stream = handler.createMessage(systemPrompt, messages); const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []; const chunks: any[] = []
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk)
} }
expect(chunks.length).toBe(2); // Text chunk and usage chunk expect(chunks.length).toBe(2) // Text chunk and usage chunk
expect(chunks[0]).toEqual({ expect(chunks[0]).toEqual({
type: 'text', type: "text",
text: 'Test response' text: "Test response",
}); })
expect(chunks[1]).toEqual({ expect(chunks[1]).toEqual({
type: 'usage', type: "usage",
inputTokens: 10, inputTokens: 10,
outputTokens: 5, outputTokens: 5,
cacheWriteTokens: 0, cacheWriteTokens: 0,
cacheReadTokens: 0, cacheReadTokens: 0,
totalCost: 0 totalCost: 0,
}); })
mockAxios.mockRestore(); mockAxios.mockRestore()
}); })
it('should handle API errors', async () => { it("should handle API errors", async () => {
mockCreate.mockImplementationOnce(() => { mockCreate.mockImplementationOnce(() => {
throw new Error('API Error'); throw new Error("API Error")
}); })
const stream = handler.createMessage(systemPrompt, messages); const stream = handler.createMessage(systemPrompt, messages)
const chunks = []; const chunks = []
try { try {
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk)
} }
fail('Expected error to be thrown'); fail("Expected error to be thrown")
} catch (error) { } catch (error) {
expect(error).toBeInstanceOf(Error); expect(error).toBeInstanceOf(Error)
expect(error.message).toBe('API Error'); expect(error.message).toBe("API Error")
} }
}); })
}); })
describe('completePrompt', () => { describe("completePrompt", () => {
it('should complete prompt successfully', async () => { it("should complete prompt successfully", async () => {
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe('Test response'); expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: mockOptions.apiModelId, model: mockOptions.apiModelId,
messages: [{ role: 'user', content: 'Test prompt' }], messages: [{ role: "user", content: "Test prompt" }],
temperature: 0, temperature: 0,
max_tokens: 8192 max_tokens: 8192,
})); }),
}); )
})
it('should handle API errors', async () => { it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error')); mockCreate.mockRejectedValueOnce(new Error("API Error"))
await expect(handler.completePrompt('Test prompt')) await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Glama completion error: API Error")
.rejects.toThrow('Glama completion error: API Error'); })
});
it('should handle empty response', async () => { it("should handle empty response", async () => {
mockCreate.mockResolvedValueOnce({ mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: '' } }] choices: [{ message: { content: "" } }],
}); })
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe(''); expect(result).toBe("")
}); })
it('should not set max_tokens for non-Anthropic models', async () => { it("should not set max_tokens for non-Anthropic models", async () => {
// Reset mock to clear any previous calls // Reset mock to clear any previous calls
mockCreate.mockClear(); mockCreate.mockClear()
const nonAnthropicOptions = { const nonAnthropicOptions = {
apiModelId: 'openai/gpt-4', apiModelId: "openai/gpt-4",
glamaModelId: 'openai/gpt-4', glamaModelId: "openai/gpt-4",
glamaApiKey: 'test-key', glamaApiKey: "test-key",
glamaModelInfo: { glamaModelInfo: {
maxTokens: 4096, maxTokens: 4096,
contextWindow: 8192, contextWindow: 8192,
supportsImages: true, supportsImages: true,
supportsPromptCache: false supportsPromptCache: false,
},
} }
}; const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions)
const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions);
await nonAnthropicHandler.completePrompt('Test prompt'); await nonAnthropicHandler.completePrompt("Test prompt")
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ expect(mockCreate).toHaveBeenCalledWith(
model: 'openai/gpt-4', expect.objectContaining({
messages: [{ role: 'user', content: 'Test prompt' }], model: "openai/gpt-4",
temperature: 0 messages: [{ role: "user", content: "Test prompt" }],
})); temperature: 0,
expect(mockCreate.mock.calls[0][0]).not.toHaveProperty('max_tokens'); }),
}); )
}); expect(mockCreate.mock.calls[0][0]).not.toHaveProperty("max_tokens")
})
})
describe('getModel', () => { describe("getModel", () => {
it('should return model info', () => { it("should return model info", () => {
const modelInfo = handler.getModel(); const modelInfo = handler.getModel()
expect(modelInfo.id).toBe(mockOptions.apiModelId); expect(modelInfo.id).toBe(mockOptions.apiModelId)
expect(modelInfo.info).toBeDefined(); expect(modelInfo.info).toBeDefined()
expect(modelInfo.info.maxTokens).toBe(8192); expect(modelInfo.info.maxTokens).toBe(8192)
expect(modelInfo.info.contextWindow).toBe(200_000); expect(modelInfo.info.contextWindow).toBe(200_000)
}); })
}); })
}); })

View File

@@ -1,11 +1,11 @@
import { LmStudioHandler } from '../lmstudio'; import { LmStudioHandler } from "../lmstudio"
import { ApiHandlerOptions } from '../../../shared/api'; import { ApiHandlerOptions } from "../../../shared/api"
import OpenAI from 'openai'; import OpenAI from "openai"
import { Anthropic } from '@anthropic-ai/sdk'; import { Anthropic } from "@anthropic-ai/sdk"
// Mock OpenAI client // Mock OpenAI client
const mockCreate = jest.fn(); const mockCreate = jest.fn()
jest.mock('openai', () => { jest.mock("openai", () => {
return { return {
__esModule: true, __esModule: true,
default: jest.fn().mockImplementation(() => ({ default: jest.fn().mockImplementation(() => ({
@@ -14,147 +14,154 @@ jest.mock('openai', () => {
create: mockCreate.mockImplementation(async (options) => { create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) { if (!options.stream) {
return { return {
id: 'test-completion', id: "test-completion",
choices: [{ choices: [
message: { role: 'assistant', content: 'Test response' }, {
finish_reason: 'stop', message: { role: "assistant", content: "Test response" },
index: 0 finish_reason: "stop",
}], index: 0,
},
],
usage: { usage: {
prompt_tokens: 10, prompt_tokens: 10,
completion_tokens: 5, completion_tokens: 5,
total_tokens: 15 total_tokens: 15,
},
} }
};
} }
return { return {
[Symbol.asyncIterator]: async function* () { [Symbol.asyncIterator]: async function* () {
yield { yield {
choices: [{ choices: [
delta: { content: 'Test response' }, {
index: 0 delta: { content: "Test response" },
}], index: 0,
usage: null },
}; ],
usage: null,
}
yield { yield {
choices: [{ choices: [
{
delta: {}, delta: {},
index: 0 index: 0,
}], },
],
usage: { usage: {
prompt_tokens: 10, prompt_tokens: 10,
completion_tokens: 5, completion_tokens: 5,
total_tokens: 15 total_tokens: 15,
},
} }
}; },
} }
}; }),
}) },
},
})),
} }
} })
}))
};
});
describe('LmStudioHandler', () => { describe("LmStudioHandler", () => {
let handler: LmStudioHandler; let handler: LmStudioHandler
let mockOptions: ApiHandlerOptions; let mockOptions: ApiHandlerOptions
beforeEach(() => { beforeEach(() => {
mockOptions = { mockOptions = {
apiModelId: 'local-model', apiModelId: "local-model",
lmStudioModelId: 'local-model', lmStudioModelId: "local-model",
lmStudioBaseUrl: 'http://localhost:1234/v1' lmStudioBaseUrl: "http://localhost:1234/v1",
}; }
handler = new LmStudioHandler(mockOptions); handler = new LmStudioHandler(mockOptions)
mockCreate.mockClear(); mockCreate.mockClear()
}); })
describe('constructor', () => { describe("constructor", () => {
it('should initialize with provided options', () => { it("should initialize with provided options", () => {
expect(handler).toBeInstanceOf(LmStudioHandler); expect(handler).toBeInstanceOf(LmStudioHandler)
expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId); expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId)
}); })
it('should use default base URL if not provided', () => { it("should use default base URL if not provided", () => {
const handlerWithoutUrl = new LmStudioHandler({ const handlerWithoutUrl = new LmStudioHandler({
apiModelId: 'local-model', apiModelId: "local-model",
lmStudioModelId: 'local-model' lmStudioModelId: "local-model",
}); })
expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler); expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler)
}); })
}); })
describe('createMessage', () => { describe("createMessage", () => {
const systemPrompt = 'You are a helpful assistant.'; const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [ const messages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: "user",
content: 'Hello!' content: "Hello!",
} },
]; ]
it('should handle streaming responses', async () => { it("should handle streaming responses", async () => {
const stream = handler.createMessage(systemPrompt, messages); const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []; const chunks: any[] = []
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk)
} }
expect(chunks.length).toBeGreaterThan(0); expect(chunks.length).toBeGreaterThan(0)
const textChunks = chunks.filter(chunk => chunk.type === 'text'); const textChunks = chunks.filter((chunk) => chunk.type === "text")
expect(textChunks).toHaveLength(1); expect(textChunks).toHaveLength(1)
expect(textChunks[0].text).toBe('Test response'); expect(textChunks[0].text).toBe("Test response")
}); })
it('should handle API errors', async () => { it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error')); mockCreate.mockRejectedValueOnce(new Error("API Error"))
const stream = handler.createMessage(systemPrompt, messages); const stream = handler.createMessage(systemPrompt, messages)
await expect(async () => { await expect(async () => {
for await (const chunk of stream) { for await (const chunk of stream) {
// Should not reach here // Should not reach here
} }
}).rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong'); }).rejects.toThrow("Please check the LM Studio developer logs to debug what went wrong")
}); })
}); })
describe('completePrompt', () => { describe("completePrompt", () => {
it('should complete prompt successfully', async () => { it("should complete prompt successfully", async () => {
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe('Test response'); expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({ expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.lmStudioModelId, model: mockOptions.lmStudioModelId,
messages: [{ role: 'user', content: 'Test prompt' }], messages: [{ role: "user", content: "Test prompt" }],
temperature: 0, temperature: 0,
stream: false stream: false,
}); })
}); })
it('should handle API errors', async () => { it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error')); mockCreate.mockRejectedValueOnce(new Error("API Error"))
await expect(handler.completePrompt('Test prompt')) await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
.rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong'); "Please check the LM Studio developer logs to debug what went wrong",
}); )
})
it('should handle empty response', async () => { it("should handle empty response", async () => {
mockCreate.mockResolvedValueOnce({ mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: '' } }] choices: [{ message: { content: "" } }],
}); })
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe(''); expect(result).toBe("")
}); })
}); })
describe('getModel', () => { describe("getModel", () => {
it('should return model info', () => { it("should return model info", () => {
const modelInfo = handler.getModel(); const modelInfo = handler.getModel()
expect(modelInfo.id).toBe(mockOptions.lmStudioModelId); expect(modelInfo.id).toBe(mockOptions.lmStudioModelId)
expect(modelInfo.info).toBeDefined(); expect(modelInfo.info).toBeDefined()
expect(modelInfo.info.maxTokens).toBe(-1); expect(modelInfo.info.maxTokens).toBe(-1)
expect(modelInfo.info.contextWindow).toBe(128_000); expect(modelInfo.info.contextWindow).toBe(128_000)
}); })
}); })
}); })

View File

@@ -1,11 +1,11 @@
import { OllamaHandler } from '../ollama'; import { OllamaHandler } from "../ollama"
import { ApiHandlerOptions } from '../../../shared/api'; import { ApiHandlerOptions } from "../../../shared/api"
import OpenAI from 'openai'; import OpenAI from "openai"
import { Anthropic } from '@anthropic-ai/sdk'; import { Anthropic } from "@anthropic-ai/sdk"
// Mock OpenAI client // Mock OpenAI client
const mockCreate = jest.fn(); const mockCreate = jest.fn()
jest.mock('openai', () => { jest.mock("openai", () => {
return { return {
__esModule: true, __esModule: true,
default: jest.fn().mockImplementation(() => ({ default: jest.fn().mockImplementation(() => ({
@@ -14,147 +14,152 @@ jest.mock('openai', () => {
create: mockCreate.mockImplementation(async (options) => { create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) { if (!options.stream) {
return { return {
id: 'test-completion', id: "test-completion",
choices: [{ choices: [
message: { role: 'assistant', content: 'Test response' }, {
finish_reason: 'stop', message: { role: "assistant", content: "Test response" },
index: 0 finish_reason: "stop",
}], index: 0,
},
],
usage: { usage: {
prompt_tokens: 10, prompt_tokens: 10,
completion_tokens: 5, completion_tokens: 5,
total_tokens: 15 total_tokens: 15,
},
} }
};
} }
return { return {
[Symbol.asyncIterator]: async function* () { [Symbol.asyncIterator]: async function* () {
yield { yield {
choices: [{ choices: [
delta: { content: 'Test response' }, {
index: 0 delta: { content: "Test response" },
}], index: 0,
usage: null },
}; ],
usage: null,
}
yield { yield {
choices: [{ choices: [
{
delta: {}, delta: {},
index: 0 index: 0,
}], },
],
usage: { usage: {
prompt_tokens: 10, prompt_tokens: 10,
completion_tokens: 5, completion_tokens: 5,
total_tokens: 15 total_tokens: 15,
},
} }
}; },
} }
}; }),
}) },
},
})),
} }
} })
}))
};
});
describe('OllamaHandler', () => { describe("OllamaHandler", () => {
let handler: OllamaHandler; let handler: OllamaHandler
let mockOptions: ApiHandlerOptions; let mockOptions: ApiHandlerOptions
beforeEach(() => { beforeEach(() => {
mockOptions = { mockOptions = {
apiModelId: 'llama2', apiModelId: "llama2",
ollamaModelId: 'llama2', ollamaModelId: "llama2",
ollamaBaseUrl: 'http://localhost:11434/v1' ollamaBaseUrl: "http://localhost:11434/v1",
}; }
handler = new OllamaHandler(mockOptions); handler = new OllamaHandler(mockOptions)
mockCreate.mockClear(); mockCreate.mockClear()
}); })
describe('constructor', () => { describe("constructor", () => {
it('should initialize with provided options', () => { it("should initialize with provided options", () => {
expect(handler).toBeInstanceOf(OllamaHandler); expect(handler).toBeInstanceOf(OllamaHandler)
expect(handler.getModel().id).toBe(mockOptions.ollamaModelId); expect(handler.getModel().id).toBe(mockOptions.ollamaModelId)
}); })
it('should use default base URL if not provided', () => { it("should use default base URL if not provided", () => {
const handlerWithoutUrl = new OllamaHandler({ const handlerWithoutUrl = new OllamaHandler({
apiModelId: 'llama2', apiModelId: "llama2",
ollamaModelId: 'llama2' ollamaModelId: "llama2",
}); })
expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler); expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler)
}); })
}); })
describe('createMessage', () => { describe("createMessage", () => {
const systemPrompt = 'You are a helpful assistant.'; const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [ const messages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: "user",
content: 'Hello!' content: "Hello!",
} },
]; ]
it('should handle streaming responses', async () => { it("should handle streaming responses", async () => {
const stream = handler.createMessage(systemPrompt, messages); const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []; const chunks: any[] = []
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk)
} }
expect(chunks.length).toBeGreaterThan(0); expect(chunks.length).toBeGreaterThan(0)
const textChunks = chunks.filter(chunk => chunk.type === 'text'); const textChunks = chunks.filter((chunk) => chunk.type === "text")
expect(textChunks).toHaveLength(1); expect(textChunks).toHaveLength(1)
expect(textChunks[0].text).toBe('Test response'); expect(textChunks[0].text).toBe("Test response")
}); })
it('should handle API errors', async () => { it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error')); mockCreate.mockRejectedValueOnce(new Error("API Error"))
const stream = handler.createMessage(systemPrompt, messages); const stream = handler.createMessage(systemPrompt, messages)
await expect(async () => { await expect(async () => {
for await (const chunk of stream) { for await (const chunk of stream) {
// Should not reach here // Should not reach here
} }
}).rejects.toThrow('API Error'); }).rejects.toThrow("API Error")
}); })
}); })
describe('completePrompt', () => { describe("completePrompt", () => {
it('should complete prompt successfully', async () => { it("should complete prompt successfully", async () => {
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe('Test response'); expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({ expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.ollamaModelId, model: mockOptions.ollamaModelId,
messages: [{ role: 'user', content: 'Test prompt' }], messages: [{ role: "user", content: "Test prompt" }],
temperature: 0, temperature: 0,
stream: false stream: false,
}); })
}); })
it('should handle API errors', async () => { it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error')); mockCreate.mockRejectedValueOnce(new Error("API Error"))
await expect(handler.completePrompt('Test prompt')) await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Ollama completion error: API Error")
.rejects.toThrow('Ollama completion error: API Error'); })
});
it('should handle empty response', async () => { it("should handle empty response", async () => {
mockCreate.mockResolvedValueOnce({ mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: '' } }] choices: [{ message: { content: "" } }],
}); })
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe(''); expect(result).toBe("")
}); })
}); })
describe('getModel', () => { describe("getModel", () => {
it('should return model info', () => { it("should return model info", () => {
const modelInfo = handler.getModel(); const modelInfo = handler.getModel()
expect(modelInfo.id).toBe(mockOptions.ollamaModelId); expect(modelInfo.id).toBe(mockOptions.ollamaModelId)
expect(modelInfo.info).toBeDefined(); expect(modelInfo.info).toBeDefined()
expect(modelInfo.info.maxTokens).toBe(-1); expect(modelInfo.info.maxTokens).toBe(-1)
expect(modelInfo.info.contextWindow).toBe(128_000); expect(modelInfo.info.contextWindow).toBe(128_000)
}); })
}); })
}); })

View File

@@ -1,11 +1,11 @@
import { OpenAiNativeHandler } from '../openai-native'; import { OpenAiNativeHandler } from "../openai-native"
import { ApiHandlerOptions } from '../../../shared/api'; import { ApiHandlerOptions } from "../../../shared/api"
import OpenAI from 'openai'; import OpenAI from "openai"
import { Anthropic } from '@anthropic-ai/sdk'; import { Anthropic } from "@anthropic-ai/sdk"
// Mock OpenAI client // Mock OpenAI client
const mockCreate = jest.fn(); const mockCreate = jest.fn()
jest.mock('openai', () => { jest.mock("openai", () => {
return { return {
__esModule: true, __esModule: true,
default: jest.fn().mockImplementation(() => ({ default: jest.fn().mockImplementation(() => ({
@@ -14,306 +14,313 @@ jest.mock('openai', () => {
create: mockCreate.mockImplementation(async (options) => { create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) { if (!options.stream) {
return { return {
id: 'test-completion', id: "test-completion",
choices: [{ choices: [
message: { role: 'assistant', content: 'Test response' }, {
finish_reason: 'stop', message: { role: "assistant", content: "Test response" },
index: 0 finish_reason: "stop",
}], index: 0,
},
],
usage: { usage: {
prompt_tokens: 10, prompt_tokens: 10,
completion_tokens: 5, completion_tokens: 5,
total_tokens: 15 total_tokens: 15,
},
} }
};
} }
return { return {
[Symbol.asyncIterator]: async function* () { [Symbol.asyncIterator]: async function* () {
yield { yield {
choices: [{ choices: [
delta: { content: 'Test response' }, {
index: 0 delta: { content: "Test response" },
}], index: 0,
usage: null },
}; ],
usage: null,
}
yield { yield {
choices: [{ choices: [
{
delta: {}, delta: {},
index: 0 index: 0,
}], },
],
usage: { usage: {
prompt_tokens: 10, prompt_tokens: 10,
completion_tokens: 5, completion_tokens: 5,
total_tokens: 15 total_tokens: 15,
},
} }
}; },
} }
}; }),
}) },
},
})),
} }
} })
}))
};
});
describe('OpenAiNativeHandler', () => { describe("OpenAiNativeHandler", () => {
let handler: OpenAiNativeHandler; let handler: OpenAiNativeHandler
let mockOptions: ApiHandlerOptions; let mockOptions: ApiHandlerOptions
const systemPrompt = 'You are a helpful assistant.'; const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [ const messages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: "user",
content: 'Hello!' content: "Hello!",
} },
]; ]
beforeEach(() => { beforeEach(() => {
mockOptions = { mockOptions = {
apiModelId: 'gpt-4o', apiModelId: "gpt-4o",
openAiNativeApiKey: 'test-api-key' openAiNativeApiKey: "test-api-key",
}; }
handler = new OpenAiNativeHandler(mockOptions); handler = new OpenAiNativeHandler(mockOptions)
mockCreate.mockClear(); mockCreate.mockClear()
}); })
describe('constructor', () => { describe("constructor", () => {
it('should initialize with provided options', () => { it("should initialize with provided options", () => {
expect(handler).toBeInstanceOf(OpenAiNativeHandler); expect(handler).toBeInstanceOf(OpenAiNativeHandler)
expect(handler.getModel().id).toBe(mockOptions.apiModelId); expect(handler.getModel().id).toBe(mockOptions.apiModelId)
}); })
it('should initialize with empty API key', () => { it("should initialize with empty API key", () => {
const handlerWithoutKey = new OpenAiNativeHandler({ const handlerWithoutKey = new OpenAiNativeHandler({
apiModelId: 'gpt-4o', apiModelId: "gpt-4o",
openAiNativeApiKey: '' openAiNativeApiKey: "",
}); })
expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler); expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler)
}); })
}); })
describe('createMessage', () => { describe("createMessage", () => {
it('should handle streaming responses', async () => { it("should handle streaming responses", async () => {
const stream = handler.createMessage(systemPrompt, messages); const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []; const chunks: any[] = []
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk)
} }
expect(chunks.length).toBeGreaterThan(0); expect(chunks.length).toBeGreaterThan(0)
const textChunks = chunks.filter(chunk => chunk.type === 'text'); const textChunks = chunks.filter((chunk) => chunk.type === "text")
expect(textChunks).toHaveLength(1); expect(textChunks).toHaveLength(1)
expect(textChunks[0].text).toBe('Test response'); expect(textChunks[0].text).toBe("Test response")
}); })
it('should handle API errors', async () => { it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error')); mockCreate.mockRejectedValueOnce(new Error("API Error"))
const stream = handler.createMessage(systemPrompt, messages); const stream = handler.createMessage(systemPrompt, messages)
await expect(async () => { await expect(async () => {
for await (const chunk of stream) { for await (const chunk of stream) {
// Should not reach here // Should not reach here
} }
}).rejects.toThrow('API Error'); }).rejects.toThrow("API Error")
}); })
it('should handle missing content in response for o1 model', async () => { it("should handle missing content in response for o1 model", async () => {
// Use o1 model which supports developer role // Use o1 model which supports developer role
handler = new OpenAiNativeHandler({ handler = new OpenAiNativeHandler({
...mockOptions, ...mockOptions,
apiModelId: 'o1' apiModelId: "o1",
}); })
mockCreate.mockResolvedValueOnce({ mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: null } }], choices: [{ message: { content: null } }],
usage: { usage: {
prompt_tokens: 0, prompt_tokens: 0,
completion_tokens: 0, completion_tokens: 0,
total_tokens: 0 total_tokens: 0,
} },
}); })
const generator = handler.createMessage(systemPrompt, messages); const generator = handler.createMessage(systemPrompt, messages)
const results = []; const results = []
for await (const result of generator) { for await (const result of generator) {
results.push(result); results.push(result)
} }
expect(results).toEqual([ expect(results).toEqual([
{ type: 'text', text: '' }, { type: "text", text: "" },
{ type: 'usage', inputTokens: 0, outputTokens: 0 } { type: "usage", inputTokens: 0, outputTokens: 0 },
]); ])
// Verify developer role is used for system prompt with o1 model // Verify developer role is used for system prompt with o1 model
expect(mockCreate).toHaveBeenCalledWith({ expect(mockCreate).toHaveBeenCalledWith({
model: 'o1', model: "o1",
messages: [ messages: [
{ role: 'developer', content: systemPrompt }, { role: "developer", content: systemPrompt },
{ role: 'user', content: 'Hello!' } { role: "user", content: "Hello!" },
] ],
}); })
}); })
}); })
describe('streaming models', () => { describe("streaming models", () => {
beforeEach(() => { beforeEach(() => {
handler = new OpenAiNativeHandler({ handler = new OpenAiNativeHandler({
...mockOptions, ...mockOptions,
apiModelId: 'gpt-4o', apiModelId: "gpt-4o",
}); })
}); })
it('should handle streaming response', async () => { it("should handle streaming response", async () => {
const mockStream = [ const mockStream = [
{ choices: [{ delta: { content: 'Hello' } }], usage: null }, { choices: [{ delta: { content: "Hello" } }], usage: null },
{ choices: [{ delta: { content: ' there' } }], usage: null }, { choices: [{ delta: { content: " there" } }], usage: null },
{ choices: [{ delta: { content: '!' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } }, { choices: [{ delta: { content: "!" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
]; ]
mockCreate.mockResolvedValueOnce( mockCreate.mockResolvedValueOnce(
(async function* () { (async function* () {
for (const chunk of mockStream) { for (const chunk of mockStream) {
yield chunk; yield chunk
} }
})() })(),
); )
const generator = handler.createMessage(systemPrompt, messages); const generator = handler.createMessage(systemPrompt, messages)
const results = []; const results = []
for await (const result of generator) { for await (const result of generator) {
results.push(result); results.push(result)
} }
expect(results).toEqual([ expect(results).toEqual([
{ type: 'text', text: 'Hello' }, { type: "text", text: "Hello" },
{ type: 'text', text: ' there' }, { type: "text", text: " there" },
{ type: 'text', text: '!' }, { type: "text", text: "!" },
{ type: 'usage', inputTokens: 10, outputTokens: 5 }, { type: "usage", inputTokens: 10, outputTokens: 5 },
]); ])
expect(mockCreate).toHaveBeenCalledWith({ expect(mockCreate).toHaveBeenCalledWith({
model: 'gpt-4o', model: "gpt-4o",
temperature: 0, temperature: 0,
messages: [ messages: [
{ role: 'system', content: systemPrompt }, { role: "system", content: systemPrompt },
{ role: 'user', content: 'Hello!' }, { role: "user", content: "Hello!" },
], ],
stream: true, stream: true,
stream_options: { include_usage: true }, stream_options: { include_usage: true },
}); })
}); })
it('should handle empty delta content', async () => { it("should handle empty delta content", async () => {
const mockStream = [ const mockStream = [
{ choices: [{ delta: {} }], usage: null }, { choices: [{ delta: {} }], usage: null },
{ choices: [{ delta: { content: null } }], usage: null }, { choices: [{ delta: { content: null } }], usage: null },
{ choices: [{ delta: { content: 'Hello' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } }, { choices: [{ delta: { content: "Hello" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
]; ]
mockCreate.mockResolvedValueOnce( mockCreate.mockResolvedValueOnce(
(async function* () { (async function* () {
for (const chunk of mockStream) { for (const chunk of mockStream) {
yield chunk; yield chunk
} }
})() })(),
); )
const generator = handler.createMessage(systemPrompt, messages); const generator = handler.createMessage(systemPrompt, messages)
const results = []; const results = []
for await (const result of generator) { for await (const result of generator) {
results.push(result); results.push(result)
} }
expect(results).toEqual([ expect(results).toEqual([
{ type: 'text', text: 'Hello' }, { type: "text", text: "Hello" },
{ type: 'usage', inputTokens: 10, outputTokens: 5 }, { type: "usage", inputTokens: 10, outputTokens: 5 },
]); ])
}); })
}); })
describe('completePrompt', () => { describe("completePrompt", () => {
it('should complete prompt successfully with gpt-4o model', async () => { it("should complete prompt successfully with gpt-4o model", async () => {
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe('Test response'); expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({ expect(mockCreate).toHaveBeenCalledWith({
model: 'gpt-4o', model: "gpt-4o",
messages: [{ role: 'user', content: 'Test prompt' }], messages: [{ role: "user", content: "Test prompt" }],
temperature: 0 temperature: 0,
}); })
}); })
it('should complete prompt successfully with o1 model', async () => { it("should complete prompt successfully with o1 model", async () => {
handler = new OpenAiNativeHandler({ handler = new OpenAiNativeHandler({
apiModelId: 'o1', apiModelId: "o1",
openAiNativeApiKey: 'test-api-key' openAiNativeApiKey: "test-api-key",
}); })
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe('Test response'); expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({ expect(mockCreate).toHaveBeenCalledWith({
model: 'o1', model: "o1",
messages: [{ role: 'user', content: 'Test prompt' }] messages: [{ role: "user", content: "Test prompt" }],
}); })
}); })
it('should complete prompt successfully with o1-preview model', async () => { it("should complete prompt successfully with o1-preview model", async () => {
handler = new OpenAiNativeHandler({ handler = new OpenAiNativeHandler({
apiModelId: 'o1-preview', apiModelId: "o1-preview",
openAiNativeApiKey: 'test-api-key' openAiNativeApiKey: "test-api-key",
}); })
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe('Test response'); expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({ expect(mockCreate).toHaveBeenCalledWith({
model: 'o1-preview', model: "o1-preview",
messages: [{ role: 'user', content: 'Test prompt' }] messages: [{ role: "user", content: "Test prompt" }],
}); })
}); })
it('should complete prompt successfully with o1-mini model', async () => { it("should complete prompt successfully with o1-mini model", async () => {
handler = new OpenAiNativeHandler({ handler = new OpenAiNativeHandler({
apiModelId: 'o1-mini', apiModelId: "o1-mini",
openAiNativeApiKey: 'test-api-key' openAiNativeApiKey: "test-api-key",
}); })
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe('Test response'); expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({ expect(mockCreate).toHaveBeenCalledWith({
model: 'o1-mini', model: "o1-mini",
messages: [{ role: 'user', content: 'Test prompt' }] messages: [{ role: "user", content: "Test prompt" }],
}); })
}); })
it('should handle API errors', async () => { it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error')); mockCreate.mockRejectedValueOnce(new Error("API Error"))
await expect(handler.completePrompt('Test prompt')) await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
.rejects.toThrow('OpenAI Native completion error: API Error'); "OpenAI Native completion error: API Error",
}); )
})
it('should handle empty response', async () => { it("should handle empty response", async () => {
mockCreate.mockResolvedValueOnce({ mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: '' } }] choices: [{ message: { content: "" } }],
}); })
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe(''); expect(result).toBe("")
}); })
}); })
describe('getModel', () => { describe("getModel", () => {
it('should return model info', () => { it("should return model info", () => {
const modelInfo = handler.getModel(); const modelInfo = handler.getModel()
expect(modelInfo.id).toBe(mockOptions.apiModelId); expect(modelInfo.id).toBe(mockOptions.apiModelId)
expect(modelInfo.info).toBeDefined(); expect(modelInfo.info).toBeDefined()
expect(modelInfo.info.maxTokens).toBe(4096); expect(modelInfo.info.maxTokens).toBe(4096)
expect(modelInfo.info.contextWindow).toBe(128_000); expect(modelInfo.info.contextWindow).toBe(128_000)
}); })
it('should handle undefined model ID', () => { it("should handle undefined model ID", () => {
const handlerWithoutModel = new OpenAiNativeHandler({ const handlerWithoutModel = new OpenAiNativeHandler({
openAiNativeApiKey: 'test-api-key' openAiNativeApiKey: "test-api-key",
}); })
const modelInfo = handlerWithoutModel.getModel(); const modelInfo = handlerWithoutModel.getModel()
expect(modelInfo.id).toBe('gpt-4o'); // Default model expect(modelInfo.id).toBe("gpt-4o") // Default model
expect(modelInfo.info).toBeDefined(); expect(modelInfo.info).toBeDefined()
}); })
}); })
}); })

View File

@@ -1,12 +1,12 @@
import { OpenAiHandler } from '../openai'; import { OpenAiHandler } from "../openai"
import { ApiHandlerOptions } from '../../../shared/api'; import { ApiHandlerOptions } from "../../../shared/api"
import { ApiStream } from '../../transform/stream'; import { ApiStream } from "../../transform/stream"
import OpenAI from 'openai'; import OpenAI from "openai"
import { Anthropic } from '@anthropic-ai/sdk'; import { Anthropic } from "@anthropic-ai/sdk"
// Mock OpenAI client // Mock OpenAI client
const mockCreate = jest.fn(); const mockCreate = jest.fn()
jest.mock('openai', () => { jest.mock("openai", () => {
return { return {
__esModule: true, __esModule: true,
default: jest.fn().mockImplementation(() => ({ default: jest.fn().mockImplementation(() => ({
@@ -15,210 +15,219 @@ jest.mock('openai', () => {
create: mockCreate.mockImplementation(async (options) => { create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) { if (!options.stream) {
return { return {
id: 'test-completion', id: "test-completion",
choices: [{ choices: [
message: { role: 'assistant', content: 'Test response', refusal: null }, {
finish_reason: 'stop', message: { role: "assistant", content: "Test response", refusal: null },
index: 0 finish_reason: "stop",
}], index: 0,
},
],
usage: { usage: {
prompt_tokens: 10, prompt_tokens: 10,
completion_tokens: 5, completion_tokens: 5,
total_tokens: 15 total_tokens: 15,
},
} }
};
} }
return { return {
[Symbol.asyncIterator]: async function* () { [Symbol.asyncIterator]: async function* () {
yield { yield {
choices: [{ choices: [
delta: { content: 'Test response' }, {
index: 0 delta: { content: "Test response" },
}], index: 0,
usage: null },
}; ],
usage: null,
}
yield { yield {
choices: [{ choices: [
{
delta: {}, delta: {},
index: 0 index: 0,
}], },
],
usage: { usage: {
prompt_tokens: 10, prompt_tokens: 10,
completion_tokens: 5, completion_tokens: 5,
total_tokens: 15 total_tokens: 15,
},
} }
}; },
} }
}; }),
}) },
},
})),
} }
} })
}))
};
});
describe('OpenAiHandler', () => { describe("OpenAiHandler", () => {
let handler: OpenAiHandler; let handler: OpenAiHandler
let mockOptions: ApiHandlerOptions; let mockOptions: ApiHandlerOptions
beforeEach(() => { beforeEach(() => {
mockOptions = { mockOptions = {
openAiApiKey: 'test-api-key', openAiApiKey: "test-api-key",
openAiModelId: 'gpt-4', openAiModelId: "gpt-4",
openAiBaseUrl: 'https://api.openai.com/v1' openAiBaseUrl: "https://api.openai.com/v1",
}; }
handler = new OpenAiHandler(mockOptions); handler = new OpenAiHandler(mockOptions)
mockCreate.mockClear(); mockCreate.mockClear()
}); })
describe('constructor', () => { describe("constructor", () => {
it('should initialize with provided options', () => { it("should initialize with provided options", () => {
expect(handler).toBeInstanceOf(OpenAiHandler); expect(handler).toBeInstanceOf(OpenAiHandler)
expect(handler.getModel().id).toBe(mockOptions.openAiModelId); expect(handler.getModel().id).toBe(mockOptions.openAiModelId)
}); })
it('should use custom base URL if provided', () => { it("should use custom base URL if provided", () => {
const customBaseUrl = 'https://custom.openai.com/v1'; const customBaseUrl = "https://custom.openai.com/v1"
const handlerWithCustomUrl = new OpenAiHandler({ const handlerWithCustomUrl = new OpenAiHandler({
...mockOptions, ...mockOptions,
openAiBaseUrl: customBaseUrl openAiBaseUrl: customBaseUrl,
}); })
expect(handlerWithCustomUrl).toBeInstanceOf(OpenAiHandler); expect(handlerWithCustomUrl).toBeInstanceOf(OpenAiHandler)
}); })
}); })
describe('createMessage', () => { describe("createMessage", () => {
const systemPrompt = 'You are a helpful assistant.'; const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [ const messages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: "user",
content: [{ content: [
type: 'text' as const, {
text: 'Hello!' type: "text" as const,
}] text: "Hello!",
} },
]; ],
},
]
it('should handle non-streaming mode', async () => { it("should handle non-streaming mode", async () => {
const handler = new OpenAiHandler({ const handler = new OpenAiHandler({
...mockOptions, ...mockOptions,
openAiStreamingEnabled: false openAiStreamingEnabled: false,
}); })
const stream = handler.createMessage(systemPrompt, messages); const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []; const chunks: any[] = []
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk)
} }
expect(chunks.length).toBeGreaterThan(0); expect(chunks.length).toBeGreaterThan(0)
const textChunk = chunks.find(chunk => chunk.type === 'text'); const textChunk = chunks.find((chunk) => chunk.type === "text")
const usageChunk = chunks.find(chunk => chunk.type === 'usage'); const usageChunk = chunks.find((chunk) => chunk.type === "usage")
expect(textChunk).toBeDefined(); expect(textChunk).toBeDefined()
expect(textChunk?.text).toBe('Test response'); expect(textChunk?.text).toBe("Test response")
expect(usageChunk).toBeDefined(); expect(usageChunk).toBeDefined()
expect(usageChunk?.inputTokens).toBe(10); expect(usageChunk?.inputTokens).toBe(10)
expect(usageChunk?.outputTokens).toBe(5); expect(usageChunk?.outputTokens).toBe(5)
}); })
it('should handle streaming responses', async () => { it("should handle streaming responses", async () => {
const stream = handler.createMessage(systemPrompt, messages); const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []; const chunks: any[] = []
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk)
} }
expect(chunks.length).toBeGreaterThan(0); expect(chunks.length).toBeGreaterThan(0)
const textChunks = chunks.filter(chunk => chunk.type === 'text'); const textChunks = chunks.filter((chunk) => chunk.type === "text")
expect(textChunks).toHaveLength(1); expect(textChunks).toHaveLength(1)
expect(textChunks[0].text).toBe('Test response'); expect(textChunks[0].text).toBe("Test response")
}); })
}); })
describe('error handling', () => { describe("error handling", () => {
const testMessages: Anthropic.Messages.MessageParam[] = [ const testMessages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: "user",
content: [{ content: [
type: 'text' as const, {
text: 'Hello' type: "text" as const,
}] text: "Hello",
} },
]; ],
},
]
it('should handle API errors', async () => { it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error')); mockCreate.mockRejectedValueOnce(new Error("API Error"))
const stream = handler.createMessage('system prompt', testMessages); const stream = handler.createMessage("system prompt", testMessages)
await expect(async () => { await expect(async () => {
for await (const chunk of stream) { for await (const chunk of stream) {
// Should not reach here // Should not reach here
} }
}).rejects.toThrow('API Error'); }).rejects.toThrow("API Error")
}); })
it('should handle rate limiting', async () => { it("should handle rate limiting", async () => {
const rateLimitError = new Error('Rate limit exceeded'); const rateLimitError = new Error("Rate limit exceeded")
rateLimitError.name = 'Error'; rateLimitError.name = "Error"
(rateLimitError as any).status = 429; ;(rateLimitError as any).status = 429
mockCreate.mockRejectedValueOnce(rateLimitError); mockCreate.mockRejectedValueOnce(rateLimitError)
const stream = handler.createMessage('system prompt', testMessages); const stream = handler.createMessage("system prompt", testMessages)
await expect(async () => { await expect(async () => {
for await (const chunk of stream) { for await (const chunk of stream) {
// Should not reach here // Should not reach here
} }
}).rejects.toThrow('Rate limit exceeded'); }).rejects.toThrow("Rate limit exceeded")
}); })
}); })
describe('completePrompt', () => { describe("completePrompt", () => {
it('should complete prompt successfully', async () => { it("should complete prompt successfully", async () => {
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe('Test response'); expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({ expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.openAiModelId, model: mockOptions.openAiModelId,
messages: [{ role: 'user', content: 'Test prompt' }], messages: [{ role: "user", content: "Test prompt" }],
temperature: 0 temperature: 0,
}); })
}); })
it('should handle API errors', async () => { it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error('API Error')); mockCreate.mockRejectedValueOnce(new Error("API Error"))
await expect(handler.completePrompt('Test prompt')) await expect(handler.completePrompt("Test prompt")).rejects.toThrow("OpenAI completion error: API Error")
.rejects.toThrow('OpenAI completion error: API Error'); })
});
it('should handle empty response', async () => { it("should handle empty response", async () => {
mockCreate.mockImplementationOnce(() => ({ mockCreate.mockImplementationOnce(() => ({
choices: [{ message: { content: '' } }] choices: [{ message: { content: "" } }],
})); }))
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe(''); expect(result).toBe("")
}); })
}); })
describe('getModel', () => { describe("getModel", () => {
it('should return model info with sane defaults', () => { it("should return model info with sane defaults", () => {
const model = handler.getModel(); const model = handler.getModel()
expect(model.id).toBe(mockOptions.openAiModelId); expect(model.id).toBe(mockOptions.openAiModelId)
expect(model.info).toBeDefined(); expect(model.info).toBeDefined()
expect(model.info.contextWindow).toBe(128_000); expect(model.info.contextWindow).toBe(128_000)
expect(model.info.supportsImages).toBe(true); expect(model.info.supportsImages).toBe(true)
}); })
it('should handle undefined model ID', () => { it("should handle undefined model ID", () => {
const handlerWithoutModel = new OpenAiHandler({ const handlerWithoutModel = new OpenAiHandler({
...mockOptions, ...mockOptions,
openAiModelId: undefined openAiModelId: undefined,
}); })
const model = handlerWithoutModel.getModel(); const model = handlerWithoutModel.getModel()
expect(model.id).toBe(''); expect(model.id).toBe("")
expect(model.info).toBeDefined(); expect(model.info).toBeDefined()
}); })
}); })
}); })

View File

@@ -1,83 +1,85 @@
import { OpenRouterHandler } from '../openrouter' import { OpenRouterHandler } from "../openrouter"
import { ApiHandlerOptions, ModelInfo } from '../../../shared/api' import { ApiHandlerOptions, ModelInfo } from "../../../shared/api"
import OpenAI from 'openai' import OpenAI from "openai"
import axios from 'axios' import axios from "axios"
import { Anthropic } from '@anthropic-ai/sdk' import { Anthropic } from "@anthropic-ai/sdk"
// Mock dependencies // Mock dependencies
jest.mock('openai') jest.mock("openai")
jest.mock('axios') jest.mock("axios")
jest.mock('delay', () => jest.fn(() => Promise.resolve())) jest.mock("delay", () => jest.fn(() => Promise.resolve()))
describe('OpenRouterHandler', () => { describe("OpenRouterHandler", () => {
const mockOptions: ApiHandlerOptions = { const mockOptions: ApiHandlerOptions = {
openRouterApiKey: 'test-key', openRouterApiKey: "test-key",
openRouterModelId: 'test-model', openRouterModelId: "test-model",
openRouterModelInfo: { openRouterModelInfo: {
name: 'Test Model', name: "Test Model",
description: 'Test Description', description: "Test Description",
maxTokens: 1000, maxTokens: 1000,
contextWindow: 2000, contextWindow: 2000,
supportsPromptCache: true, supportsPromptCache: true,
inputPrice: 0.01, inputPrice: 0.01,
outputPrice: 0.02 outputPrice: 0.02,
} as ModelInfo } as ModelInfo,
} }
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks() jest.clearAllMocks()
}) })
test('constructor initializes with correct options', () => { test("constructor initializes with correct options", () => {
const handler = new OpenRouterHandler(mockOptions) const handler = new OpenRouterHandler(mockOptions)
expect(handler).toBeInstanceOf(OpenRouterHandler) expect(handler).toBeInstanceOf(OpenRouterHandler)
expect(OpenAI).toHaveBeenCalledWith({ expect(OpenAI).toHaveBeenCalledWith({
baseURL: 'https://openrouter.ai/api/v1', baseURL: "https://openrouter.ai/api/v1",
apiKey: mockOptions.openRouterApiKey, apiKey: mockOptions.openRouterApiKey,
defaultHeaders: { defaultHeaders: {
'HTTP-Referer': 'https://github.com/RooVetGit/Roo-Cline', "HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
'X-Title': 'Roo-Cline', "X-Title": "Roo-Cline",
}, },
}) })
}) })
test('getModel returns correct model info when options are provided', () => { test("getModel returns correct model info when options are provided", () => {
const handler = new OpenRouterHandler(mockOptions) const handler = new OpenRouterHandler(mockOptions)
const result = handler.getModel() const result = handler.getModel()
expect(result).toEqual({ expect(result).toEqual({
id: mockOptions.openRouterModelId, id: mockOptions.openRouterModelId,
info: mockOptions.openRouterModelInfo info: mockOptions.openRouterModelInfo,
}) })
}) })
test('getModel returns default model info when options are not provided', () => { test("getModel returns default model info when options are not provided", () => {
const handler = new OpenRouterHandler({}) const handler = new OpenRouterHandler({})
const result = handler.getModel() const result = handler.getModel()
expect(result.id).toBe('anthropic/claude-3.5-sonnet:beta') expect(result.id).toBe("anthropic/claude-3.5-sonnet:beta")
expect(result.info.supportsPromptCache).toBe(true) expect(result.info.supportsPromptCache).toBe(true)
}) })
test('createMessage generates correct stream chunks', async () => { test("createMessage generates correct stream chunks", async () => {
const handler = new OpenRouterHandler(mockOptions) const handler = new OpenRouterHandler(mockOptions)
const mockStream = { const mockStream = {
async *[Symbol.asyncIterator]() { async *[Symbol.asyncIterator]() {
yield { yield {
id: 'test-id', id: "test-id",
choices: [{ choices: [
{
delta: { delta: {
content: 'test response' content: "test response",
} },
}] },
} ],
} }
},
} }
// Mock OpenAI chat.completions.create // Mock OpenAI chat.completions.create
const mockCreate = jest.fn().mockResolvedValue(mockStream) const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = { ;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate } completions: { create: mockCreate },
} as any } as any
// Mock axios.get for generation details // Mock axios.get for generation details
@@ -86,13 +88,13 @@ describe('OpenRouterHandler', () => {
data: { data: {
native_tokens_prompt: 10, native_tokens_prompt: 10,
native_tokens_completion: 20, native_tokens_completion: 20,
total_cost: 0.001 total_cost: 0.001,
} },
} },
}) })
const systemPrompt = 'test system prompt' const systemPrompt = "test system prompt"
const messages: Anthropic.Messages.MessageParam[] = [{ role: 'user' as const, content: 'test message' }] const messages: Anthropic.Messages.MessageParam[] = [{ role: "user" as const, content: "test message" }]
const generator = handler.createMessage(systemPrompt, messages) const generator = handler.createMessage(systemPrompt, messages)
const chunks = [] const chunks = []
@@ -104,180 +106,192 @@ describe('OpenRouterHandler', () => {
// Verify stream chunks // Verify stream chunks
expect(chunks).toHaveLength(2) // One text chunk and one usage chunk expect(chunks).toHaveLength(2) // One text chunk and one usage chunk
expect(chunks[0]).toEqual({ expect(chunks[0]).toEqual({
type: 'text', type: "text",
text: 'test response' text: "test response",
}) })
expect(chunks[1]).toEqual({ expect(chunks[1]).toEqual({
type: 'usage', type: "usage",
inputTokens: 10, inputTokens: 10,
outputTokens: 20, outputTokens: 20,
totalCost: 0.001, totalCost: 0.001,
fullResponseText: 'test response' fullResponseText: "test response",
}) })
// Verify OpenAI client was called with correct parameters // Verify OpenAI client was called with correct parameters
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: mockOptions.openRouterModelId, model: mockOptions.openRouterModelId,
temperature: 0, temperature: 0,
messages: expect.arrayContaining([ messages: expect.arrayContaining([
{ role: 'system', content: systemPrompt }, { role: "system", content: systemPrompt },
{ role: 'user', content: 'test message' } { role: "user", content: "test message" },
]), ]),
stream: true stream: true,
})) }),
)
}) })
test('createMessage with middle-out transform enabled', async () => { test("createMessage with middle-out transform enabled", async () => {
const handler = new OpenRouterHandler({ const handler = new OpenRouterHandler({
...mockOptions, ...mockOptions,
openRouterUseMiddleOutTransform: true openRouterUseMiddleOutTransform: true,
}) })
const mockStream = { const mockStream = {
async *[Symbol.asyncIterator]() { async *[Symbol.asyncIterator]() {
yield { yield {
id: 'test-id', id: "test-id",
choices: [{ choices: [
{
delta: { delta: {
content: 'test response' content: "test response",
} },
}] },
} ],
} }
},
} }
const mockCreate = jest.fn().mockResolvedValue(mockStream) const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = { ;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate } completions: { create: mockCreate },
} as any } as any
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } }) ;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
await handler.createMessage('test', []).next() await handler.createMessage("test", []).next()
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ expect(mockCreate).toHaveBeenCalledWith(
transforms: ['middle-out'] expect.objectContaining({
})) transforms: ["middle-out"],
}),
)
}) })
test('createMessage with Claude model adds cache control', async () => { test("createMessage with Claude model adds cache control", async () => {
const handler = new OpenRouterHandler({ const handler = new OpenRouterHandler({
...mockOptions, ...mockOptions,
openRouterModelId: 'anthropic/claude-3.5-sonnet' openRouterModelId: "anthropic/claude-3.5-sonnet",
}) })
const mockStream = { const mockStream = {
async *[Symbol.asyncIterator]() { async *[Symbol.asyncIterator]() {
yield { yield {
id: 'test-id', id: "test-id",
choices: [{ choices: [
{
delta: { delta: {
content: 'test response' content: "test response",
} },
}] },
} ],
} }
},
} }
const mockCreate = jest.fn().mockResolvedValue(mockStream) const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = { ;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate } completions: { create: mockCreate },
} as any } as any
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } }) ;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
const messages: Anthropic.Messages.MessageParam[] = [ const messages: Anthropic.Messages.MessageParam[] = [
{ role: 'user', content: 'message 1' }, { role: "user", content: "message 1" },
{ role: 'assistant', content: 'response 1' }, { role: "assistant", content: "response 1" },
{ role: 'user', content: 'message 2' } { role: "user", content: "message 2" },
] ]
await handler.createMessage('test system', messages).next() await handler.createMessage("test system", messages).next()
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.arrayContaining([ messages: expect.arrayContaining([
expect.objectContaining({ expect.objectContaining({
role: 'system', role: "system",
content: expect.arrayContaining([ content: expect.arrayContaining([
expect.objectContaining({ expect.objectContaining({
cache_control: { type: 'ephemeral' } cache_control: { type: "ephemeral" },
}) }),
]) ]),
}) }),
]) ]),
})) }),
)
}) })
test('createMessage handles API errors', async () => { test("createMessage handles API errors", async () => {
const handler = new OpenRouterHandler(mockOptions) const handler = new OpenRouterHandler(mockOptions)
const mockStream = { const mockStream = {
async *[Symbol.asyncIterator]() { async *[Symbol.asyncIterator]() {
yield { yield {
error: { error: {
message: 'API Error', message: "API Error",
code: 500 code: 500,
} },
}
} }
},
} }
const mockCreate = jest.fn().mockResolvedValue(mockStream) const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = { ;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate } completions: { create: mockCreate },
} as any } as any
const generator = handler.createMessage('test', []) const generator = handler.createMessage("test", [])
await expect(generator.next()).rejects.toThrow('OpenRouter API Error 500: API Error') await expect(generator.next()).rejects.toThrow("OpenRouter API Error 500: API Error")
}) })
test('completePrompt returns correct response', async () => { test("completePrompt returns correct response", async () => {
const handler = new OpenRouterHandler(mockOptions) const handler = new OpenRouterHandler(mockOptions)
const mockResponse = { const mockResponse = {
choices: [{ choices: [
{
message: { message: {
content: 'test completion' content: "test completion",
} },
}] },
],
} }
const mockCreate = jest.fn().mockResolvedValue(mockResponse) const mockCreate = jest.fn().mockResolvedValue(mockResponse)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = { ;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate } completions: { create: mockCreate },
} as any } as any
const result = await handler.completePrompt('test prompt') const result = await handler.completePrompt("test prompt")
expect(result).toBe('test completion') expect(result).toBe("test completion")
expect(mockCreate).toHaveBeenCalledWith({ expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.openRouterModelId, model: mockOptions.openRouterModelId,
messages: [{ role: 'user', content: 'test prompt' }], messages: [{ role: "user", content: "test prompt" }],
temperature: 0, temperature: 0,
stream: false stream: false,
}) })
}) })
test('completePrompt handles API errors', async () => { test("completePrompt handles API errors", async () => {
const handler = new OpenRouterHandler(mockOptions) const handler = new OpenRouterHandler(mockOptions)
const mockError = { const mockError = {
error: { error: {
message: 'API Error', message: "API Error",
code: 500 code: 500,
} },
} }
const mockCreate = jest.fn().mockResolvedValue(mockError) const mockCreate = jest.fn().mockResolvedValue(mockError)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = { ;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate } completions: { create: mockCreate },
} as any } as any
await expect(handler.completePrompt('test prompt')) await expect(handler.completePrompt("test prompt")).rejects.toThrow("OpenRouter API Error 500: API Error")
.rejects.toThrow('OpenRouter API Error 500: API Error')
}) })
test('completePrompt handles unexpected errors', async () => { test("completePrompt handles unexpected errors", async () => {
const handler = new OpenRouterHandler(mockOptions) const handler = new OpenRouterHandler(mockOptions)
const mockCreate = jest.fn().mockRejectedValue(new Error('Unexpected error')) const mockCreate = jest.fn().mockRejectedValue(new Error("Unexpected error"))
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = { ;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate } completions: { create: mockCreate },
} as any } as any
await expect(handler.completePrompt('test prompt')) await expect(handler.completePrompt("test prompt")).rejects.toThrow(
.rejects.toThrow('OpenRouter completion error: Unexpected error') "OpenRouter completion error: Unexpected error",
)
}) })
}) })

View File

@@ -1,296 +1,295 @@
import { VertexHandler } from '../vertex'; import { VertexHandler } from "../vertex"
import { Anthropic } from '@anthropic-ai/sdk'; import { Anthropic } from "@anthropic-ai/sdk"
import { AnthropicVertex } from '@anthropic-ai/vertex-sdk'; import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
// Mock Vertex SDK // Mock Vertex SDK
jest.mock('@anthropic-ai/vertex-sdk', () => ({ jest.mock("@anthropic-ai/vertex-sdk", () => ({
AnthropicVertex: jest.fn().mockImplementation(() => ({ AnthropicVertex: jest.fn().mockImplementation(() => ({
messages: { messages: {
create: jest.fn().mockImplementation(async (options) => { create: jest.fn().mockImplementation(async (options) => {
if (!options.stream) { if (!options.stream) {
return { return {
id: 'test-completion', id: "test-completion",
content: [ content: [{ type: "text", text: "Test response" }],
{ type: 'text', text: 'Test response' } role: "assistant",
],
role: 'assistant',
model: options.model, model: options.model,
usage: { usage: {
input_tokens: 10, input_tokens: 10,
output_tokens: 5 output_tokens: 5,
} },
} }
} }
return { return {
async *[Symbol.asyncIterator]() { async *[Symbol.asyncIterator]() {
yield { yield {
type: 'message_start', type: "message_start",
message: { message: {
usage: { usage: {
input_tokens: 10, input_tokens: 10,
output_tokens: 5 output_tokens: 5,
} },
} },
} }
yield { yield {
type: 'content_block_start', type: "content_block_start",
content_block: { content_block: {
type: 'text', type: "text",
text: 'Test response' text: "Test response",
},
} }
},
} }
} }),
} },
}) })),
} }))
}))
}));
describe('VertexHandler', () => { describe("VertexHandler", () => {
let handler: VertexHandler; let handler: VertexHandler
beforeEach(() => { beforeEach(() => {
handler = new VertexHandler({ handler = new VertexHandler({
apiModelId: 'claude-3-5-sonnet-v2@20241022', apiModelId: "claude-3-5-sonnet-v2@20241022",
vertexProjectId: 'test-project', vertexProjectId: "test-project",
vertexRegion: 'us-central1' vertexRegion: "us-central1",
}); })
}); })
describe('constructor', () => { describe("constructor", () => {
it('should initialize with provided config', () => { it("should initialize with provided config", () => {
expect(AnthropicVertex).toHaveBeenCalledWith({ expect(AnthropicVertex).toHaveBeenCalledWith({
projectId: 'test-project', projectId: "test-project",
region: 'us-central1' region: "us-central1",
}); })
}); })
}); })
describe('createMessage', () => { describe("createMessage", () => {
const mockMessages: Anthropic.Messages.MessageParam[] = [ const mockMessages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: "user",
content: 'Hello' content: "Hello",
}, },
{ {
role: 'assistant', role: "assistant",
content: 'Hi there!' content: "Hi there!",
} },
]; ]
const systemPrompt = 'You are a helpful assistant'; const systemPrompt = "You are a helpful assistant"
it('should handle streaming responses correctly', async () => { it("should handle streaming responses correctly", async () => {
const mockStream = [ const mockStream = [
{ {
type: 'message_start', type: "message_start",
message: { message: {
usage: { usage: {
input_tokens: 10, input_tokens: 10,
output_tokens: 0 output_tokens: 0,
} },
} },
}, },
{ {
type: 'content_block_start', type: "content_block_start",
index: 0, index: 0,
content_block: { content_block: {
type: 'text', type: "text",
text: 'Hello' text: "Hello",
} },
}, },
{ {
type: 'content_block_delta', type: "content_block_delta",
delta: { delta: {
type: 'text_delta', type: "text_delta",
text: ' world!' text: " world!",
} },
}, },
{ {
type: 'message_delta', type: "message_delta",
usage: { usage: {
output_tokens: 5 output_tokens: 5,
} },
} },
]; ]
// Setup async iterator for mock stream // Setup async iterator for mock stream
const asyncIterator = { const asyncIterator = {
async *[Symbol.asyncIterator]() { async *[Symbol.asyncIterator]() {
for (const chunk of mockStream) { for (const chunk of mockStream) {
yield chunk; yield chunk
} }
},
} }
};
const mockCreate = jest.fn().mockResolvedValue(asyncIterator); const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
(handler['client'].messages as any).create = mockCreate; ;(handler["client"].messages as any).create = mockCreate
const stream = handler.createMessage(systemPrompt, mockMessages); const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []; const chunks = []
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk)
} }
expect(chunks.length).toBe(4); expect(chunks.length).toBe(4)
expect(chunks[0]).toEqual({ expect(chunks[0]).toEqual({
type: 'usage', type: "usage",
inputTokens: 10, inputTokens: 10,
outputTokens: 0 outputTokens: 0,
}); })
expect(chunks[1]).toEqual({ expect(chunks[1]).toEqual({
type: 'text', type: "text",
text: 'Hello' text: "Hello",
}); })
expect(chunks[2]).toEqual({ expect(chunks[2]).toEqual({
type: 'text', type: "text",
text: ' world!' text: " world!",
}); })
expect(chunks[3]).toEqual({ expect(chunks[3]).toEqual({
type: 'usage', type: "usage",
inputTokens: 0, inputTokens: 0,
outputTokens: 5 outputTokens: 5,
}); })
expect(mockCreate).toHaveBeenCalledWith({ expect(mockCreate).toHaveBeenCalledWith({
model: 'claude-3-5-sonnet-v2@20241022', model: "claude-3-5-sonnet-v2@20241022",
max_tokens: 8192, max_tokens: 8192,
temperature: 0, temperature: 0,
system: systemPrompt, system: systemPrompt,
messages: mockMessages, messages: mockMessages,
stream: true stream: true,
}); })
}); })
it('should handle multiple content blocks with line breaks', async () => { it("should handle multiple content blocks with line breaks", async () => {
const mockStream = [ const mockStream = [
{ {
type: 'content_block_start', type: "content_block_start",
index: 0, index: 0,
content_block: { content_block: {
type: 'text', type: "text",
text: 'First line' text: "First line",
} },
}, },
{ {
type: 'content_block_start', type: "content_block_start",
index: 1, index: 1,
content_block: { content_block: {
type: 'text', type: "text",
text: 'Second line' text: "Second line",
} },
} },
]; ]
const asyncIterator = { const asyncIterator = {
async *[Symbol.asyncIterator]() { async *[Symbol.asyncIterator]() {
for (const chunk of mockStream) { for (const chunk of mockStream) {
yield chunk; yield chunk
} }
},
} }
};
const mockCreate = jest.fn().mockResolvedValue(asyncIterator); const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
(handler['client'].messages as any).create = mockCreate; ;(handler["client"].messages as any).create = mockCreate
const stream = handler.createMessage(systemPrompt, mockMessages); const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []; const chunks = []
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk)
} }
expect(chunks.length).toBe(3); expect(chunks.length).toBe(3)
expect(chunks[0]).toEqual({ expect(chunks[0]).toEqual({
type: 'text', type: "text",
text: 'First line' text: "First line",
}); })
expect(chunks[1]).toEqual({ expect(chunks[1]).toEqual({
type: 'text', type: "text",
text: '\n' text: "\n",
}); })
expect(chunks[2]).toEqual({ expect(chunks[2]).toEqual({
type: 'text', type: "text",
text: 'Second line' text: "Second line",
}); })
}); })
it('should handle API errors', async () => { it("should handle API errors", async () => {
const mockError = new Error('Vertex API error'); const mockError = new Error("Vertex API error")
const mockCreate = jest.fn().mockRejectedValue(mockError); const mockCreate = jest.fn().mockRejectedValue(mockError)
(handler['client'].messages as any).create = mockCreate; ;(handler["client"].messages as any).create = mockCreate
const stream = handler.createMessage(systemPrompt, mockMessages); const stream = handler.createMessage(systemPrompt, mockMessages)
await expect(async () => { await expect(async () => {
for await (const chunk of stream) { for await (const chunk of stream) {
// Should throw before yielding any chunks // Should throw before yielding any chunks
} }
}).rejects.toThrow('Vertex API error'); }).rejects.toThrow("Vertex API error")
}); })
}); })
describe('completePrompt', () => { describe("completePrompt", () => {
it('should complete prompt successfully', async () => { it("should complete prompt successfully", async () => {
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe('Test response'); expect(result).toBe("Test response")
expect(handler['client'].messages.create).toHaveBeenCalledWith({ expect(handler["client"].messages.create).toHaveBeenCalledWith({
model: 'claude-3-5-sonnet-v2@20241022', model: "claude-3-5-sonnet-v2@20241022",
max_tokens: 8192, max_tokens: 8192,
temperature: 0, temperature: 0,
messages: [{ role: 'user', content: 'Test prompt' }], messages: [{ role: "user", content: "Test prompt" }],
stream: false stream: false,
}); })
}); })
it('should handle API errors', async () => { it("should handle API errors", async () => {
const mockError = new Error('Vertex API error'); const mockError = new Error("Vertex API error")
const mockCreate = jest.fn().mockRejectedValue(mockError); const mockCreate = jest.fn().mockRejectedValue(mockError)
(handler['client'].messages as any).create = mockCreate; ;(handler["client"].messages as any).create = mockCreate
await expect(handler.completePrompt('Test prompt')) await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
.rejects.toThrow('Vertex completion error: Vertex API error'); "Vertex completion error: Vertex API error",
}); )
})
it('should handle non-text content', async () => { it("should handle non-text content", async () => {
const mockCreate = jest.fn().mockResolvedValue({ const mockCreate = jest.fn().mockResolvedValue({
content: [{ type: 'image' }] content: [{ type: "image" }],
}); })
(handler['client'].messages as any).create = mockCreate; ;(handler["client"].messages as any).create = mockCreate
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe(''); expect(result).toBe("")
}); })
it('should handle empty response', async () => { it("should handle empty response", async () => {
const mockCreate = jest.fn().mockResolvedValue({ const mockCreate = jest.fn().mockResolvedValue({
content: [{ type: 'text', text: '' }] content: [{ type: "text", text: "" }],
}); })
(handler['client'].messages as any).create = mockCreate; ;(handler["client"].messages as any).create = mockCreate
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe(''); expect(result).toBe("")
}); })
}); })
describe('getModel', () => { describe("getModel", () => {
it('should return correct model info', () => { it("should return correct model info", () => {
const modelInfo = handler.getModel(); const modelInfo = handler.getModel()
expect(modelInfo.id).toBe('claude-3-5-sonnet-v2@20241022'); expect(modelInfo.id).toBe("claude-3-5-sonnet-v2@20241022")
expect(modelInfo.info).toBeDefined(); expect(modelInfo.info).toBeDefined()
expect(modelInfo.info.maxTokens).toBe(8192); expect(modelInfo.info.maxTokens).toBe(8192)
expect(modelInfo.info.contextWindow).toBe(200_000); expect(modelInfo.info.contextWindow).toBe(200_000)
}); })
it('should return default model if invalid model specified', () => { it("should return default model if invalid model specified", () => {
const invalidHandler = new VertexHandler({ const invalidHandler = new VertexHandler({
apiModelId: 'invalid-model', apiModelId: "invalid-model",
vertexProjectId: 'test-project', vertexProjectId: "test-project",
vertexRegion: 'us-central1' vertexRegion: "us-central1",
}); })
const modelInfo = invalidHandler.getModel(); const modelInfo = invalidHandler.getModel()
expect(modelInfo.id).toBe('claude-3-5-sonnet-v2@20241022'); // Default model expect(modelInfo.id).toBe("claude-3-5-sonnet-v2@20241022") // Default model
}); })
}); })
}); })

View File

@@ -1,289 +1,295 @@
import * as vscode from 'vscode'; import * as vscode from "vscode"
import { VsCodeLmHandler } from '../vscode-lm'; import { VsCodeLmHandler } from "../vscode-lm"
import { ApiHandlerOptions } from '../../../shared/api'; import { ApiHandlerOptions } from "../../../shared/api"
import { Anthropic } from '@anthropic-ai/sdk'; import { Anthropic } from "@anthropic-ai/sdk"
// Mock vscode namespace // Mock vscode namespace
jest.mock('vscode', () => { jest.mock("vscode", () => {
class MockLanguageModelTextPart { class MockLanguageModelTextPart {
type = 'text'; type = "text"
constructor(public value: string) {} constructor(public value: string) {}
} }
class MockLanguageModelToolCallPart { class MockLanguageModelToolCallPart {
type = 'tool_call'; type = "tool_call"
constructor( constructor(
public callId: string, public callId: string,
public name: string, public name: string,
public input: any public input: any,
) {} ) {}
} }
return { return {
workspace: { workspace: {
onDidChangeConfiguration: jest.fn((callback) => ({ onDidChangeConfiguration: jest.fn((callback) => ({
dispose: jest.fn() dispose: jest.fn(),
})) })),
}, },
CancellationTokenSource: jest.fn(() => ({ CancellationTokenSource: jest.fn(() => ({
token: { token: {
isCancellationRequested: false, isCancellationRequested: false,
onCancellationRequested: jest.fn() onCancellationRequested: jest.fn(),
}, },
cancel: jest.fn(), cancel: jest.fn(),
dispose: jest.fn() dispose: jest.fn(),
})), })),
CancellationError: class CancellationError extends Error { CancellationError: class CancellationError extends Error {
constructor() { constructor() {
super('Operation cancelled'); super("Operation cancelled")
this.name = 'CancellationError'; this.name = "CancellationError"
} }
}, },
LanguageModelChatMessage: { LanguageModelChatMessage: {
Assistant: jest.fn((content) => ({ Assistant: jest.fn((content) => ({
role: 'assistant', role: "assistant",
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)] content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
})), })),
User: jest.fn((content) => ({ User: jest.fn((content) => ({
role: 'user', role: "user",
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)] content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
})) })),
}, },
LanguageModelTextPart: MockLanguageModelTextPart, LanguageModelTextPart: MockLanguageModelTextPart,
LanguageModelToolCallPart: MockLanguageModelToolCallPart, LanguageModelToolCallPart: MockLanguageModelToolCallPart,
lm: { lm: {
selectChatModels: jest.fn() selectChatModels: jest.fn(),
},
} }
}; })
});
const mockLanguageModelChat = { const mockLanguageModelChat = {
id: 'test-model', id: "test-model",
name: 'Test Model', name: "Test Model",
vendor: 'test-vendor', vendor: "test-vendor",
family: 'test-family', family: "test-family",
version: '1.0', version: "1.0",
maxInputTokens: 4096, maxInputTokens: 4096,
sendRequest: jest.fn(), sendRequest: jest.fn(),
countTokens: jest.fn() countTokens: jest.fn(),
}; }
describe('VsCodeLmHandler', () => { describe("VsCodeLmHandler", () => {
let handler: VsCodeLmHandler; let handler: VsCodeLmHandler
const defaultOptions: ApiHandlerOptions = { const defaultOptions: ApiHandlerOptions = {
vsCodeLmModelSelector: { vsCodeLmModelSelector: {
vendor: 'test-vendor', vendor: "test-vendor",
family: 'test-family' family: "test-family",
},
} }
};
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks(); jest.clearAllMocks()
handler = new VsCodeLmHandler(defaultOptions); handler = new VsCodeLmHandler(defaultOptions)
}); })
afterEach(() => { afterEach(() => {
handler.dispose(); handler.dispose()
}); })
describe('constructor', () => { describe("constructor", () => {
it('should initialize with provided options', () => { it("should initialize with provided options", () => {
expect(handler).toBeDefined(); expect(handler).toBeDefined()
expect(vscode.workspace.onDidChangeConfiguration).toHaveBeenCalled(); expect(vscode.workspace.onDidChangeConfiguration).toHaveBeenCalled()
}); })
it('should handle configuration changes', () => { it("should handle configuration changes", () => {
const callback = (vscode.workspace.onDidChangeConfiguration as jest.Mock).mock.calls[0][0]; const callback = (vscode.workspace.onDidChangeConfiguration as jest.Mock).mock.calls[0][0]
callback({ affectsConfiguration: () => true }); callback({ affectsConfiguration: () => true })
// Should reset client when config changes // Should reset client when config changes
expect(handler['client']).toBeNull(); expect(handler["client"]).toBeNull()
}); })
}); })
describe('createClient', () => { describe("createClient", () => {
it('should create client with selector', async () => { it("should create client with selector", async () => {
const mockModel = { ...mockLanguageModelChat }; const mockModel = { ...mockLanguageModelChat }
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]); ;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
const client = await handler['createClient']({ const client = await handler["createClient"]({
vendor: 'test-vendor', vendor: "test-vendor",
family: 'test-family' family: "test-family",
}); })
expect(client).toBeDefined(); expect(client).toBeDefined()
expect(client.id).toBe('test-model'); expect(client.id).toBe("test-model")
expect(vscode.lm.selectChatModels).toHaveBeenCalledWith({ expect(vscode.lm.selectChatModels).toHaveBeenCalledWith({
vendor: 'test-vendor', vendor: "test-vendor",
family: 'test-family' family: "test-family",
}); })
}); })
it('should return default client when no models available', async () => { it("should return default client when no models available", async () => {
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([]); ;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([])
const client = await handler['createClient']({}); const client = await handler["createClient"]({})
expect(client).toBeDefined(); expect(client).toBeDefined()
expect(client.id).toBe('default-lm'); expect(client.id).toBe("default-lm")
expect(client.vendor).toBe('vscode'); expect(client.vendor).toBe("vscode")
}); })
}); })
describe('createMessage', () => { describe("createMessage", () => {
beforeEach(() => { beforeEach(() => {
const mockModel = { ...mockLanguageModelChat }; const mockModel = { ...mockLanguageModelChat }
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]); ;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
mockLanguageModelChat.countTokens.mockResolvedValue(10); mockLanguageModelChat.countTokens.mockResolvedValue(10)
}); })
it('should stream text responses', async () => { it("should stream text responses", async () => {
const systemPrompt = 'You are a helpful assistant'; const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [{ const messages: Anthropic.Messages.MessageParam[] = [
role: 'user' as const, {
content: 'Hello' role: "user" as const,
}]; content: "Hello",
},
]
const responseText = 'Hello! How can I help you?'; const responseText = "Hello! How can I help you?"
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({ mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
stream: (async function* () { stream: (async function* () {
yield new vscode.LanguageModelTextPart(responseText); yield new vscode.LanguageModelTextPart(responseText)
return; return
})(), })(),
text: (async function* () { text: (async function* () {
yield responseText; yield responseText
return; return
})() })(),
}); })
const stream = handler.createMessage(systemPrompt, messages); const stream = handler.createMessage(systemPrompt, messages)
const chunks = []; const chunks = []
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk)
} }
expect(chunks).toHaveLength(2); // Text chunk + usage chunk expect(chunks).toHaveLength(2) // Text chunk + usage chunk
expect(chunks[0]).toEqual({ expect(chunks[0]).toEqual({
type: 'text', type: "text",
text: responseText text: responseText,
}); })
expect(chunks[1]).toMatchObject({ expect(chunks[1]).toMatchObject({
type: 'usage', type: "usage",
inputTokens: expect.any(Number), inputTokens: expect.any(Number),
outputTokens: expect.any(Number) outputTokens: expect.any(Number),
}); })
}); })
it('should handle tool calls', async () => { it("should handle tool calls", async () => {
const systemPrompt = 'You are a helpful assistant'; const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [{ const messages: Anthropic.Messages.MessageParam[] = [
role: 'user' as const, {
content: 'Calculate 2+2' role: "user" as const,
}]; content: "Calculate 2+2",
},
]
const toolCallData = { const toolCallData = {
name: 'calculator', name: "calculator",
arguments: { operation: 'add', numbers: [2, 2] }, arguments: { operation: "add", numbers: [2, 2] },
callId: 'call-1' callId: "call-1",
}; }
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({ mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
stream: (async function* () { stream: (async function* () {
yield new vscode.LanguageModelToolCallPart( yield new vscode.LanguageModelToolCallPart(
toolCallData.callId, toolCallData.callId,
toolCallData.name, toolCallData.name,
toolCallData.arguments toolCallData.arguments,
); )
return; return
})(), })(),
text: (async function* () { text: (async function* () {
yield JSON.stringify({ type: 'tool_call', ...toolCallData }); yield JSON.stringify({ type: "tool_call", ...toolCallData })
return; return
})() })(),
}); })
const stream = handler.createMessage(systemPrompt, messages); const stream = handler.createMessage(systemPrompt, messages)
const chunks = []; const chunks = []
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk)
} }
expect(chunks).toHaveLength(2); // Tool call chunk + usage chunk expect(chunks).toHaveLength(2) // Tool call chunk + usage chunk
expect(chunks[0]).toEqual({ expect(chunks[0]).toEqual({
type: 'text', type: "text",
text: JSON.stringify({ type: 'tool_call', ...toolCallData }) text: JSON.stringify({ type: "tool_call", ...toolCallData }),
}); })
}); })
it('should handle errors', async () => { it("should handle errors", async () => {
const systemPrompt = 'You are a helpful assistant'; const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [{ const messages: Anthropic.Messages.MessageParam[] = [
role: 'user' as const, {
content: 'Hello' role: "user" as const,
}]; content: "Hello",
},
]
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error('API Error')); mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("API Error"))
await expect(async () => { await expect(async () => {
const stream = handler.createMessage(systemPrompt, messages); const stream = handler.createMessage(systemPrompt, messages)
for await (const _ of stream) { for await (const _ of stream) {
// consume stream // consume stream
} }
}).rejects.toThrow('API Error'); }).rejects.toThrow("API Error")
}); })
}); })
describe('getModel', () => { describe("getModel", () => {
it('should return model info when client exists', async () => { it("should return model info when client exists", async () => {
const mockModel = { ...mockLanguageModelChat }; const mockModel = { ...mockLanguageModelChat }
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]); ;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
// Initialize client // Initialize client
await handler['getClient'](); await handler["getClient"]()
const model = handler.getModel(); const model = handler.getModel()
expect(model.id).toBe('test-model'); expect(model.id).toBe("test-model")
expect(model.info).toBeDefined(); expect(model.info).toBeDefined()
expect(model.info.contextWindow).toBe(4096); expect(model.info.contextWindow).toBe(4096)
}); })
it('should return fallback model info when no client exists', () => { it("should return fallback model info when no client exists", () => {
const model = handler.getModel(); const model = handler.getModel()
expect(model.id).toBe('test-vendor/test-family'); expect(model.id).toBe("test-vendor/test-family")
expect(model.info).toBeDefined(); expect(model.info).toBeDefined()
}); })
}); })
describe('completePrompt', () => { describe("completePrompt", () => {
it('should complete single prompt', async () => { it("should complete single prompt", async () => {
const mockModel = { ...mockLanguageModelChat }; const mockModel = { ...mockLanguageModelChat }
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]); ;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
const responseText = 'Completed text'; const responseText = "Completed text"
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({ mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
stream: (async function* () { stream: (async function* () {
yield new vscode.LanguageModelTextPart(responseText); yield new vscode.LanguageModelTextPart(responseText)
return; return
})(), })(),
text: (async function* () { text: (async function* () {
yield responseText; yield responseText
return; return
})() })(),
}); })
const result = await handler.completePrompt('Test prompt'); const result = await handler.completePrompt("Test prompt")
expect(result).toBe(responseText); expect(result).toBe(responseText)
expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled(); expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled()
}); })
it('should handle errors during completion', async () => { it("should handle errors during completion", async () => {
const mockModel = { ...mockLanguageModelChat }; const mockModel = { ...mockLanguageModelChat }
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]); ;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error('Completion failed')); mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("Completion failed"))
await expect(handler.completePrompt('Test prompt')) await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
.rejects "VSCode LM completion error: Completion failed",
.toThrow('VSCode LM completion error: Completion failed'); )
}); })
}); })
}); })

View File

@@ -181,14 +181,14 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
max_tokens: this.getModel().info.maxTokens || 8192, max_tokens: this.getModel().info.maxTokens || 8192,
temperature: 0, temperature: 0,
messages: [{ role: "user", content: prompt }], messages: [{ role: "user", content: prompt }],
stream: false stream: false,
}) })
const content = response.content[0] const content = response.content[0]
if (content.type === 'text') { if (content.type === "text") {
return content.text return content.text
} }
return '' return ""
} catch (error) { } catch (error) {
if (error instanceof Error) { if (error instanceof Error) {
throw new Error(`Anthropic completion error: ${error.message}`) throw new Error(`Anthropic completion error: ${error.message}`)

View File

@@ -1,4 +1,9 @@
import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime" import {
BedrockRuntimeClient,
ConverseStreamCommand,
ConverseCommand,
BedrockRuntimeClientConfig,
} from "@aws-sdk/client-bedrock-runtime"
import { Anthropic } from "@anthropic-ai/sdk" import { Anthropic } from "@anthropic-ai/sdk"
import { ApiHandler, SingleCompletionHandler } from "../" import { ApiHandler, SingleCompletionHandler } from "../"
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api" import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
@@ -8,34 +13,34 @@ import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../
// Define types for stream events based on AWS SDK // Define types for stream events based on AWS SDK
export interface StreamEvent { export interface StreamEvent {
messageStart?: { messageStart?: {
role?: string; role?: string
}; }
messageStop?: { messageStop?: {
stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence"; stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence"
additionalModelResponseFields?: Record<string, unknown>; additionalModelResponseFields?: Record<string, unknown>
}; }
contentBlockStart?: { contentBlockStart?: {
start?: { start?: {
text?: string; text?: string
}; }
contentBlockIndex?: number; contentBlockIndex?: number
}; }
contentBlockDelta?: { contentBlockDelta?: {
delta?: { delta?: {
text?: string; text?: string
}; }
contentBlockIndex?: number; contentBlockIndex?: number
}; }
metadata?: { metadata?: {
usage?: { usage?: {
inputTokens: number; inputTokens: number
outputTokens: number; outputTokens: number
totalTokens?: number; // Made optional since we don't use it totalTokens?: number // Made optional since we don't use it
}; }
metrics?: { metrics?: {
latencyMs: number; latencyMs: number
}; }
}; }
} }
export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler { export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
@@ -47,7 +52,7 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
// Only include credentials if they actually exist // Only include credentials if they actually exist
const clientConfig: BedrockRuntimeClientConfig = { const clientConfig: BedrockRuntimeClientConfig = {
region: this.options.awsRegion || "us-east-1" region: this.options.awsRegion || "us-east-1",
} }
if (this.options.awsAccessKey && this.options.awsSecretKey) { if (this.options.awsAccessKey && this.options.awsSecretKey) {
@@ -55,7 +60,7 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
clientConfig.credentials = { clientConfig.credentials = {
accessKeyId: this.options.awsAccessKey, accessKeyId: this.options.awsAccessKey,
secretAccessKey: this.options.awsSecretKey, secretAccessKey: this.options.awsSecretKey,
...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {}) ...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {}),
} }
} }
@@ -96,12 +101,14 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
maxTokens: modelConfig.info.maxTokens || 5000, maxTokens: modelConfig.info.maxTokens || 5000,
temperature: 0.3, temperature: 0.3,
topP: 0.1, topP: 0.1,
...(this.options.awsUsePromptCache ? { ...(this.options.awsUsePromptCache
? {
promptCache: { promptCache: {
promptCacheId: this.options.awspromptCacheId || "" promptCacheId: this.options.awspromptCacheId || "",
} },
} : {})
} }
: {}),
},
} }
try { try {
@@ -109,18 +116,16 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
const response = await this.client.send(command) const response = await this.client.send(command)
if (!response.stream) { if (!response.stream) {
throw new Error('No stream available in the response') throw new Error("No stream available in the response")
} }
for await (const chunk of response.stream) { for await (const chunk of response.stream) {
// Parse the chunk as JSON if it's a string (for tests) // Parse the chunk as JSON if it's a string (for tests)
let streamEvent: StreamEvent let streamEvent: StreamEvent
try { try {
streamEvent = typeof chunk === 'string' ? streamEvent = typeof chunk === "string" ? JSON.parse(chunk) : (chunk as unknown as StreamEvent)
JSON.parse(chunk) :
chunk as unknown as StreamEvent
} catch (e) { } catch (e) {
console.error('Failed to parse stream event:', e) console.error("Failed to parse stream event:", e)
continue continue
} }
@@ -129,7 +134,7 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
yield { yield {
type: "usage", type: "usage",
inputTokens: streamEvent.metadata.usage.inputTokens || 0, inputTokens: streamEvent.metadata.usage.inputTokens || 0,
outputTokens: streamEvent.metadata.usage.outputTokens || 0 outputTokens: streamEvent.metadata.usage.outputTokens || 0,
} }
continue continue
} }
@@ -143,7 +148,7 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
if (streamEvent.contentBlockStart?.start?.text) { if (streamEvent.contentBlockStart?.start?.text) {
yield { yield {
type: "text", type: "text",
text: streamEvent.contentBlockStart.start.text text: streamEvent.contentBlockStart.start.text,
} }
continue continue
} }
@@ -152,7 +157,7 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
if (streamEvent.contentBlockDelta?.delta?.text) { if (streamEvent.contentBlockDelta?.delta?.text) {
yield { yield {
type: "text", type: "text",
text: streamEvent.contentBlockDelta.delta.text text: streamEvent.contentBlockDelta.delta.text,
} }
continue continue
} }
@@ -162,32 +167,31 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
continue continue
} }
} }
} catch (error: unknown) { } catch (error: unknown) {
console.error('Bedrock Runtime API Error:', error) console.error("Bedrock Runtime API Error:", error)
// Only access stack if error is an Error object // Only access stack if error is an Error object
if (error instanceof Error) { if (error instanceof Error) {
console.error('Error stack:', error.stack) console.error("Error stack:", error.stack)
yield { yield {
type: "text", type: "text",
text: `Error: ${error.message}` text: `Error: ${error.message}`,
} }
yield { yield {
type: "usage", type: "usage",
inputTokens: 0, inputTokens: 0,
outputTokens: 0 outputTokens: 0,
} }
throw error throw error
} else { } else {
const unknownError = new Error("An unknown error occurred") const unknownError = new Error("An unknown error occurred")
yield { yield {
type: "text", type: "text",
text: unknownError.message text: unknownError.message,
} }
yield { yield {
type: "usage", type: "usage",
inputTokens: 0, inputTokens: 0,
outputTokens: 0 outputTokens: 0,
} }
throw unknownError throw unknownError
} }
@@ -198,14 +202,14 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
const modelId = this.options.apiModelId const modelId = this.options.apiModelId
if (modelId) { if (modelId) {
// For tests, allow any model ID // For tests, allow any model ID
if (process.env.NODE_ENV === 'test') { if (process.env.NODE_ENV === "test") {
return { return {
id: modelId, id: modelId,
info: { info: {
maxTokens: 5000, maxTokens: 5000,
contextWindow: 128_000, contextWindow: 128_000,
supportsPromptCache: false supportsPromptCache: false,
} },
} }
} }
// For production, validate against known models // For production, validate against known models
@@ -216,7 +220,7 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
} }
return { return {
id: bedrockDefaultModelId, id: bedrockDefaultModelId,
info: bedrockModels[bedrockDefaultModelId] info: bedrockModels[bedrockDefaultModelId],
} }
} }
@@ -245,15 +249,17 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
const payload = { const payload = {
modelId, modelId,
messages: convertToBedrockConverseMessages([{ messages: convertToBedrockConverseMessages([
{
role: "user", role: "user",
content: prompt content: prompt,
}]), },
]),
inferenceConfig: { inferenceConfig: {
maxTokens: modelConfig.info.maxTokens || 5000, maxTokens: modelConfig.info.maxTokens || 5000,
temperature: 0.3, temperature: 0.3,
topP: 0.1 topP: 0.1,
} },
} }
const command = new ConverseCommand(payload) const command = new ConverseCommand(payload)
@@ -267,10 +273,10 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
return output.content return output.content
} }
} catch (parseError) { } catch (parseError) {
console.error('Failed to parse Bedrock response:', parseError) console.error("Failed to parse Bedrock response:", parseError)
} }
} }
return '' return ""
} catch (error) { } catch (error) {
if (error instanceof Error) { if (error instanceof Error) {
throw new Error(`Bedrock completion error: ${error.message}`) throw new Error(`Bedrock completion error: ${error.message}`)

View File

@@ -12,7 +12,7 @@ export class DeepSeekHandler extends OpenAiHandler {
openAiApiKey: options.deepSeekApiKey, openAiApiKey: options.deepSeekApiKey,
openAiModelId: options.deepSeekModelId ?? deepSeekDefaultModelId, openAiModelId: options.deepSeekModelId ?? deepSeekDefaultModelId,
openAiBaseUrl: options.deepSeekBaseUrl ?? "https://api.deepseek.com/v1", openAiBaseUrl: options.deepSeekBaseUrl ?? "https://api.deepseek.com/v1",
includeMaxTokens: true includeMaxTokens: true,
}) })
} }
@@ -20,7 +20,7 @@ export class DeepSeekHandler extends OpenAiHandler {
const modelId = this.options.deepSeekModelId ?? deepSeekDefaultModelId const modelId = this.options.deepSeekModelId ?? deepSeekDefaultModelId
return { return {
id: modelId, id: modelId,
info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId] info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId],
} }
} }
} }

View File

@@ -72,17 +72,17 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
maxTokens = 8_192 maxTokens = 8_192
} }
const { data: completion, response } = await this.client.chat.completions.create({ const { data: completion, response } = await this.client.chat.completions
.create({
model: this.getModel().id, model: this.getModel().id,
max_tokens: maxTokens, max_tokens: maxTokens,
temperature: 0, temperature: 0,
messages: openAiMessages, messages: openAiMessages,
stream: true, stream: true,
}).withResponse(); })
.withResponse()
const completionRequestId = response.headers.get( const completionRequestId = response.headers.get("x-completion-request-id")
'x-completion-request-id',
);
for await (const chunk of completion) { for await (const chunk of completion) {
const delta = chunk.choices[0]?.delta const delta = chunk.choices[0]?.delta
@@ -96,13 +96,16 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
} }
try { try {
const response = await axios.get(`https://glama.ai/api/gateway/v1/completion-requests/${completionRequestId}`, { const response = await axios.get(
`https://glama.ai/api/gateway/v1/completion-requests/${completionRequestId}`,
{
headers: { headers: {
Authorization: `Bearer ${this.options.glamaApiKey}`, Authorization: `Bearer ${this.options.glamaApiKey}`,
}, },
}) },
)
const completionRequest = response.data; const completionRequest = response.data
if (completionRequest.tokenUsage) { if (completionRequest.tokenUsage) {
yield { yield {

View File

@@ -60,7 +60,7 @@ export class LmStudioHandler implements ApiHandler, SingleCompletionHandler {
model: this.getModel().id, model: this.getModel().id,
messages: [{ role: "user", content: prompt }], messages: [{ role: "user", content: prompt }],
temperature: 0, temperature: 0,
stream: false stream: false,
}) })
return response.choices[0]?.message.content || "" return response.choices[0]?.message.content || ""
} catch (error) { } catch (error) {

View File

@@ -53,7 +53,7 @@ export class OllamaHandler implements ApiHandler, SingleCompletionHandler {
model: this.getModel().id, model: this.getModel().id,
messages: [{ role: "user", content: prompt }], messages: [{ role: "user", content: prompt }],
temperature: 0, temperature: 0,
stream: false stream: false,
}) })
return response.choices[0]?.message.content || "" return response.choices[0]?.message.content || ""
} catch (error) { } catch (error) {

View File

@@ -32,7 +32,10 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
// o1 doesnt support streaming or non-1 temp but does support a developer prompt // o1 doesnt support streaming or non-1 temp but does support a developer prompt
const response = await this.client.chat.completions.create({ const response = await this.client.chat.completions.create({
model: modelId, model: modelId,
messages: [{ role: modelId === "o1" ? "developer" : "user", content: systemPrompt }, ...convertToOpenAiMessages(messages)], messages: [
{ role: modelId === "o1" ? "developer" : "user", content: systemPrompt },
...convertToOpenAiMessages(messages),
],
}) })
yield { yield {
type: "text", type: "text",
@@ -98,14 +101,14 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
// o1 doesn't support non-1 temp // o1 doesn't support non-1 temp
requestOptions = { requestOptions = {
model: modelId, model: modelId,
messages: [{ role: "user", content: prompt }] messages: [{ role: "user", content: prompt }],
} }
break break
default: default:
requestOptions = { requestOptions = {
model: modelId, model: modelId,
messages: [{ role: "user", content: prompt }], messages: [{ role: "user", content: prompt }],
temperature: 0 temperature: 0,
} }
} }

View File

@@ -17,7 +17,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
constructor(options: ApiHandlerOptions) { constructor(options: ApiHandlerOptions) {
this.options = options this.options = options
// Azure API shape slightly differs from the core API shape: https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai // Azure API shape slightly differs from the core API shape: https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
const urlHost = new URL(this.options.openAiBaseUrl ?? "").host; const urlHost = new URL(this.options.openAiBaseUrl ?? "").host
if (urlHost === "azure.com" || urlHost.endsWith(".azure.com")) { if (urlHost === "azure.com" || urlHost.endsWith(".azure.com")) {
this.client = new AzureOpenAI({ this.client = new AzureOpenAI({
baseURL: this.options.openAiBaseUrl, baseURL: this.options.openAiBaseUrl,
@@ -39,7 +39,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
if (this.options.openAiStreamingEnabled ?? true) { if (this.options.openAiStreamingEnabled ?? true) {
const systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = { const systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = {
role: "system", role: "system",
content: systemPrompt content: systemPrompt,
} }
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: modelId, model: modelId,
@@ -74,7 +74,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
// o1 for instance doesnt support streaming, non-1 temp, or system prompt // o1 for instance doesnt support streaming, non-1 temp, or system prompt
const systemMessage: OpenAI.Chat.ChatCompletionUserMessageParam = { const systemMessage: OpenAI.Chat.ChatCompletionUserMessageParam = {
role: "user", role: "user",
content: systemPrompt content: systemPrompt,
} }
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: modelId, model: modelId,

View File

@@ -9,12 +9,12 @@ import delay from "delay"
// Add custom interface for OpenRouter params // Add custom interface for OpenRouter params
type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & { type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & {
transforms?: string[]; transforms?: string[]
} }
// Add custom interface for OpenRouter usage chunk // Add custom interface for OpenRouter usage chunk
interface OpenRouterApiStreamUsageChunk extends ApiStreamUsageChunk { interface OpenRouterApiStreamUsageChunk extends ApiStreamUsageChunk {
fullResponseText: string; fullResponseText: string
} }
import { SingleCompletionHandler } from ".." import { SingleCompletionHandler } from ".."
@@ -35,7 +35,10 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
}) })
} }
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): AsyncGenerator<ApiStreamChunk> { async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
): AsyncGenerator<ApiStreamChunk> {
// Convert Anthropic messages to OpenAI format // Convert Anthropic messages to OpenAI format
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt }, { role: "system", content: systemPrompt },
@@ -108,7 +111,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
break break
} }
// https://openrouter.ai/docs/transforms // https://openrouter.ai/docs/transforms
let fullResponseText = ""; let fullResponseText = ""
const stream = await this.client.chat.completions.create({ const stream = await this.client.chat.completions.create({
model: this.getModel().id, model: this.getModel().id,
max_tokens: maxTokens, max_tokens: maxTokens,
@@ -116,8 +119,8 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
messages: openAiMessages, messages: openAiMessages,
stream: true, stream: true,
// This way, the transforms field will only be included in the parameters when openRouterUseMiddleOutTransform is true. // This way, the transforms field will only be included in the parameters when openRouterUseMiddleOutTransform is true.
...(this.options.openRouterUseMiddleOutTransform && { transforms: ["middle-out"] }) ...(this.options.openRouterUseMiddleOutTransform && { transforms: ["middle-out"] }),
} as OpenRouterChatCompletionParams); } as OpenRouterChatCompletionParams)
let genId: string | undefined let genId: string | undefined
@@ -135,11 +138,11 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
const delta = chunk.choices[0]?.delta const delta = chunk.choices[0]?.delta
if (delta?.content) { if (delta?.content) {
fullResponseText += delta.content; fullResponseText += delta.content
yield { yield {
type: "text", type: "text",
text: delta.content, text: delta.content,
} as ApiStreamChunk; } as ApiStreamChunk
} }
// if (chunk.usage) { // if (chunk.usage) {
// yield { // yield {
@@ -170,13 +173,12 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
inputTokens: generation?.native_tokens_prompt || 0, inputTokens: generation?.native_tokens_prompt || 0,
outputTokens: generation?.native_tokens_completion || 0, outputTokens: generation?.native_tokens_completion || 0,
totalCost: generation?.total_cost || 0, totalCost: generation?.total_cost || 0,
fullResponseText fullResponseText,
} as OpenRouterApiStreamUsageChunk; } as OpenRouterApiStreamUsageChunk
} catch (error) { } catch (error) {
// ignore if fails // ignore if fails
console.error("Error fetching OpenRouter generation details:", error) console.error("Error fetching OpenRouter generation details:", error)
} }
} }
getModel(): { id: string; info: ModelInfo } { getModel(): { id: string; info: ModelInfo } {
const modelId = this.options.openRouterModelId const modelId = this.options.openRouterModelId
@@ -193,7 +195,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
model: this.getModel().id, model: this.getModel().id,
messages: [{ role: "user", content: prompt }], messages: [{ role: "user", content: prompt }],
temperature: 0, temperature: 0,
stream: false stream: false,
}) })
if ("error" in response) { if ("error" in response) {

View File

@@ -91,14 +91,14 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
max_tokens: this.getModel().info.maxTokens || 8192, max_tokens: this.getModel().info.maxTokens || 8192,
temperature: 0, temperature: 0,
messages: [{ role: "user", content: prompt }], messages: [{ role: "user", content: prompt }],
stream: false stream: false,
}) })
const content = response.content[0] const content = response.content[0]
if (content.type === 'text') { if (content.type === "text") {
return content.text return content.text
} }
return '' return ""
} catch (error) { } catch (error) {
if (error instanceof Error) { if (error instanceof Error) {
throw new Error(`Vertex completion error: ${error.message}`) throw new Error(`Vertex completion error: ${error.message}`)

View File

@@ -1,11 +1,11 @@
import { Anthropic } from "@anthropic-ai/sdk"; import { Anthropic } from "@anthropic-ai/sdk"
import * as vscode from 'vscode'; import * as vscode from "vscode"
import { ApiHandler, SingleCompletionHandler } from "../"; import { ApiHandler, SingleCompletionHandler } from "../"
import { calculateApiCost } from "../../utils/cost"; import { calculateApiCost } from "../../utils/cost"
import { ApiStream } from "../transform/stream"; import { ApiStream } from "../transform/stream"
import { convertToVsCodeLmMessages } from "../transform/vscode-lm-format"; import { convertToVsCodeLmMessages } from "../transform/vscode-lm-format"
import { SELECTOR_SEPARATOR, stringifyVsCodeLmModelSelector } from "../../shared/vsCodeSelectorUtils"; import { SELECTOR_SEPARATOR, stringifyVsCodeLmModelSelector } from "../../shared/vsCodeSelectorUtils"
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"; import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
/** /**
* Handles interaction with VS Code's Language Model API for chat-based operations. * Handles interaction with VS Code's Language Model API for chat-based operations.
@@ -35,39 +35,36 @@ import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../..
* ``` * ```
*/ */
export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler { export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
private options: ApiHandlerOptions; private client: vscode.LanguageModelChat | null
private client: vscode.LanguageModelChat | null; private disposable: vscode.Disposable | null
private disposable: vscode.Disposable | null; private currentRequestCancellation: vscode.CancellationTokenSource | null
private currentRequestCancellation: vscode.CancellationTokenSource | null;
constructor(options: ApiHandlerOptions) { constructor(options: ApiHandlerOptions) {
this.options = options; this.options = options
this.client = null; this.client = null
this.disposable = null; this.disposable = null
this.currentRequestCancellation = null; this.currentRequestCancellation = null
try { try {
// Listen for model changes and reset client // Listen for model changes and reset client
this.disposable = vscode.workspace.onDidChangeConfiguration(event => { this.disposable = vscode.workspace.onDidChangeConfiguration((event) => {
if (event.affectsConfiguration('lm')) { if (event.affectsConfiguration("lm")) {
try { try {
this.client = null; this.client = null
this.ensureCleanState(); this.ensureCleanState()
} } catch (error) {
catch (error) { console.error("Error during configuration change cleanup:", error)
console.error('Error during configuration change cleanup:', error);
} }
} }
}); })
} } catch (error) {
catch (error) {
// Ensure cleanup if constructor fails // Ensure cleanup if constructor fails
this.dispose(); this.dispose()
throw new Error( throw new Error(
`Cline <Language Model API>: Failed to initialize handler: ${error instanceof Error ? error.message : 'Unknown error'}` `Cline <Language Model API>: Failed to initialize handler: ${error instanceof Error ? error.message : "Unknown error"}`,
); )
} }
} }
@@ -84,39 +81,39 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
*/ */
async createClient(selector: vscode.LanguageModelChatSelector): Promise<vscode.LanguageModelChat> { async createClient(selector: vscode.LanguageModelChatSelector): Promise<vscode.LanguageModelChat> {
try { try {
const models = await vscode.lm.selectChatModels(selector); const models = await vscode.lm.selectChatModels(selector)
// Use first available model or create a minimal model object // Use first available model or create a minimal model object
if (models && Array.isArray(models) && models.length > 0) { if (models && Array.isArray(models) && models.length > 0) {
return models[0]; return models[0]
} }
// Create a minimal model if no models are available // Create a minimal model if no models are available
return { return {
id: 'default-lm', id: "default-lm",
name: 'Default Language Model', name: "Default Language Model",
vendor: 'vscode', vendor: "vscode",
family: 'lm', family: "lm",
version: '1.0', version: "1.0",
maxInputTokens: 8192, maxInputTokens: 8192,
sendRequest: async (messages, options, token) => { sendRequest: async (messages, options, token) => {
// Provide a minimal implementation // Provide a minimal implementation
return { return {
stream: (async function* () { stream: (async function* () {
yield new vscode.LanguageModelTextPart( yield new vscode.LanguageModelTextPart(
"Language model functionality is limited. Please check VS Code configuration." "Language model functionality is limited. Please check VS Code configuration.",
); )
})(), })(),
text: (async function* () { text: (async function* () {
yield "Language model functionality is limited. Please check VS Code configuration."; yield "Language model functionality is limited. Please check VS Code configuration."
})() })(),
}; }
}, },
countTokens: async () => 0 countTokens: async () => 0,
}; }
} catch (error) { } catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error'; const errorMessage = error instanceof Error ? error.message : "Unknown error"
throw new Error(`Cline <Language Model API>: Failed to select model: ${errorMessage}`); throw new Error(`Cline <Language Model API>: Failed to select model: ${errorMessage}`)
} }
} }
@@ -137,230 +134,222 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
* Tool calls handling is currently a work in progress. * Tool calls handling is currently a work in progress.
*/ */
dispose(): void { dispose(): void {
if (this.disposable) { if (this.disposable) {
this.disposable.dispose()
this.disposable.dispose();
} }
if (this.currentRequestCancellation) { if (this.currentRequestCancellation) {
this.currentRequestCancellation.cancel()
this.currentRequestCancellation.cancel(); this.currentRequestCancellation.dispose()
this.currentRequestCancellation.dispose();
} }
} }
private async countTokens(text: string | vscode.LanguageModelChatMessage): Promise<number> { private async countTokens(text: string | vscode.LanguageModelChatMessage): Promise<number> {
// Check for required dependencies // Check for required dependencies
if (!this.client) { if (!this.client) {
console.warn('Cline <Language Model API>: No client available for token counting'); console.warn("Cline <Language Model API>: No client available for token counting")
return 0; return 0
} }
if (!this.currentRequestCancellation) { if (!this.currentRequestCancellation) {
console.warn('Cline <Language Model API>: No cancellation token available for token counting'); console.warn("Cline <Language Model API>: No cancellation token available for token counting")
return 0; return 0
} }
// Validate input // Validate input
if (!text) { if (!text) {
console.debug('Cline <Language Model API>: Empty text provided for token counting'); console.debug("Cline <Language Model API>: Empty text provided for token counting")
return 0; return 0
} }
try { try {
// Handle different input types // Handle different input types
let tokenCount: number; let tokenCount: number
if (typeof text === 'string') { if (typeof text === "string") {
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token); tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
} else if (text instanceof vscode.LanguageModelChatMessage) { } else if (text instanceof vscode.LanguageModelChatMessage) {
// For chat messages, ensure we have content // For chat messages, ensure we have content
if (!text.content || (Array.isArray(text.content) && text.content.length === 0)) { if (!text.content || (Array.isArray(text.content) && text.content.length === 0)) {
console.debug('Cline <Language Model API>: Empty chat message content'); console.debug("Cline <Language Model API>: Empty chat message content")
return 0; return 0
} }
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token); tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
} else { } else {
console.warn('Cline <Language Model API>: Invalid input type for token counting'); console.warn("Cline <Language Model API>: Invalid input type for token counting")
return 0; return 0
} }
// Validate the result // Validate the result
if (typeof tokenCount !== 'number') { if (typeof tokenCount !== "number") {
console.warn('Cline <Language Model API>: Non-numeric token count received:', tokenCount); console.warn("Cline <Language Model API>: Non-numeric token count received:", tokenCount)
return 0; return 0
} }
if (tokenCount < 0) { if (tokenCount < 0) {
console.warn('Cline <Language Model API>: Negative token count received:', tokenCount); console.warn("Cline <Language Model API>: Negative token count received:", tokenCount)
return 0; return 0
} }
return tokenCount; return tokenCount
} } catch (error) {
catch (error) {
// Handle specific error types // Handle specific error types
if (error instanceof vscode.CancellationError) { if (error instanceof vscode.CancellationError) {
console.debug('Cline <Language Model API>: Token counting cancelled by user'); console.debug("Cline <Language Model API>: Token counting cancelled by user")
return 0; return 0
} }
const errorMessage = error instanceof Error ? error.message : 'Unknown error'; const errorMessage = error instanceof Error ? error.message : "Unknown error"
console.warn('Cline <Language Model API>: Token counting failed:', errorMessage); console.warn("Cline <Language Model API>: Token counting failed:", errorMessage)
// Log additional error details if available // Log additional error details if available
if (error instanceof Error && error.stack) { if (error instanceof Error && error.stack) {
console.debug('Token counting error stack:', error.stack); console.debug("Token counting error stack:", error.stack)
} }
return 0; // Fallback to prevent stream interruption return 0 // Fallback to prevent stream interruption
} }
} }
private async calculateTotalInputTokens(systemPrompt: string, vsCodeLmMessages: vscode.LanguageModelChatMessage[]): Promise<number> { private async calculateTotalInputTokens(
systemPrompt: string,
vsCodeLmMessages: vscode.LanguageModelChatMessage[],
): Promise<number> {
const systemTokens: number = await this.countTokens(systemPrompt)
const systemTokens: number = await this.countTokens(systemPrompt); const messageTokens: number[] = await Promise.all(vsCodeLmMessages.map((msg) => this.countTokens(msg)))
const messageTokens: number[] = await Promise.all( return systemTokens + messageTokens.reduce((sum: number, tokens: number): number => sum + tokens, 0)
vsCodeLmMessages.map(msg => this.countTokens(msg))
);
return systemTokens + messageTokens.reduce(
(sum: number, tokens: number): number => sum + tokens, 0
);
} }
private ensureCleanState(): void { private ensureCleanState(): void {
if (this.currentRequestCancellation) { if (this.currentRequestCancellation) {
this.currentRequestCancellation.cancel()
this.currentRequestCancellation.cancel(); this.currentRequestCancellation.dispose()
this.currentRequestCancellation.dispose(); this.currentRequestCancellation = null
this.currentRequestCancellation = null;
} }
} }
private async getClient(): Promise<vscode.LanguageModelChat> { private async getClient(): Promise<vscode.LanguageModelChat> {
if (!this.client) { if (!this.client) {
console.debug('Cline <Language Model API>: Getting client with options:', { console.debug("Cline <Language Model API>: Getting client with options:", {
vsCodeLmModelSelector: this.options.vsCodeLmModelSelector, vsCodeLmModelSelector: this.options.vsCodeLmModelSelector,
hasOptions: !!this.options, hasOptions: !!this.options,
selectorKeys: this.options.vsCodeLmModelSelector ? Object.keys(this.options.vsCodeLmModelSelector) : [] selectorKeys: this.options.vsCodeLmModelSelector ? Object.keys(this.options.vsCodeLmModelSelector) : [],
}); })
try { try {
// Use default empty selector if none provided to get all available models // Use default empty selector if none provided to get all available models
const selector = this.options?.vsCodeLmModelSelector || {}; const selector = this.options?.vsCodeLmModelSelector || {}
console.debug('Cline <Language Model API>: Creating client with selector:', selector); console.debug("Cline <Language Model API>: Creating client with selector:", selector)
this.client = await this.createClient(selector); this.client = await this.createClient(selector)
} catch (error) { } catch (error) {
const message = error instanceof Error ? error.message : 'Unknown error'; const message = error instanceof Error ? error.message : "Unknown error"
console.error('Cline <Language Model API>: Client creation failed:', message); console.error("Cline <Language Model API>: Client creation failed:", message)
throw new Error(`Cline <Language Model API>: Failed to create client: ${message}`); throw new Error(`Cline <Language Model API>: Failed to create client: ${message}`)
} }
} }
return this.client; return this.client
} }
private cleanTerminalOutput(text: string): string { private cleanTerminalOutput(text: string): string {
if (!text) { if (!text) {
return ''; return ""
} }
return text return (
text
// Нормализуем переносы строк // Нормализуем переносы строк
.replace(/\r\n/g, '\n') .replace(/\r\n/g, "\n")
.replace(/\r/g, '\n') .replace(/\r/g, "\n")
// Удаляем ANSI escape sequences // Удаляем ANSI escape sequences
.replace(/\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])/g, '') // Полный набор ANSI sequences .replace(/\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])/g, "") // Полный набор ANSI sequences
.replace(/\x9B[0-?]*[ -/]*[@-~]/g, '') // CSI sequences .replace(/\x9B[0-?]*[ -/]*[@-~]/g, "") // CSI sequences
// Удаляем последовательности установки заголовка терминала и прочие OSC sequences // Удаляем последовательности установки заголовка терминала и прочие OSC sequences
.replace(/\x1B\][0-9;]*(?:\x07|\x1B\\)/g, '') .replace(/\x1B\][0-9;]*(?:\x07|\x1B\\)/g, "")
// Удаляем управляющие символы // Удаляем управляющие символы
.replace(/[\x00-\x09\x0B-\x0C\x0E-\x1F\x7F]/g, '') .replace(/[\x00-\x09\x0B-\x0C\x0E-\x1F\x7F]/g, "")
// Удаляем escape-последовательности VS Code // Удаляем escape-последовательности VS Code
.replace(/\x1B[PD].*?\x1B\\/g, '') // DCS sequences .replace(/\x1B[PD].*?\x1B\\/g, "") // DCS sequences
.replace(/\x1B_.*?\x1B\\/g, '') // APC sequences .replace(/\x1B_.*?\x1B\\/g, "") // APC sequences
.replace(/\x1B\^.*?\x1B\\/g, '') // PM sequences .replace(/\x1B\^.*?\x1B\\/g, "") // PM sequences
.replace(/\x1B\[[\d;]*[HfABCDEFGJKST]/g, '') // Cursor movement and clear screen .replace(/\x1B\[[\d;]*[HfABCDEFGJKST]/g, "") // Cursor movement and clear screen
// Удаляем пути Windows и служебную информацию // Удаляем пути Windows и служебную информацию
.replace(/^(?:PS )?[A-Z]:\\[^\n]*$/mg, '') .replace(/^(?:PS )?[A-Z]:\\[^\n]*$/gm, "")
.replace(/^;?Cwd=.*$/mg, '') .replace(/^;?Cwd=.*$/gm, "")
// Очищаем экранированные последовательности // Очищаем экранированные последовательности
.replace(/\\x[0-9a-fA-F]{2}/g, '') .replace(/\\x[0-9a-fA-F]{2}/g, "")
.replace(/\\u[0-9a-fA-F]{4}/g, '') .replace(/\\u[0-9a-fA-F]{4}/g, "")
// Финальная очистка // Финальная очистка
.replace(/\n{3,}/g, '\n\n') // Убираем множественные пустые строки .replace(/\n{3,}/g, "\n\n") // Убираем множественные пустые строки
.trim(); .trim()
)
} }
private cleanMessageContent(content: any): any { private cleanMessageContent(content: any): any {
if (!content) { if (!content) {
return content; return content
} }
if (typeof content === 'string') { if (typeof content === "string") {
return this.cleanTerminalOutput(content); return this.cleanTerminalOutput(content)
} }
if (Array.isArray(content)) { if (Array.isArray(content)) {
return content.map(item => this.cleanMessageContent(item)); return content.map((item) => this.cleanMessageContent(item))
} }
if (typeof content === 'object') { if (typeof content === "object") {
const cleaned: any = {}; const cleaned: any = {}
for (const [key, value] of Object.entries(content)) { for (const [key, value] of Object.entries(content)) {
cleaned[key] = this.cleanMessageContent(value); cleaned[key] = this.cleanMessageContent(value)
} }
return cleaned; return cleaned
} }
return content; return content
} }
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
// Ensure clean state before starting a new request // Ensure clean state before starting a new request
this.ensureCleanState(); this.ensureCleanState()
const client: vscode.LanguageModelChat = await this.getClient(); const client: vscode.LanguageModelChat = await this.getClient()
// Clean system prompt and messages // Clean system prompt and messages
const cleanedSystemPrompt = this.cleanTerminalOutput(systemPrompt); const cleanedSystemPrompt = this.cleanTerminalOutput(systemPrompt)
const cleanedMessages = messages.map(msg => ({ const cleanedMessages = messages.map((msg) => ({
...msg, ...msg,
content: this.cleanMessageContent(msg.content) content: this.cleanMessageContent(msg.content),
})); }))
// Convert Anthropic messages to VS Code LM messages // Convert Anthropic messages to VS Code LM messages
const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = [ const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = [
vscode.LanguageModelChatMessage.Assistant(cleanedSystemPrompt), vscode.LanguageModelChatMessage.Assistant(cleanedSystemPrompt),
...convertToVsCodeLmMessages(cleanedMessages), ...convertToVsCodeLmMessages(cleanedMessages),
]; ]
// Initialize cancellation token for the request // Initialize cancellation token for the request
this.currentRequestCancellation = new vscode.CancellationTokenSource(); this.currentRequestCancellation = new vscode.CancellationTokenSource()
// Calculate input tokens before starting the stream // Calculate input tokens before starting the stream
const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages); const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages)
// Accumulate the text and count at the end of the stream to reduce token counting overhead. // Accumulate the text and count at the end of the stream to reduce token counting overhead.
let accumulatedText: string = ''; let accumulatedText: string = ""
try { try {
// Create the response stream with minimal required options // Create the response stream with minimal required options
const requestOptions: vscode.LanguageModelChatRequestOptions = { const requestOptions: vscode.LanguageModelChatRequestOptions = {
justification: `Cline would like to use '${client.name}' from '${client.vendor}', Click 'Allow' to proceed.` justification: `Cline would like to use '${client.name}' from '${client.vendor}', Click 'Allow' to proceed.`,
}; }
// Note: Tool support is currently provided by the VSCode Language Model API directly // Note: Tool support is currently provided by the VSCode Language Model API directly
// Extensions can register tools using vscode.lm.registerTool() // Extensions can register tools using vscode.lm.registerTool()
@@ -368,40 +357,40 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
const response: vscode.LanguageModelChatResponse = await client.sendRequest( const response: vscode.LanguageModelChatResponse = await client.sendRequest(
vsCodeLmMessages, vsCodeLmMessages,
requestOptions, requestOptions,
this.currentRequestCancellation.token this.currentRequestCancellation.token,
); )
// Consume the stream and handle both text and tool call chunks // Consume the stream and handle both text and tool call chunks
for await (const chunk of response.stream) { for await (const chunk of response.stream) {
if (chunk instanceof vscode.LanguageModelTextPart) { if (chunk instanceof vscode.LanguageModelTextPart) {
// Validate text part value // Validate text part value
if (typeof chunk.value !== 'string') { if (typeof chunk.value !== "string") {
console.warn('Cline <Language Model API>: Invalid text part value received:', chunk.value); console.warn("Cline <Language Model API>: Invalid text part value received:", chunk.value)
continue; continue
} }
accumulatedText += chunk.value; accumulatedText += chunk.value
yield { yield {
type: "text", type: "text",
text: chunk.value, text: chunk.value,
}; }
} else if (chunk instanceof vscode.LanguageModelToolCallPart) { } else if (chunk instanceof vscode.LanguageModelToolCallPart) {
try { try {
// Validate tool call parameters // Validate tool call parameters
if (!chunk.name || typeof chunk.name !== 'string') { if (!chunk.name || typeof chunk.name !== "string") {
console.warn('Cline <Language Model API>: Invalid tool name received:', chunk.name); console.warn("Cline <Language Model API>: Invalid tool name received:", chunk.name)
continue; continue
} }
if (!chunk.callId || typeof chunk.callId !== 'string') { if (!chunk.callId || typeof chunk.callId !== "string") {
console.warn('Cline <Language Model API>: Invalid tool callId received:', chunk.callId); console.warn("Cline <Language Model API>: Invalid tool callId received:", chunk.callId)
continue; continue
} }
// Ensure input is a valid object // Ensure input is a valid object
if (!chunk.input || typeof chunk.input !== 'object') { if (!chunk.input || typeof chunk.input !== "object") {
console.warn('Cline <Language Model API>: Invalid tool input received:', chunk.input); console.warn("Cline <Language Model API>: Invalid tool input received:", chunk.input)
continue; continue
} }
// Convert tool calls to text format with proper error handling // Convert tool calls to text format with proper error handling
@@ -409,82 +398,75 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
type: "tool_call", type: "tool_call",
name: chunk.name, name: chunk.name,
arguments: chunk.input, arguments: chunk.input,
callId: chunk.callId callId: chunk.callId,
}; }
const toolCallText = JSON.stringify(toolCall); const toolCallText = JSON.stringify(toolCall)
accumulatedText += toolCallText; accumulatedText += toolCallText
// Log tool call for debugging // Log tool call for debugging
console.debug('Cline <Language Model API>: Processing tool call:', { console.debug("Cline <Language Model API>: Processing tool call:", {
name: chunk.name, name: chunk.name,
callId: chunk.callId, callId: chunk.callId,
inputSize: JSON.stringify(chunk.input).length inputSize: JSON.stringify(chunk.input).length,
}); })
yield { yield {
type: "text", type: "text",
text: toolCallText, text: toolCallText,
}; }
} catch (error) { } catch (error) {
console.error('Cline <Language Model API>: Failed to process tool call:', error); console.error("Cline <Language Model API>: Failed to process tool call:", error)
// Continue processing other chunks even if one fails // Continue processing other chunks even if one fails
continue; continue
} }
} else { } else {
console.warn('Cline <Language Model API>: Unknown chunk type received:', chunk); console.warn("Cline <Language Model API>: Unknown chunk type received:", chunk)
} }
} }
// Count tokens in the accumulated text after stream completion // Count tokens in the accumulated text after stream completion
const totalOutputTokens: number = await this.countTokens(accumulatedText); const totalOutputTokens: number = await this.countTokens(accumulatedText)
// Report final usage after stream completion // Report final usage after stream completion
yield { yield {
type: "usage", type: "usage",
inputTokens: totalInputTokens, inputTokens: totalInputTokens,
outputTokens: totalOutputTokens, outputTokens: totalOutputTokens,
totalCost: calculateApiCost( totalCost: calculateApiCost(this.getModel().info, totalInputTokens, totalOutputTokens),
this.getModel().info,
totalInputTokens,
totalOutputTokens
)
};
} }
catch (error: unknown) { } catch (error: unknown) {
this.ensureCleanState()
this.ensureCleanState();
if (error instanceof vscode.CancellationError) { if (error instanceof vscode.CancellationError) {
throw new Error("Cline <Language Model API>: Request cancelled by user")
throw new Error("Cline <Language Model API>: Request cancelled by user");
} }
if (error instanceof Error) { if (error instanceof Error) {
console.error('Cline <Language Model API>: Stream error details:', { console.error("Cline <Language Model API>: Stream error details:", {
message: error.message, message: error.message,
stack: error.stack, stack: error.stack,
name: error.name name: error.name,
}); })
// Return original error if it's already an Error instance // Return original error if it's already an Error instance
throw error; throw error
} else if (typeof error === 'object' && error !== null) { } else if (typeof error === "object" && error !== null) {
// Handle error-like objects // Handle error-like objects
const errorDetails = JSON.stringify(error, null, 2); const errorDetails = JSON.stringify(error, null, 2)
console.error('Cline <Language Model API>: Stream error object:', errorDetails); console.error("Cline <Language Model API>: Stream error object:", errorDetails)
throw new Error(`Cline <Language Model API>: Response stream error: ${errorDetails}`); throw new Error(`Cline <Language Model API>: Response stream error: ${errorDetails}`)
} else { } else {
// Fallback for unknown error types // Fallback for unknown error types
const errorMessage = String(error); const errorMessage = String(error)
console.error('Cline <Language Model API>: Unknown stream error:', errorMessage); console.error("Cline <Language Model API>: Unknown stream error:", errorMessage)
throw new Error(`Cline <Language Model API>: Response stream error: ${errorMessage}`); throw new Error(`Cline <Language Model API>: Response stream error: ${errorMessage}`)
} }
} }
} }
// Return model information based on the current client state // Return model information based on the current client state
getModel(): { id: string; info: ModelInfo; } { getModel(): { id: string; info: ModelInfo } {
if (this.client) { if (this.client) {
// Validate client properties // Validate client properties
const requiredProps = { const requiredProps = {
@@ -492,68 +474,69 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
vendor: this.client.vendor, vendor: this.client.vendor,
family: this.client.family, family: this.client.family,
version: this.client.version, version: this.client.version,
maxInputTokens: this.client.maxInputTokens maxInputTokens: this.client.maxInputTokens,
}; }
// Log any missing properties for debugging // Log any missing properties for debugging
for (const [prop, value] of Object.entries(requiredProps)) { for (const [prop, value] of Object.entries(requiredProps)) {
if (!value && value !== 0) { if (!value && value !== 0) {
console.warn(`Cline <Language Model API>: Client missing ${prop} property`); console.warn(`Cline <Language Model API>: Client missing ${prop} property`)
} }
} }
// Construct model ID using available information // Construct model ID using available information
const modelParts = [ const modelParts = [this.client.vendor, this.client.family, this.client.version].filter(Boolean)
this.client.vendor,
this.client.family,
this.client.version
].filter(Boolean);
const modelId = this.client.id || modelParts.join(SELECTOR_SEPARATOR); const modelId = this.client.id || modelParts.join(SELECTOR_SEPARATOR)
// Build model info with conservative defaults for missing values // Build model info with conservative defaults for missing values
const modelInfo: ModelInfo = { const modelInfo: ModelInfo = {
maxTokens: -1, // Unlimited tokens by default maxTokens: -1, // Unlimited tokens by default
contextWindow: typeof this.client.maxInputTokens === 'number' contextWindow:
typeof this.client.maxInputTokens === "number"
? Math.max(0, this.client.maxInputTokens) ? Math.max(0, this.client.maxInputTokens)
: openAiModelInfoSaneDefaults.contextWindow, : openAiModelInfoSaneDefaults.contextWindow,
supportsImages: false, // VSCode Language Model API currently doesn't support image inputs supportsImages: false, // VSCode Language Model API currently doesn't support image inputs
supportsPromptCache: true, supportsPromptCache: true,
inputPrice: 0, inputPrice: 0,
outputPrice: 0, outputPrice: 0,
description: `VSCode Language Model: ${modelId}` description: `VSCode Language Model: ${modelId}`,
}; }
return { id: modelId, info: modelInfo }; return { id: modelId, info: modelInfo }
} }
// Fallback when no client is available // Fallback when no client is available
const fallbackId = this.options.vsCodeLmModelSelector const fallbackId = this.options.vsCodeLmModelSelector
? stringifyVsCodeLmModelSelector(this.options.vsCodeLmModelSelector) ? stringifyVsCodeLmModelSelector(this.options.vsCodeLmModelSelector)
: "vscode-lm"; : "vscode-lm"
console.debug('Cline <Language Model API>: No client available, using fallback model info'); console.debug("Cline <Language Model API>: No client available, using fallback model info")
return { return {
id: fallbackId, id: fallbackId,
info: { info: {
...openAiModelInfoSaneDefaults, ...openAiModelInfoSaneDefaults,
description: `VSCode Language Model (Fallback): ${fallbackId}` description: `VSCode Language Model (Fallback): ${fallbackId}`,
},
} }
};
} }
async completePrompt(prompt: string): Promise<string> { async completePrompt(prompt: string): Promise<string> {
try { try {
const client = await this.getClient(); const client = await this.getClient()
const response = await client.sendRequest([vscode.LanguageModelChatMessage.User(prompt)], {}, new vscode.CancellationTokenSource().token); const response = await client.sendRequest(
let result = ""; [vscode.LanguageModelChatMessage.User(prompt)],
{},
new vscode.CancellationTokenSource().token,
)
let result = ""
for await (const chunk of response.stream) { for await (const chunk of response.stream) {
if (chunk instanceof vscode.LanguageModelTextPart) { if (chunk instanceof vscode.LanguageModelTextPart) {
result += chunk.value; result += chunk.value
} }
} }
return result; return result
} catch (error) { } catch (error) {
if (error instanceof Error) { if (error instanceof Error) {
throw new Error(`VSCode LM completion error: ${error.message}`) throw new Error(`VSCode LM completion error: ${error.message}`)

View File

@@ -1,251 +1,249 @@
import { convertToBedrockConverseMessages, convertToAnthropicMessage } from '../bedrock-converse-format' import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../bedrock-converse-format"
import { Anthropic } from '@anthropic-ai/sdk' import { Anthropic } from "@anthropic-ai/sdk"
import { ContentBlock, ToolResultContentBlock } from '@aws-sdk/client-bedrock-runtime' import { ContentBlock, ToolResultContentBlock } from "@aws-sdk/client-bedrock-runtime"
import { StreamEvent } from '../../providers/bedrock' import { StreamEvent } from "../../providers/bedrock"
describe('bedrock-converse-format', () => { describe("bedrock-converse-format", () => {
describe('convertToBedrockConverseMessages', () => { describe("convertToBedrockConverseMessages", () => {
test('converts simple text messages correctly', () => { test("converts simple text messages correctly", () => {
const messages: Anthropic.Messages.MessageParam[] = [ const messages: Anthropic.Messages.MessageParam[] = [
{ role: 'user', content: 'Hello' }, { role: "user", content: "Hello" },
{ role: 'assistant', content: 'Hi there' } { role: "assistant", content: "Hi there" },
] ]
const result = convertToBedrockConverseMessages(messages) const result = convertToBedrockConverseMessages(messages)
expect(result).toEqual([ expect(result).toEqual([
{ {
role: 'user', role: "user",
content: [{ text: 'Hello' }] content: [{ text: "Hello" }],
}, },
{ {
role: 'assistant', role: "assistant",
content: [{ text: 'Hi there' }] content: [{ text: "Hi there" }],
} },
]) ])
}) })
test('converts messages with images correctly', () => { test("converts messages with images correctly", () => {
const messages: Anthropic.Messages.MessageParam[] = [ const messages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: "user",
content: [ content: [
{ {
type: 'text', type: "text",
text: 'Look at this image:' text: "Look at this image:",
}, },
{ {
type: 'image', type: "image",
source: { source: {
type: 'base64', type: "base64",
data: 'SGVsbG8=', // "Hello" in base64 data: "SGVsbG8=", // "Hello" in base64
media_type: 'image/jpeg' as const media_type: "image/jpeg" as const,
} },
} },
] ],
} },
] ]
const result = convertToBedrockConverseMessages(messages) const result = convertToBedrockConverseMessages(messages)
if (!result[0] || !result[0].content) { if (!result[0] || !result[0].content) {
fail('Expected result to have content') fail("Expected result to have content")
return return
} }
expect(result[0].role).toBe('user') expect(result[0].role).toBe("user")
expect(result[0].content).toHaveLength(2) expect(result[0].content).toHaveLength(2)
expect(result[0].content[0]).toEqual({ text: 'Look at this image:' }) expect(result[0].content[0]).toEqual({ text: "Look at this image:" })
const imageBlock = result[0].content[1] as ContentBlock const imageBlock = result[0].content[1] as ContentBlock
if ('image' in imageBlock && imageBlock.image && imageBlock.image.source) { if ("image" in imageBlock && imageBlock.image && imageBlock.image.source) {
expect(imageBlock.image.format).toBe('jpeg') expect(imageBlock.image.format).toBe("jpeg")
expect(imageBlock.image.source).toBeDefined() expect(imageBlock.image.source).toBeDefined()
expect(imageBlock.image.source.bytes).toBeDefined() expect(imageBlock.image.source.bytes).toBeDefined()
} else { } else {
fail('Expected image block not found') fail("Expected image block not found")
} }
}) })
test('converts tool use messages correctly', () => { test("converts tool use messages correctly", () => {
const messages: Anthropic.Messages.MessageParam[] = [ const messages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'assistant', role: "assistant",
content: [ content: [
{ {
type: 'tool_use', type: "tool_use",
id: 'test-id', id: "test-id",
name: 'read_file', name: "read_file",
input: { input: {
path: 'test.txt' path: "test.txt",
} },
} },
] ],
} },
] ]
const result = convertToBedrockConverseMessages(messages) const result = convertToBedrockConverseMessages(messages)
if (!result[0] || !result[0].content) { if (!result[0] || !result[0].content) {
fail('Expected result to have content') fail("Expected result to have content")
return return
} }
expect(result[0].role).toBe('assistant') expect(result[0].role).toBe("assistant")
const toolBlock = result[0].content[0] as ContentBlock const toolBlock = result[0].content[0] as ContentBlock
if ('toolUse' in toolBlock && toolBlock.toolUse) { if ("toolUse" in toolBlock && toolBlock.toolUse) {
expect(toolBlock.toolUse).toEqual({ expect(toolBlock.toolUse).toEqual({
toolUseId: 'test-id', toolUseId: "test-id",
name: 'read_file', name: "read_file",
input: '<read_file>\n<path>\ntest.txt\n</path>\n</read_file>' input: "<read_file>\n<path>\ntest.txt\n</path>\n</read_file>",
}) })
} else { } else {
fail('Expected tool use block not found') fail("Expected tool use block not found")
} }
}) })
test('converts tool result messages correctly', () => { test("converts tool result messages correctly", () => {
const messages: Anthropic.Messages.MessageParam[] = [ const messages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'assistant', role: "assistant",
content: [ content: [
{ {
type: 'tool_result', type: "tool_result",
tool_use_id: 'test-id', tool_use_id: "test-id",
content: [{ type: 'text', text: 'File contents here' }] content: [{ type: "text", text: "File contents here" }],
} },
] ],
} },
] ]
const result = convertToBedrockConverseMessages(messages) const result = convertToBedrockConverseMessages(messages)
if (!result[0] || !result[0].content) { if (!result[0] || !result[0].content) {
fail('Expected result to have content') fail("Expected result to have content")
return return
} }
expect(result[0].role).toBe('assistant') expect(result[0].role).toBe("assistant")
const resultBlock = result[0].content[0] as ContentBlock const resultBlock = result[0].content[0] as ContentBlock
if ('toolResult' in resultBlock && resultBlock.toolResult) { if ("toolResult" in resultBlock && resultBlock.toolResult) {
const expectedContent: ToolResultContentBlock[] = [ const expectedContent: ToolResultContentBlock[] = [{ text: "File contents here" }]
{ text: 'File contents here' }
]
expect(resultBlock.toolResult).toEqual({ expect(resultBlock.toolResult).toEqual({
toolUseId: 'test-id', toolUseId: "test-id",
content: expectedContent, content: expectedContent,
status: 'success' status: "success",
}) })
} else { } else {
fail('Expected tool result block not found') fail("Expected tool result block not found")
} }
}) })
test('handles text content correctly', () => { test("handles text content correctly", () => {
const messages: Anthropic.Messages.MessageParam[] = [ const messages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: "user",
content: [ content: [
{ {
type: 'text', type: "text",
text: 'Hello world' text: "Hello world",
} },
] ],
} },
] ]
const result = convertToBedrockConverseMessages(messages) const result = convertToBedrockConverseMessages(messages)
if (!result[0] || !result[0].content) { if (!result[0] || !result[0].content) {
fail('Expected result to have content') fail("Expected result to have content")
return return
} }
expect(result[0].role).toBe('user') expect(result[0].role).toBe("user")
expect(result[0].content).toHaveLength(1) expect(result[0].content).toHaveLength(1)
const textBlock = result[0].content[0] as ContentBlock const textBlock = result[0].content[0] as ContentBlock
expect(textBlock).toEqual({ text: 'Hello world' }) expect(textBlock).toEqual({ text: "Hello world" })
}) })
}) })
describe('convertToAnthropicMessage', () => { describe("convertToAnthropicMessage", () => {
test('converts metadata events correctly', () => { test("converts metadata events correctly", () => {
const event: StreamEvent = { const event: StreamEvent = {
metadata: { metadata: {
usage: { usage: {
inputTokens: 10, inputTokens: 10,
outputTokens: 20 outputTokens: 20,
} },
} },
} }
const result = convertToAnthropicMessage(event, 'test-model') const result = convertToAnthropicMessage(event, "test-model")
expect(result).toEqual({ expect(result).toEqual({
id: '', id: "",
type: 'message', type: "message",
role: 'assistant', role: "assistant",
model: 'test-model', model: "test-model",
usage: { usage: {
input_tokens: 10, input_tokens: 10,
output_tokens: 20 output_tokens: 20,
} },
}) })
}) })
test('converts content block start events correctly', () => { test("converts content block start events correctly", () => {
const event: StreamEvent = { const event: StreamEvent = {
contentBlockStart: { contentBlockStart: {
start: { start: {
text: 'Hello' text: "Hello",
} },
} },
} }
const result = convertToAnthropicMessage(event, 'test-model') const result = convertToAnthropicMessage(event, "test-model")
expect(result).toEqual({ expect(result).toEqual({
type: 'message', type: "message",
role: 'assistant', role: "assistant",
content: [{ type: 'text', text: 'Hello' }], content: [{ type: "text", text: "Hello" }],
model: 'test-model' model: "test-model",
}) })
}) })
test('converts content block delta events correctly', () => { test("converts content block delta events correctly", () => {
const event: StreamEvent = { const event: StreamEvent = {
contentBlockDelta: { contentBlockDelta: {
delta: { delta: {
text: ' world' text: " world",
} },
} },
} }
const result = convertToAnthropicMessage(event, 'test-model') const result = convertToAnthropicMessage(event, "test-model")
expect(result).toEqual({ expect(result).toEqual({
type: 'message', type: "message",
role: 'assistant', role: "assistant",
content: [{ type: 'text', text: ' world' }], content: [{ type: "text", text: " world" }],
model: 'test-model' model: "test-model",
}) })
}) })
test('converts message stop events correctly', () => { test("converts message stop events correctly", () => {
const event: StreamEvent = { const event: StreamEvent = {
messageStop: { messageStop: {
stopReason: 'end_turn' as const stopReason: "end_turn" as const,
} },
} }
const result = convertToAnthropicMessage(event, 'test-model') const result = convertToAnthropicMessage(event, "test-model")
expect(result).toEqual({ expect(result).toEqual({
type: 'message', type: "message",
role: 'assistant', role: "assistant",
stop_reason: 'end_turn', stop_reason: "end_turn",
stop_sequence: null, stop_sequence: null,
model: 'test-model' model: "test-model",
}) })
}) })
}) })

View File

@@ -1,257 +1,275 @@
import { convertToOpenAiMessages, convertToAnthropicMessage } from '../openai-format'; import { convertToOpenAiMessages, convertToAnthropicMessage } from "../openai-format"
import { Anthropic } from '@anthropic-ai/sdk'; import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from 'openai'; import OpenAI from "openai"
type PartialChatCompletion = Omit<OpenAI.Chat.Completions.ChatCompletion, 'choices'> & { type PartialChatCompletion = Omit<OpenAI.Chat.Completions.ChatCompletion, "choices"> & {
choices: Array<Partial<OpenAI.Chat.Completions.ChatCompletion.Choice> & { choices: Array<
message: OpenAI.Chat.Completions.ChatCompletion.Choice['message']; Partial<OpenAI.Chat.Completions.ChatCompletion.Choice> & {
finish_reason: string; message: OpenAI.Chat.Completions.ChatCompletion.Choice["message"]
index: number; finish_reason: string
}>; index: number
}; }
>
}
describe('OpenAI Format Transformations', () => { describe("OpenAI Format Transformations", () => {
describe('convertToOpenAiMessages', () => { describe("convertToOpenAiMessages", () => {
it('should convert simple text messages', () => { it("should convert simple text messages", () => {
const anthropicMessages: Anthropic.Messages.MessageParam[] = [ const anthropicMessages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: "user",
content: 'Hello' content: "Hello",
}, },
{ {
role: 'assistant', role: "assistant",
content: 'Hi there!' content: "Hi there!",
} },
]; ]
const openAiMessages = convertToOpenAiMessages(anthropicMessages); const openAiMessages = convertToOpenAiMessages(anthropicMessages)
expect(openAiMessages).toHaveLength(2); expect(openAiMessages).toHaveLength(2)
expect(openAiMessages[0]).toEqual({ expect(openAiMessages[0]).toEqual({
role: 'user', role: "user",
content: 'Hello' content: "Hello",
}); })
expect(openAiMessages[1]).toEqual({ expect(openAiMessages[1]).toEqual({
role: 'assistant', role: "assistant",
content: 'Hi there!' content: "Hi there!",
}); })
}); })
it('should handle messages with image content', () => { it("should handle messages with image content", () => {
const anthropicMessages: Anthropic.Messages.MessageParam[] = [ const anthropicMessages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: "user",
content: [ content: [
{ {
type: 'text', type: "text",
text: 'What is in this image?' text: "What is in this image?",
}, },
{ {
type: 'image', type: "image",
source: { source: {
type: 'base64', type: "base64",
media_type: 'image/jpeg', media_type: "image/jpeg",
data: 'base64data' data: "base64data",
} },
} },
],
},
] ]
}
];
const openAiMessages = convertToOpenAiMessages(anthropicMessages); const openAiMessages = convertToOpenAiMessages(anthropicMessages)
expect(openAiMessages).toHaveLength(1); expect(openAiMessages).toHaveLength(1)
expect(openAiMessages[0].role).toBe('user'); expect(openAiMessages[0].role).toBe("user")
const content = openAiMessages[0].content as Array<{ const content = openAiMessages[0].content as Array<{
type: string; type: string
text?: string; text?: string
image_url?: { url: string }; image_url?: { url: string }
}>; }>
expect(Array.isArray(content)).toBe(true); expect(Array.isArray(content)).toBe(true)
expect(content).toHaveLength(2); expect(content).toHaveLength(2)
expect(content[0]).toEqual({ type: 'text', text: 'What is in this image?' }); expect(content[0]).toEqual({ type: "text", text: "What is in this image?" })
expect(content[1]).toEqual({ expect(content[1]).toEqual({
type: 'image_url', type: "image_url",
image_url: { url: '' } image_url: { url: "" },
}); })
}); })
it('should handle assistant messages with tool use', () => { it("should handle assistant messages with tool use", () => {
const anthropicMessages: Anthropic.Messages.MessageParam[] = [ const anthropicMessages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'assistant', role: "assistant",
content: [ content: [
{ {
type: 'text', type: "text",
text: 'Let me check the weather.' text: "Let me check the weather.",
}, },
{ {
type: 'tool_use', type: "tool_use",
id: 'weather-123', id: "weather-123",
name: 'get_weather', name: "get_weather",
input: { city: 'London' } input: { city: "London" },
} },
],
},
] ]
}
];
const openAiMessages = convertToOpenAiMessages(anthropicMessages); const openAiMessages = convertToOpenAiMessages(anthropicMessages)
expect(openAiMessages).toHaveLength(1); expect(openAiMessages).toHaveLength(1)
const assistantMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionAssistantMessageParam; const assistantMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionAssistantMessageParam
expect(assistantMessage.role).toBe('assistant'); expect(assistantMessage.role).toBe("assistant")
expect(assistantMessage.content).toBe('Let me check the weather.'); expect(assistantMessage.content).toBe("Let me check the weather.")
expect(assistantMessage.tool_calls).toHaveLength(1); expect(assistantMessage.tool_calls).toHaveLength(1)
expect(assistantMessage.tool_calls![0]).toEqual({ expect(assistantMessage.tool_calls![0]).toEqual({
id: 'weather-123', id: "weather-123",
type: 'function', type: "function",
function: { function: {
name: 'get_weather', name: "get_weather",
arguments: JSON.stringify({ city: 'London' }) arguments: JSON.stringify({ city: "London" }),
} },
}); })
}); })
it('should handle user messages with tool results', () => { it("should handle user messages with tool results", () => {
const anthropicMessages: Anthropic.Messages.MessageParam[] = [ const anthropicMessages: Anthropic.Messages.MessageParam[] = [
{ {
role: 'user', role: "user",
content: [ content: [
{ {
type: 'tool_result', type: "tool_result",
tool_use_id: 'weather-123', tool_use_id: "weather-123",
content: 'Current temperature in London: 20°C' content: "Current temperature in London: 20°C",
}
]
}
];
const openAiMessages = convertToOpenAiMessages(anthropicMessages);
expect(openAiMessages).toHaveLength(1);
const toolMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionToolMessageParam;
expect(toolMessage.role).toBe('tool');
expect(toolMessage.tool_call_id).toBe('weather-123');
expect(toolMessage.content).toBe('Current temperature in London: 20°C');
});
});
describe('convertToAnthropicMessage', () => {
it('should convert simple completion', () => {
const openAiCompletion: PartialChatCompletion = {
id: 'completion-123',
model: 'gpt-4',
choices: [{
message: {
role: 'assistant',
content: 'Hello there!',
refusal: null
}, },
finish_reason: 'stop', ],
index: 0 },
}], ]
const openAiMessages = convertToOpenAiMessages(anthropicMessages)
expect(openAiMessages).toHaveLength(1)
const toolMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionToolMessageParam
expect(toolMessage.role).toBe("tool")
expect(toolMessage.tool_call_id).toBe("weather-123")
expect(toolMessage.content).toBe("Current temperature in London: 20°C")
})
})
describe("convertToAnthropicMessage", () => {
it("should convert simple completion", () => {
const openAiCompletion: PartialChatCompletion = {
id: "completion-123",
model: "gpt-4",
choices: [
{
message: {
role: "assistant",
content: "Hello there!",
refusal: null,
},
finish_reason: "stop",
index: 0,
},
],
usage: { usage: {
prompt_tokens: 10, prompt_tokens: 10,
completion_tokens: 5, completion_tokens: 5,
total_tokens: 15 total_tokens: 15,
}, },
created: 123456789, created: 123456789,
object: 'chat.completion' object: "chat.completion",
}; }
const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion); const anthropicMessage = convertToAnthropicMessage(
expect(anthropicMessage.id).toBe('completion-123'); openAiCompletion as OpenAI.Chat.Completions.ChatCompletion,
expect(anthropicMessage.role).toBe('assistant'); )
expect(anthropicMessage.content).toHaveLength(1); expect(anthropicMessage.id).toBe("completion-123")
expect(anthropicMessage.role).toBe("assistant")
expect(anthropicMessage.content).toHaveLength(1)
expect(anthropicMessage.content[0]).toEqual({ expect(anthropicMessage.content[0]).toEqual({
type: 'text', type: "text",
text: 'Hello there!' text: "Hello there!",
}); })
expect(anthropicMessage.stop_reason).toBe('end_turn'); expect(anthropicMessage.stop_reason).toBe("end_turn")
expect(anthropicMessage.usage).toEqual({ expect(anthropicMessage.usage).toEqual({
input_tokens: 10, input_tokens: 10,
output_tokens: 5 output_tokens: 5,
}); })
}); })
it('should handle tool calls in completion', () => { it("should handle tool calls in completion", () => {
const openAiCompletion: PartialChatCompletion = { const openAiCompletion: PartialChatCompletion = {
id: 'completion-123', id: "completion-123",
model: 'gpt-4', model: "gpt-4",
choices: [{ choices: [
{
message: { message: {
role: 'assistant', role: "assistant",
content: 'Let me check the weather.', content: "Let me check the weather.",
tool_calls: [{ tool_calls: [
id: 'weather-123', {
type: 'function', id: "weather-123",
type: "function",
function: { function: {
name: 'get_weather', name: "get_weather",
arguments: '{"city":"London"}' arguments: '{"city":"London"}',
}
}],
refusal: null
}, },
finish_reason: 'tool_calls', },
index: 0 ],
}], refusal: null,
},
finish_reason: "tool_calls",
index: 0,
},
],
usage: { usage: {
prompt_tokens: 15, prompt_tokens: 15,
completion_tokens: 8, completion_tokens: 8,
total_tokens: 23 total_tokens: 23,
}, },
created: 123456789, created: 123456789,
object: 'chat.completion' object: "chat.completion",
};
const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion);
expect(anthropicMessage.content).toHaveLength(2);
expect(anthropicMessage.content[0]).toEqual({
type: 'text',
text: 'Let me check the weather.'
});
expect(anthropicMessage.content[1]).toEqual({
type: 'tool_use',
id: 'weather-123',
name: 'get_weather',
input: { city: 'London' }
});
expect(anthropicMessage.stop_reason).toBe('tool_use');
});
it('should handle invalid tool call arguments', () => {
const openAiCompletion: PartialChatCompletion = {
id: 'completion-123',
model: 'gpt-4',
choices: [{
message: {
role: 'assistant',
content: 'Testing invalid arguments',
tool_calls: [{
id: 'test-123',
type: 'function',
function: {
name: 'test_function',
arguments: 'invalid json'
} }
}],
refusal: null
},
finish_reason: 'tool_calls',
index: 0
}],
created: 123456789,
object: 'chat.completion'
};
const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion); const anthropicMessage = convertToAnthropicMessage(
expect(anthropicMessage.content).toHaveLength(2); openAiCompletion as OpenAI.Chat.Completions.ChatCompletion,
)
expect(anthropicMessage.content).toHaveLength(2)
expect(anthropicMessage.content[0]).toEqual({
type: "text",
text: "Let me check the weather.",
})
expect(anthropicMessage.content[1]).toEqual({ expect(anthropicMessage.content[1]).toEqual({
type: 'tool_use', type: "tool_use",
id: 'test-123', id: "weather-123",
name: 'test_function', name: "get_weather",
input: {} // Should default to empty object for invalid JSON input: { city: "London" },
}); })
}); expect(anthropicMessage.stop_reason).toBe("tool_use")
}); })
});
it("should handle invalid tool call arguments", () => {
const openAiCompletion: PartialChatCompletion = {
id: "completion-123",
model: "gpt-4",
choices: [
{
message: {
role: "assistant",
content: "Testing invalid arguments",
tool_calls: [
{
id: "test-123",
type: "function",
function: {
name: "test_function",
arguments: "invalid json",
},
},
],
refusal: null,
},
finish_reason: "tool_calls",
index: 0,
},
],
created: 123456789,
object: "chat.completion",
}
const anthropicMessage = convertToAnthropicMessage(
openAiCompletion as OpenAI.Chat.Completions.ChatCompletion,
)
expect(anthropicMessage.content).toHaveLength(2)
expect(anthropicMessage.content[1]).toEqual({
type: "tool_use",
id: "test-123",
name: "test_function",
input: {}, // Should default to empty object for invalid JSON
})
})
})
})

View File

@@ -1,114 +1,114 @@
import { ApiStreamChunk } from '../stream'; import { ApiStreamChunk } from "../stream"
describe('API Stream Types', () => { describe("API Stream Types", () => {
describe('ApiStreamChunk', () => { describe("ApiStreamChunk", () => {
it('should correctly handle text chunks', () => { it("should correctly handle text chunks", () => {
const textChunk: ApiStreamChunk = { const textChunk: ApiStreamChunk = {
type: 'text', type: "text",
text: 'Hello world' text: "Hello world",
}; }
expect(textChunk.type).toBe('text'); expect(textChunk.type).toBe("text")
expect(textChunk.text).toBe('Hello world'); expect(textChunk.text).toBe("Hello world")
}); })
it('should correctly handle usage chunks with cache information', () => { it("should correctly handle usage chunks with cache information", () => {
const usageChunk: ApiStreamChunk = { const usageChunk: ApiStreamChunk = {
type: 'usage', type: "usage",
inputTokens: 100, inputTokens: 100,
outputTokens: 50, outputTokens: 50,
cacheWriteTokens: 20, cacheWriteTokens: 20,
cacheReadTokens: 10 cacheReadTokens: 10,
}; }
expect(usageChunk.type).toBe('usage'); expect(usageChunk.type).toBe("usage")
expect(usageChunk.inputTokens).toBe(100); expect(usageChunk.inputTokens).toBe(100)
expect(usageChunk.outputTokens).toBe(50); expect(usageChunk.outputTokens).toBe(50)
expect(usageChunk.cacheWriteTokens).toBe(20); expect(usageChunk.cacheWriteTokens).toBe(20)
expect(usageChunk.cacheReadTokens).toBe(10); expect(usageChunk.cacheReadTokens).toBe(10)
}); })
it('should handle usage chunks without cache tokens', () => { it("should handle usage chunks without cache tokens", () => {
const usageChunk: ApiStreamChunk = { const usageChunk: ApiStreamChunk = {
type: 'usage', type: "usage",
inputTokens: 100, inputTokens: 100,
outputTokens: 50 outputTokens: 50,
}; }
expect(usageChunk.type).toBe('usage'); expect(usageChunk.type).toBe("usage")
expect(usageChunk.inputTokens).toBe(100); expect(usageChunk.inputTokens).toBe(100)
expect(usageChunk.outputTokens).toBe(50); expect(usageChunk.outputTokens).toBe(50)
expect(usageChunk.cacheWriteTokens).toBeUndefined(); expect(usageChunk.cacheWriteTokens).toBeUndefined()
expect(usageChunk.cacheReadTokens).toBeUndefined(); expect(usageChunk.cacheReadTokens).toBeUndefined()
}); })
it('should handle text chunks with empty strings', () => { it("should handle text chunks with empty strings", () => {
const emptyTextChunk: ApiStreamChunk = { const emptyTextChunk: ApiStreamChunk = {
type: 'text', type: "text",
text: '' text: "",
}; }
expect(emptyTextChunk.type).toBe('text'); expect(emptyTextChunk.type).toBe("text")
expect(emptyTextChunk.text).toBe(''); expect(emptyTextChunk.text).toBe("")
}); })
it('should handle usage chunks with zero tokens', () => { it("should handle usage chunks with zero tokens", () => {
const zeroUsageChunk: ApiStreamChunk = { const zeroUsageChunk: ApiStreamChunk = {
type: 'usage', type: "usage",
inputTokens: 0, inputTokens: 0,
outputTokens: 0 outputTokens: 0,
}; }
expect(zeroUsageChunk.type).toBe('usage'); expect(zeroUsageChunk.type).toBe("usage")
expect(zeroUsageChunk.inputTokens).toBe(0); expect(zeroUsageChunk.inputTokens).toBe(0)
expect(zeroUsageChunk.outputTokens).toBe(0); expect(zeroUsageChunk.outputTokens).toBe(0)
}); })
it('should handle usage chunks with large token counts', () => { it("should handle usage chunks with large token counts", () => {
const largeUsageChunk: ApiStreamChunk = { const largeUsageChunk: ApiStreamChunk = {
type: 'usage', type: "usage",
inputTokens: 1000000, inputTokens: 1000000,
outputTokens: 500000, outputTokens: 500000,
cacheWriteTokens: 200000, cacheWriteTokens: 200000,
cacheReadTokens: 100000 cacheReadTokens: 100000,
}; }
expect(largeUsageChunk.type).toBe('usage'); expect(largeUsageChunk.type).toBe("usage")
expect(largeUsageChunk.inputTokens).toBe(1000000); expect(largeUsageChunk.inputTokens).toBe(1000000)
expect(largeUsageChunk.outputTokens).toBe(500000); expect(largeUsageChunk.outputTokens).toBe(500000)
expect(largeUsageChunk.cacheWriteTokens).toBe(200000); expect(largeUsageChunk.cacheWriteTokens).toBe(200000)
expect(largeUsageChunk.cacheReadTokens).toBe(100000); expect(largeUsageChunk.cacheReadTokens).toBe(100000)
}); })
it('should handle text chunks with special characters', () => { it("should handle text chunks with special characters", () => {
const specialCharsChunk: ApiStreamChunk = { const specialCharsChunk: ApiStreamChunk = {
type: 'text', type: "text",
text: '!@#$%^&*()_+-=[]{}|;:,.<>?`~' text: "!@#$%^&*()_+-=[]{}|;:,.<>?`~",
}; }
expect(specialCharsChunk.type).toBe('text'); expect(specialCharsChunk.type).toBe("text")
expect(specialCharsChunk.text).toBe('!@#$%^&*()_+-=[]{}|;:,.<>?`~'); expect(specialCharsChunk.text).toBe("!@#$%^&*()_+-=[]{}|;:,.<>?`~")
}); })
it('should handle text chunks with unicode characters', () => { it("should handle text chunks with unicode characters", () => {
const unicodeChunk: ApiStreamChunk = { const unicodeChunk: ApiStreamChunk = {
type: 'text', type: "text",
text: '你好世界👋🌍' text: "你好世界👋🌍",
}; }
expect(unicodeChunk.type).toBe('text'); expect(unicodeChunk.type).toBe("text")
expect(unicodeChunk.text).toBe('你好世界👋🌍'); expect(unicodeChunk.text).toBe("你好世界👋🌍")
}); })
it('should handle text chunks with multiline content', () => { it("should handle text chunks with multiline content", () => {
const multilineChunk: ApiStreamChunk = { const multilineChunk: ApiStreamChunk = {
type: 'text', type: "text",
text: 'Line 1\nLine 2\nLine 3' text: "Line 1\nLine 2\nLine 3",
}; }
expect(multilineChunk.type).toBe('text'); expect(multilineChunk.type).toBe("text")
expect(multilineChunk.text).toBe('Line 1\nLine 2\nLine 3'); expect(multilineChunk.text).toBe("Line 1\nLine 2\nLine 3")
expect(multilineChunk.text.split('\n')).toHaveLength(3); expect(multilineChunk.text.split("\n")).toHaveLength(3)
}); })
}); })
}); })

View File

@@ -1,66 +1,66 @@
import { Anthropic } from "@anthropic-ai/sdk"; import { Anthropic } from "@anthropic-ai/sdk"
import * as vscode from 'vscode'; import * as vscode from "vscode"
import { convertToVsCodeLmMessages, convertToAnthropicRole, convertToAnthropicMessage } from '../vscode-lm-format'; import { convertToVsCodeLmMessages, convertToAnthropicRole, convertToAnthropicMessage } from "../vscode-lm-format"
// Mock crypto // Mock crypto
const mockCrypto = { const mockCrypto = {
randomUUID: () => 'test-uuid' randomUUID: () => "test-uuid",
}; }
global.crypto = mockCrypto as any; global.crypto = mockCrypto as any
// Define types for our mocked classes // Define types for our mocked classes
interface MockLanguageModelTextPart { interface MockLanguageModelTextPart {
type: 'text'; type: "text"
value: string; value: string
} }
interface MockLanguageModelToolCallPart { interface MockLanguageModelToolCallPart {
type: 'tool_call'; type: "tool_call"
callId: string; callId: string
name: string; name: string
input: any; input: any
} }
interface MockLanguageModelToolResultPart { interface MockLanguageModelToolResultPart {
type: 'tool_result'; type: "tool_result"
toolUseId: string; toolUseId: string
parts: MockLanguageModelTextPart[]; parts: MockLanguageModelTextPart[]
} }
type MockMessageContent = MockLanguageModelTextPart | MockLanguageModelToolCallPart | MockLanguageModelToolResultPart; type MockMessageContent = MockLanguageModelTextPart | MockLanguageModelToolCallPart | MockLanguageModelToolResultPart
interface MockLanguageModelChatMessage { interface MockLanguageModelChatMessage {
role: string; role: string
name?: string; name?: string
content: MockMessageContent[]; content: MockMessageContent[]
} }
// Mock vscode namespace // Mock vscode namespace
jest.mock('vscode', () => { jest.mock("vscode", () => {
const LanguageModelChatMessageRole = { const LanguageModelChatMessageRole = {
Assistant: 'assistant', Assistant: "assistant",
User: 'user' User: "user",
}; }
class MockLanguageModelTextPart { class MockLanguageModelTextPart {
type = 'text'; type = "text"
constructor(public value: string) {} constructor(public value: string) {}
} }
class MockLanguageModelToolCallPart { class MockLanguageModelToolCallPart {
type = 'tool_call'; type = "tool_call"
constructor( constructor(
public callId: string, public callId: string,
public name: string, public name: string,
public input: any public input: any,
) {} ) {}
} }
class MockLanguageModelToolResultPart { class MockLanguageModelToolResultPart {
type = 'tool_result'; type = "tool_result"
constructor( constructor(
public toolUseId: string, public toolUseId: string,
public parts: MockLanguageModelTextPart[] public parts: MockLanguageModelTextPart[],
) {} ) {}
} }
@@ -68,179 +68,189 @@ jest.mock('vscode', () => {
LanguageModelChatMessage: { LanguageModelChatMessage: {
Assistant: jest.fn((content) => ({ Assistant: jest.fn((content) => ({
role: LanguageModelChatMessageRole.Assistant, role: LanguageModelChatMessageRole.Assistant,
name: 'assistant', name: "assistant",
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)] content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
})), })),
User: jest.fn((content) => ({ User: jest.fn((content) => ({
role: LanguageModelChatMessageRole.User, role: LanguageModelChatMessageRole.User,
name: 'user', name: "user",
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)] content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
})) })),
}, },
LanguageModelChatMessageRole, LanguageModelChatMessageRole,
LanguageModelTextPart: MockLanguageModelTextPart, LanguageModelTextPart: MockLanguageModelTextPart,
LanguageModelToolCallPart: MockLanguageModelToolCallPart, LanguageModelToolCallPart: MockLanguageModelToolCallPart,
LanguageModelToolResultPart: MockLanguageModelToolResultPart LanguageModelToolResultPart: MockLanguageModelToolResultPart,
}; }
}); })
describe('vscode-lm-format', () => { describe("vscode-lm-format", () => {
describe('convertToVsCodeLmMessages', () => { describe("convertToVsCodeLmMessages", () => {
it('should convert simple string messages', () => { it("should convert simple string messages", () => {
const messages: Anthropic.Messages.MessageParam[] = [ const messages: Anthropic.Messages.MessageParam[] = [
{ role: 'user', content: 'Hello' }, { role: "user", content: "Hello" },
{ role: 'assistant', content: 'Hi there' } { role: "assistant", content: "Hi there" },
];
const result = convertToVsCodeLmMessages(messages);
expect(result).toHaveLength(2);
expect(result[0].role).toBe('user');
expect((result[0].content[0] as MockLanguageModelTextPart).value).toBe('Hello');
expect(result[1].role).toBe('assistant');
expect((result[1].content[0] as MockLanguageModelTextPart).value).toBe('Hi there');
});
it('should handle complex user messages with tool results', () => {
const messages: Anthropic.Messages.MessageParam[] = [{
role: 'user',
content: [
{ type: 'text', text: 'Here is the result:' },
{
type: 'tool_result',
tool_use_id: 'tool-1',
content: 'Tool output'
}
] ]
}];
const result = convertToVsCodeLmMessages(messages); const result = convertToVsCodeLmMessages(messages)
expect(result).toHaveLength(1); expect(result).toHaveLength(2)
expect(result[0].role).toBe('user'); expect(result[0].role).toBe("user")
expect(result[0].content).toHaveLength(2); expect((result[0].content[0] as MockLanguageModelTextPart).value).toBe("Hello")
const [toolResult, textContent] = result[0].content as [MockLanguageModelToolResultPart, MockLanguageModelTextPart]; expect(result[1].role).toBe("assistant")
expect(toolResult.type).toBe('tool_result'); expect((result[1].content[0] as MockLanguageModelTextPart).value).toBe("Hi there")
expect(textContent.type).toBe('text'); })
});
it('should handle complex assistant messages with tool calls', () => { it("should handle complex user messages with tool results", () => {
const messages: Anthropic.Messages.MessageParam[] = [{ const messages: Anthropic.Messages.MessageParam[] = [
role: 'assistant',
content: [
{ type: 'text', text: 'Let me help you with that.' },
{ {
type: 'tool_use', role: "user",
id: 'tool-1', content: [
name: 'calculator', { type: "text", text: "Here is the result:" },
input: { operation: 'add', numbers: [2, 2] } {
} type: "tool_result",
tool_use_id: "tool-1",
content: "Tool output",
},
],
},
] ]
}];
const result = convertToVsCodeLmMessages(messages); const result = convertToVsCodeLmMessages(messages)
expect(result).toHaveLength(1); expect(result).toHaveLength(1)
expect(result[0].role).toBe('assistant'); expect(result[0].role).toBe("user")
expect(result[0].content).toHaveLength(2); expect(result[0].content).toHaveLength(2)
const [toolCall, textContent] = result[0].content as [MockLanguageModelToolCallPart, MockLanguageModelTextPart]; const [toolResult, textContent] = result[0].content as [
expect(toolCall.type).toBe('tool_call'); MockLanguageModelToolResultPart,
expect(textContent.type).toBe('text'); MockLanguageModelTextPart,
}); ]
expect(toolResult.type).toBe("tool_result")
expect(textContent.type).toBe("text")
})
it('should handle image blocks with appropriate placeholders', () => { it("should handle complex assistant messages with tool calls", () => {
const messages: Anthropic.Messages.MessageParam[] = [{ const messages: Anthropic.Messages.MessageParam[] = [
role: 'user',
content: [
{ type: 'text', text: 'Look at this:' },
{ {
type: 'image', role: "assistant",
content: [
{ type: "text", text: "Let me help you with that." },
{
type: "tool_use",
id: "tool-1",
name: "calculator",
input: { operation: "add", numbers: [2, 2] },
},
],
},
]
const result = convertToVsCodeLmMessages(messages)
expect(result).toHaveLength(1)
expect(result[0].role).toBe("assistant")
expect(result[0].content).toHaveLength(2)
const [toolCall, textContent] = result[0].content as [
MockLanguageModelToolCallPart,
MockLanguageModelTextPart,
]
expect(toolCall.type).toBe("tool_call")
expect(textContent.type).toBe("text")
})
it("should handle image blocks with appropriate placeholders", () => {
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [
{ type: "text", text: "Look at this:" },
{
type: "image",
source: { source: {
type: 'base64', type: "base64",
media_type: 'image/png', media_type: "image/png",
data: 'base64data' data: "base64data",
} },
} },
],
},
] ]
}];
const result = convertToVsCodeLmMessages(messages); const result = convertToVsCodeLmMessages(messages)
expect(result).toHaveLength(1); expect(result).toHaveLength(1)
const imagePlaceholder = result[0].content[1] as MockLanguageModelTextPart; const imagePlaceholder = result[0].content[1] as MockLanguageModelTextPart
expect(imagePlaceholder.value).toContain('[Image (base64): image/png not supported by VSCode LM API]'); expect(imagePlaceholder.value).toContain("[Image (base64): image/png not supported by VSCode LM API]")
}); })
}); })
describe('convertToAnthropicRole', () => { describe("convertToAnthropicRole", () => {
it('should convert assistant role correctly', () => { it("should convert assistant role correctly", () => {
const result = convertToAnthropicRole('assistant' as any); const result = convertToAnthropicRole("assistant" as any)
expect(result).toBe('assistant'); expect(result).toBe("assistant")
}); })
it('should convert user role correctly', () => { it("should convert user role correctly", () => {
const result = convertToAnthropicRole('user' as any); const result = convertToAnthropicRole("user" as any)
expect(result).toBe('user'); expect(result).toBe("user")
}); })
it('should return null for unknown roles', () => { it("should return null for unknown roles", () => {
const result = convertToAnthropicRole('unknown' as any); const result = convertToAnthropicRole("unknown" as any)
expect(result).toBeNull(); expect(result).toBeNull()
}); })
}); })
describe('convertToAnthropicMessage', () => { describe("convertToAnthropicMessage", () => {
it('should convert assistant message with text content', async () => { it("should convert assistant message with text content", async () => {
const vsCodeMessage = { const vsCodeMessage = {
role: 'assistant', role: "assistant",
name: 'assistant', name: "assistant",
content: [new vscode.LanguageModelTextPart('Hello')] content: [new vscode.LanguageModelTextPart("Hello")],
}; }
const result = await convertToAnthropicMessage(vsCodeMessage as any); const result = await convertToAnthropicMessage(vsCodeMessage as any)
expect(result.role).toBe('assistant'); expect(result.role).toBe("assistant")
expect(result.content).toHaveLength(1); expect(result.content).toHaveLength(1)
expect(result.content[0]).toEqual({ expect(result.content[0]).toEqual({
type: 'text', type: "text",
text: 'Hello' text: "Hello",
}); })
expect(result.id).toBe('test-uuid'); expect(result.id).toBe("test-uuid")
}); })
it('should convert assistant message with tool calls', async () => { it("should convert assistant message with tool calls", async () => {
const vsCodeMessage = { const vsCodeMessage = {
role: 'assistant', role: "assistant",
name: 'assistant', name: "assistant",
content: [new vscode.LanguageModelToolCallPart( content: [
'call-1', new vscode.LanguageModelToolCallPart("call-1", "calculator", { operation: "add", numbers: [2, 2] }),
'calculator', ],
{ operation: 'add', numbers: [2, 2] } }
)]
};
const result = await convertToAnthropicMessage(vsCodeMessage as any); const result = await convertToAnthropicMessage(vsCodeMessage as any)
expect(result.content).toHaveLength(1); expect(result.content).toHaveLength(1)
expect(result.content[0]).toEqual({ expect(result.content[0]).toEqual({
type: 'tool_use', type: "tool_use",
id: 'call-1', id: "call-1",
name: 'calculator', name: "calculator",
input: { operation: 'add', numbers: [2, 2] } input: { operation: "add", numbers: [2, 2] },
}); })
expect(result.id).toBe('test-uuid'); expect(result.id).toBe("test-uuid")
}); })
it('should throw error for non-assistant messages', async () => { it("should throw error for non-assistant messages", async () => {
const vsCodeMessage = { const vsCodeMessage = {
role: 'user', role: "user",
name: 'user', name: "user",
content: [new vscode.LanguageModelTextPart('Hello')] content: [new vscode.LanguageModelTextPart("Hello")],
}; }
await expect(convertToAnthropicMessage(vsCodeMessage as any)) await expect(convertToAnthropicMessage(vsCodeMessage as any)).rejects.toThrow(
.rejects "Cline <Language Model API>: Only assistant messages are supported.",
.toThrow('Cline <Language Model API>: Only assistant messages are supported.'); )
}); })
}); })
}); })

View File

@@ -8,41 +8,41 @@ import { StreamEvent } from "../providers/bedrock"
/** /**
* Convert Anthropic messages to Bedrock Converse format * Convert Anthropic messages to Bedrock Converse format
*/ */
export function convertToBedrockConverseMessages( export function convertToBedrockConverseMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): Message[] {
anthropicMessages: Anthropic.Messages.MessageParam[] return anthropicMessages.map((anthropicMessage) => {
): Message[] {
return anthropicMessages.map(anthropicMessage => {
// Map Anthropic roles to Bedrock roles // Map Anthropic roles to Bedrock roles
const role: ConversationRole = anthropicMessage.role === "assistant" ? "assistant" : "user" const role: ConversationRole = anthropicMessage.role === "assistant" ? "assistant" : "user"
if (typeof anthropicMessage.content === "string") { if (typeof anthropicMessage.content === "string") {
return { return {
role, role,
content: [{ content: [
text: anthropicMessage.content {
}] as ContentBlock[] text: anthropicMessage.content,
},
] as ContentBlock[],
} }
} }
// Process complex content types // Process complex content types
const content = anthropicMessage.content.map(block => { const content = anthropicMessage.content.map((block) => {
const messageBlock = block as MessageContent & { const messageBlock = block as MessageContent & {
id?: string, id?: string
tool_use_id?: string, tool_use_id?: string
content?: Array<{ type: string, text: string }>, content?: Array<{ type: string; text: string }>
output?: string | Array<{ type: string, text: string }> output?: string | Array<{ type: string; text: string }>
} }
if (messageBlock.type === "text") { if (messageBlock.type === "text") {
return { return {
text: messageBlock.text || '' text: messageBlock.text || "",
} as ContentBlock } as ContentBlock
} }
if (messageBlock.type === "image" && messageBlock.source) { if (messageBlock.type === "image" && messageBlock.source) {
// Convert base64 string to byte array if needed // Convert base64 string to byte array if needed
let byteArray: Uint8Array let byteArray: Uint8Array
if (typeof messageBlock.source.data === 'string') { if (typeof messageBlock.source.data === "string") {
const binaryString = atob(messageBlock.source.data) const binaryString = atob(messageBlock.source.data)
byteArray = new Uint8Array(binaryString.length) byteArray = new Uint8Array(binaryString.length)
for (let i = 0; i < binaryString.length; i++) { for (let i = 0; i < binaryString.length; i++) {
@@ -53,8 +53,8 @@ export function convertToBedrockConverseMessages(
} }
// Extract format from media_type (e.g., "image/jpeg" -> "jpeg") // Extract format from media_type (e.g., "image/jpeg" -> "jpeg")
const format = messageBlock.source.media_type.split('/')[1] const format = messageBlock.source.media_type.split("/")[1]
if (!['png', 'jpeg', 'gif', 'webp'].includes(format)) { if (!["png", "jpeg", "gif", "webp"].includes(format)) {
throw new Error(`Unsupported image format: ${format}`) throw new Error(`Unsupported image format: ${format}`)
} }
@@ -62,9 +62,9 @@ export function convertToBedrockConverseMessages(
image: { image: {
format: format as "png" | "jpeg" | "gif" | "webp", format: format as "png" | "jpeg" | "gif" | "webp",
source: { source: {
bytes: byteArray bytes: byteArray,
} },
} },
} as ContentBlock } as ContentBlock
} }
@@ -72,14 +72,14 @@ export function convertToBedrockConverseMessages(
// Convert tool use to XML format // Convert tool use to XML format
const toolParams = Object.entries(messageBlock.input || {}) const toolParams = Object.entries(messageBlock.input || {})
.map(([key, value]) => `<${key}>\n${value}\n</${key}>`) .map(([key, value]) => `<${key}>\n${value}\n</${key}>`)
.join('\n') .join("\n")
return { return {
toolUse: { toolUse: {
toolUseId: messageBlock.id || '', toolUseId: messageBlock.id || "",
name: messageBlock.name || '', name: messageBlock.name || "",
input: `<${messageBlock.name}>\n${toolParams}\n</${messageBlock.name}>` input: `<${messageBlock.name}>\n${toolParams}\n</${messageBlock.name}>`,
} },
} as ContentBlock } as ContentBlock
} }
@@ -88,12 +88,12 @@ export function convertToBedrockConverseMessages(
if (messageBlock.content && Array.isArray(messageBlock.content)) { if (messageBlock.content && Array.isArray(messageBlock.content)) {
return { return {
toolResult: { toolResult: {
toolUseId: messageBlock.tool_use_id || '', toolUseId: messageBlock.tool_use_id || "",
content: messageBlock.content.map(item => ({ content: messageBlock.content.map((item) => ({
text: item.text text: item.text,
})), })),
status: "success" status: "success",
} },
} as ContentBlock } as ContentBlock
} }
@@ -101,20 +101,22 @@ export function convertToBedrockConverseMessages(
if (messageBlock.output && typeof messageBlock.output === "string") { if (messageBlock.output && typeof messageBlock.output === "string") {
return { return {
toolResult: { toolResult: {
toolUseId: messageBlock.tool_use_id || '', toolUseId: messageBlock.tool_use_id || "",
content: [{ content: [
text: messageBlock.output {
}], text: messageBlock.output,
status: "success" },
} ],
status: "success",
},
} as ContentBlock } as ContentBlock
} }
// Handle array of content blocks if output is an array // Handle array of content blocks if output is an array
if (Array.isArray(messageBlock.output)) { if (Array.isArray(messageBlock.output)) {
return { return {
toolResult: { toolResult: {
toolUseId: messageBlock.tool_use_id || '', toolUseId: messageBlock.tool_use_id || "",
content: messageBlock.output.map(part => { content: messageBlock.output.map((part) => {
if (typeof part === "object" && "text" in part) { if (typeof part === "object" && "text" in part) {
return { text: part.text } return { text: part.text }
} }
@@ -124,48 +126,52 @@ export function convertToBedrockConverseMessages(
} }
return { text: String(part) } return { text: String(part) }
}), }),
status: "success" status: "success",
} },
} as ContentBlock } as ContentBlock
} }
// Default case // Default case
return { return {
toolResult: { toolResult: {
toolUseId: messageBlock.tool_use_id || '', toolUseId: messageBlock.tool_use_id || "",
content: [{ content: [
text: String(messageBlock.output || '') {
}], text: String(messageBlock.output || ""),
status: "success" },
} ],
status: "success",
},
} as ContentBlock } as ContentBlock
} }
if (messageBlock.type === "video") { if (messageBlock.type === "video") {
const videoContent = messageBlock.s3Location ? { const videoContent = messageBlock.s3Location
? {
s3Location: { s3Location: {
uri: messageBlock.s3Location.uri, uri: messageBlock.s3Location.uri,
bucketOwner: messageBlock.s3Location.bucketOwner bucketOwner: messageBlock.s3Location.bucketOwner,
},
} }
} : messageBlock.source : messageBlock.source
return { return {
video: { video: {
format: "mp4", // Default to mp4, adjust based on actual format if needed format: "mp4", // Default to mp4, adjust based on actual format if needed
source: videoContent source: videoContent,
} },
} as ContentBlock } as ContentBlock
} }
// Default case for unknown block types // Default case for unknown block types
return { return {
text: '[Unknown Block Type]' text: "[Unknown Block Type]",
} as ContentBlock } as ContentBlock
}) })
return { return {
role, role,
content content,
} }
}) })
} }
@@ -175,19 +181,19 @@ export function convertToBedrockConverseMessages(
*/ */
export function convertToAnthropicMessage( export function convertToAnthropicMessage(
streamEvent: StreamEvent, streamEvent: StreamEvent,
modelId: string modelId: string,
): Partial<Anthropic.Messages.Message> { ): Partial<Anthropic.Messages.Message> {
// Handle metadata events // Handle metadata events
if (streamEvent.metadata?.usage) { if (streamEvent.metadata?.usage) {
return { return {
id: '', // Bedrock doesn't provide message IDs id: "", // Bedrock doesn't provide message IDs
type: "message", type: "message",
role: "assistant", role: "assistant",
model: modelId, model: modelId,
usage: { usage: {
input_tokens: streamEvent.metadata.usage.inputTokens || 0, input_tokens: streamEvent.metadata.usage.inputTokens || 0,
output_tokens: streamEvent.metadata.usage.outputTokens || 0 output_tokens: streamEvent.metadata.usage.outputTokens || 0,
} },
} }
} }
@@ -198,7 +204,7 @@ export function convertToAnthropicMessage(
type: "message", type: "message",
role: "assistant", role: "assistant",
content: [{ type: "text", text: text }], content: [{ type: "text", text: text }],
model: modelId model: modelId,
} }
} }
@@ -209,7 +215,7 @@ export function convertToAnthropicMessage(
role: "assistant", role: "assistant",
stop_reason: streamEvent.messageStop.stopReason || null, stop_reason: streamEvent.messageStop.stopReason || null,
stop_sequence: null, stop_sequence: null,
model: modelId model: modelId,
} }
} }

View File

@@ -1,5 +1,5 @@
import { Anthropic } from "@anthropic-ai/sdk"; import { Anthropic } from "@anthropic-ai/sdk"
import * as vscode from 'vscode'; import * as vscode from "vscode"
/** /**
* Safely converts a value into a plain object. * Safely converts a value into a plain object.
@@ -7,30 +7,31 @@ import * as vscode from 'vscode';
function asObjectSafe(value: any): object { function asObjectSafe(value: any): object {
// Handle null/undefined // Handle null/undefined
if (!value) { if (!value) {
return {}; return {}
} }
try { try {
// Handle strings that might be JSON // Handle strings that might be JSON
if (typeof value === 'string') { if (typeof value === "string") {
return JSON.parse(value); return JSON.parse(value)
} }
// Handle pre-existing objects // Handle pre-existing objects
if (typeof value === 'object') { if (typeof value === "object") {
return Object.assign({}, value); return Object.assign({}, value)
} }
return {}; return {}
} } catch (error) {
catch (error) { console.warn("Cline <Language Model API>: Failed to parse object:", error)
console.warn('Cline <Language Model API>: Failed to parse object:', error); return {}
return {};
} }
} }
export function convertToVsCodeLmMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): vscode.LanguageModelChatMessage[] { export function convertToVsCodeLmMessages(
const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = []; anthropicMessages: Anthropic.Messages.MessageParam[],
): vscode.LanguageModelChatMessage[] {
const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = []
for (const anthropicMessage of anthropicMessages) { for (const anthropicMessage of anthropicMessages) {
// Handle simple string messages // Handle simple string messages
@@ -38,135 +39,129 @@ export function convertToVsCodeLmMessages(anthropicMessages: Anthropic.Messages.
vsCodeLmMessages.push( vsCodeLmMessages.push(
anthropicMessage.role === "assistant" anthropicMessage.role === "assistant"
? vscode.LanguageModelChatMessage.Assistant(anthropicMessage.content) ? vscode.LanguageModelChatMessage.Assistant(anthropicMessage.content)
: vscode.LanguageModelChatMessage.User(anthropicMessage.content) : vscode.LanguageModelChatMessage.User(anthropicMessage.content),
); )
continue; continue
} }
// Handle complex message structures // Handle complex message structures
switch (anthropicMessage.role) { switch (anthropicMessage.role) {
case "user": { case "user": {
const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{ const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]; nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
toolMessages: Anthropic.ToolResultBlockParam[]; toolMessages: Anthropic.ToolResultBlockParam[]
}>( }>(
(acc, part) => { (acc, part) => {
if (part.type === "tool_result") { if (part.type === "tool_result") {
acc.toolMessages.push(part); acc.toolMessages.push(part)
} else if (part.type === "text" || part.type === "image") {
acc.nonToolMessages.push(part)
} }
else if (part.type === "text" || part.type === "image") { return acc
acc.nonToolMessages.push(part);
}
return acc;
}, },
{ nonToolMessages: [], toolMessages: [] }, { nonToolMessages: [], toolMessages: [] },
); )
// Process tool messages first then non-tool messages // Process tool messages first then non-tool messages
const contentParts = [ const contentParts = [
// Convert tool messages to ToolResultParts // Convert tool messages to ToolResultParts
...toolMessages.map((toolMessage) => { ...toolMessages.map((toolMessage) => {
// Process tool result content into TextParts // Process tool result content into TextParts
const toolContentParts: vscode.LanguageModelTextPart[] = ( const toolContentParts: vscode.LanguageModelTextPart[] =
typeof toolMessage.content === "string" typeof toolMessage.content === "string"
? [new vscode.LanguageModelTextPart(toolMessage.content)] ? [new vscode.LanguageModelTextPart(toolMessage.content)]
: ( : (toolMessage.content?.map((part) => {
toolMessage.content?.map((part) => {
if (part.type === "image") { if (part.type === "image") {
return new vscode.LanguageModelTextPart( return new vscode.LanguageModelTextPart(
`[Image (${part.source?.type || 'Unknown source-type'}): ${part.source?.media_type || 'unknown media-type'} not supported by VSCode LM API]` `[Image (${part.source?.type || "Unknown source-type"}): ${part.source?.media_type || "unknown media-type"} not supported by VSCode LM API]`,
);
}
return new vscode.LanguageModelTextPart(part.text);
})
?? [new vscode.LanguageModelTextPart("")]
) )
); }
return new vscode.LanguageModelTextPart(part.text)
}) ?? [new vscode.LanguageModelTextPart("")])
return new vscode.LanguageModelToolResultPart( return new vscode.LanguageModelToolResultPart(toolMessage.tool_use_id, toolContentParts)
toolMessage.tool_use_id,
toolContentParts
);
}), }),
// Convert non-tool messages to TextParts after tool messages // Convert non-tool messages to TextParts after tool messages
...nonToolMessages.map((part) => { ...nonToolMessages.map((part) => {
if (part.type === "image") { if (part.type === "image") {
return new vscode.LanguageModelTextPart( return new vscode.LanguageModelTextPart(
`[Image (${part.source?.type || 'Unknown source-type'}): ${part.source?.media_type || 'unknown media-type'} not supported by VSCode LM API]` `[Image (${part.source?.type || "Unknown source-type"}): ${part.source?.media_type || "unknown media-type"} not supported by VSCode LM API]`,
); )
} }
return new vscode.LanguageModelTextPart(part.text); return new vscode.LanguageModelTextPart(part.text)
}) }),
]; ]
// Add single user message with all content parts // Add single user message with all content parts
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.User(contentParts)); vsCodeLmMessages.push(vscode.LanguageModelChatMessage.User(contentParts))
break; break
} }
case "assistant": { case "assistant": {
const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{ const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]; nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
toolMessages: Anthropic.ToolUseBlockParam[]; toolMessages: Anthropic.ToolUseBlockParam[]
}>( }>(
(acc, part) => { (acc, part) => {
if (part.type === "tool_use") { if (part.type === "tool_use") {
acc.toolMessages.push(part); acc.toolMessages.push(part)
} else if (part.type === "text" || part.type === "image") {
acc.nonToolMessages.push(part)
} }
else if (part.type === "text" || part.type === "image") { return acc
acc.nonToolMessages.push(part);
}
return acc;
}, },
{ nonToolMessages: [], toolMessages: [] }, { nonToolMessages: [], toolMessages: [] },
); )
// Process tool messages first then non-tool messages // Process tool messages first then non-tool messages
const contentParts = [ const contentParts = [
// Convert tool messages to ToolCallParts first // Convert tool messages to ToolCallParts first
...toolMessages.map((toolMessage) => ...toolMessages.map(
(toolMessage) =>
new vscode.LanguageModelToolCallPart( new vscode.LanguageModelToolCallPart(
toolMessage.id, toolMessage.id,
toolMessage.name, toolMessage.name,
asObjectSafe(toolMessage.input) asObjectSafe(toolMessage.input),
) ),
), ),
// Convert non-tool messages to TextParts after tool messages // Convert non-tool messages to TextParts after tool messages
...nonToolMessages.map((part) => { ...nonToolMessages.map((part) => {
if (part.type === "image") { if (part.type === "image") {
return new vscode.LanguageModelTextPart("[Image generation not supported by VSCode LM API]"); return new vscode.LanguageModelTextPart("[Image generation not supported by VSCode LM API]")
} }
return new vscode.LanguageModelTextPart(part.text); return new vscode.LanguageModelTextPart(part.text)
}) }),
]; ]
// Add the assistant message to the list of messages // Add the assistant message to the list of messages
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.Assistant(contentParts)); vsCodeLmMessages.push(vscode.LanguageModelChatMessage.Assistant(contentParts))
break; break
} }
} }
} }
return vsCodeLmMessages; return vsCodeLmMessages
} }
export function convertToAnthropicRole(vsCodeLmMessageRole: vscode.LanguageModelChatMessageRole): string | null { export function convertToAnthropicRole(vsCodeLmMessageRole: vscode.LanguageModelChatMessageRole): string | null {
switch (vsCodeLmMessageRole) { switch (vsCodeLmMessageRole) {
case vscode.LanguageModelChatMessageRole.Assistant: case vscode.LanguageModelChatMessageRole.Assistant:
return "assistant"; return "assistant"
case vscode.LanguageModelChatMessageRole.User: case vscode.LanguageModelChatMessageRole.User:
return "user"; return "user"
default: default:
return null; return null
} }
} }
export async function convertToAnthropicMessage(vsCodeLmMessage: vscode.LanguageModelChatMessage): Promise<Anthropic.Messages.Message> { export async function convertToAnthropicMessage(
const anthropicRole: string | null = convertToAnthropicRole(vsCodeLmMessage.role); vsCodeLmMessage: vscode.LanguageModelChatMessage,
): Promise<Anthropic.Messages.Message> {
const anthropicRole: string | null = convertToAnthropicRole(vsCodeLmMessage.role)
if (anthropicRole !== "assistant") { if (anthropicRole !== "assistant") {
throw new Error("Cline <Language Model API>: Only assistant messages are supported."); throw new Error("Cline <Language Model API>: Only assistant messages are supported.")
} }
return { return {
@@ -174,14 +169,13 @@ export async function convertToAnthropicMessage(vsCodeLmMessage: vscode.Language
type: "message", type: "message",
model: "vscode-lm", model: "vscode-lm",
role: anthropicRole, role: anthropicRole,
content: ( content: vsCodeLmMessage.content
vsCodeLmMessage.content
.map((part): Anthropic.ContentBlock | null => { .map((part): Anthropic.ContentBlock | null => {
if (part instanceof vscode.LanguageModelTextPart) { if (part instanceof vscode.LanguageModelTextPart) {
return { return {
type: "text", type: "text",
text: part.value text: part.value,
}; }
} }
if (part instanceof vscode.LanguageModelToolCallPart) { if (part instanceof vscode.LanguageModelToolCallPart) {
@@ -189,21 +183,18 @@ export async function convertToAnthropicMessage(vsCodeLmMessage: vscode.Language
type: "tool_use", type: "tool_use",
id: part.callId || crypto.randomUUID(), id: part.callId || crypto.randomUUID(),
name: part.name, name: part.name,
input: asObjectSafe(part.input) input: asObjectSafe(part.input),
}; }
} }
return null; return null
}) })
.filter( .filter((part): part is Anthropic.ContentBlock => part !== null),
(part): part is Anthropic.ContentBlock => part !== null
)
),
stop_reason: null, stop_reason: null,
stop_sequence: null, stop_sequence: null,
usage: { usage: {
input_tokens: 0, input_tokens: 0,
output_tokens: 0, output_tokens: 0,
},
} }
};
} }

View File

@@ -13,7 +13,13 @@ import { ApiHandler, SingleCompletionHandler, buildApiHandler } from "../api"
import { ApiStream } from "../api/transform/stream" import { ApiStream } from "../api/transform/stream"
import { DiffViewProvider } from "../integrations/editor/DiffViewProvider" import { DiffViewProvider } from "../integrations/editor/DiffViewProvider"
import { findToolName, formatContentBlockToMarkdown } from "../integrations/misc/export-markdown" import { findToolName, formatContentBlockToMarkdown } from "../integrations/misc/export-markdown"
import { extractTextFromFile, addLineNumbers, stripLineNumbers, everyLineHasLineNumbers, truncateOutput } from "../integrations/misc/extract-text" import {
extractTextFromFile,
addLineNumbers,
stripLineNumbers,
everyLineHasLineNumbers,
truncateOutput,
} from "../integrations/misc/extract-text"
import { TerminalManager } from "../integrations/terminal/TerminalManager" import { TerminalManager } from "../integrations/terminal/TerminalManager"
import { UrlContentFetcher } from "../services/browser/UrlContentFetcher" import { UrlContentFetcher } from "../services/browser/UrlContentFetcher"
import { listFiles } from "../services/glob/list-files" import { listFiles } from "../services/glob/list-files"
@@ -45,7 +51,8 @@ import { arePathsEqual, getReadablePath } from "../utils/path"
import { parseMentions } from "./mentions" import { parseMentions } from "./mentions"
import { AssistantMessageContent, parseAssistantMessage, ToolParamName, ToolUseName } from "./assistant-message" import { AssistantMessageContent, parseAssistantMessage, ToolParamName, ToolUseName } from "./assistant-message"
import { formatResponse } from "./prompts/responses" import { formatResponse } from "./prompts/responses"
import { addCustomInstructions, codeMode, SYSTEM_PROMPT } from "./prompts/system" import { addCustomInstructions, SYSTEM_PROMPT } from "./prompts/system"
import { modes, defaultModeSlug } from "../shared/modes"
import { truncateHalfConversation } from "./sliding-window" import { truncateHalfConversation } from "./sliding-window"
import { ClineProvider, GlobalFileNames } from "./webview/ClineProvider" import { ClineProvider, GlobalFileNames } from "./webview/ClineProvider"
import { detectCodeOmission } from "../integrations/editor/detect-omission" import { detectCodeOmission } from "../integrations/editor/detect-omission"
@@ -111,7 +118,7 @@ export class Cline {
experimentalDiffStrategy: boolean = false, experimentalDiffStrategy: boolean = false,
) { ) {
if (!task && !images && !historyItem) { if (!task && !images && !historyItem) {
throw new Error('Either historyItem or task/images must be provided'); throw new Error("Either historyItem or task/images must be provided")
} }
this.taskId = crypto.randomUUID() this.taskId = crypto.randomUUID()
@@ -143,7 +150,8 @@ export class Cline {
async updateDiffStrategy(experimentalDiffStrategy?: boolean) { async updateDiffStrategy(experimentalDiffStrategy?: boolean) {
// If not provided, get from current state // If not provided, get from current state
if (experimentalDiffStrategy === undefined) { if (experimentalDiffStrategy === undefined) {
const { experimentalDiffStrategy: stateExperimentalDiffStrategy } = await this.providerRef.deref()?.getState() ?? {} const { experimentalDiffStrategy: stateExperimentalDiffStrategy } =
(await this.providerRef.deref()?.getState()) ?? {}
experimentalDiffStrategy = stateExperimentalDiffStrategy ?? false experimentalDiffStrategy = stateExperimentalDiffStrategy ?? false
} }
this.diffStrategy = getDiffStrategy(this.api.getModel().id, this.fuzzyMatchThreshold, experimentalDiffStrategy) this.diffStrategy = getDiffStrategy(this.api.getModel().id, this.fuzzyMatchThreshold, experimentalDiffStrategy)
@@ -755,8 +763,8 @@ export class Cline {
// grouping command_output messages despite any gaps anyways) // grouping command_output messages despite any gaps anyways)
await delay(50) await delay(50)
const { terminalOutputLineLimit } = await this.providerRef.deref()?.getState() ?? {} const { terminalOutputLineLimit } = (await this.providerRef.deref()?.getState()) ?? {}
const output = truncateOutput(lines.join('\n'), terminalOutputLineLimit) const output = truncateOutput(lines.join("\n"), terminalOutputLineLimit)
const result = output.trim() const result = output.trim()
if (userFeedback) { if (userFeedback) {
@@ -787,7 +795,8 @@ export class Cline {
async *attemptApiRequest(previousApiReqIndex: number): ApiStream { async *attemptApiRequest(previousApiReqIndex: number): ApiStream {
let mcpHub: McpHub | undefined let mcpHub: McpHub | undefined
const { mcpEnabled, alwaysApproveResubmit, requestDelaySeconds } = await this.providerRef.deref()?.getState() ?? {} const { mcpEnabled, alwaysApproveResubmit, requestDelaySeconds } =
(await this.providerRef.deref()?.getState()) ?? {}
if (mcpEnabled ?? true) { if (mcpEnabled ?? true) {
mcpHub = this.providerRef.deref()?.mcpHub mcpHub = this.providerRef.deref()?.mcpHub
@@ -800,24 +809,27 @@ export class Cline {
}) })
} }
const { browserViewportSize, preferredLanguage, mode, customPrompts } = await this.providerRef.deref()?.getState() ?? {} const { browserViewportSize, preferredLanguage, mode, customPrompts } =
const systemPrompt = await SYSTEM_PROMPT( (await this.providerRef.deref()?.getState()) ?? {}
const systemPrompt =
(await SYSTEM_PROMPT(
cwd, cwd,
this.api.getModel().info.supportsComputerUse ?? false, this.api.getModel().info.supportsComputerUse ?? false,
mcpHub, mcpHub,
this.diffStrategy, this.diffStrategy,
browserViewportSize, browserViewportSize,
mode, mode,
customPrompts customPrompts,
) + await addCustomInstructions( )) +
(await addCustomInstructions(
{ {
customInstructions: this.customInstructions, customInstructions: this.customInstructions,
customPrompts, customPrompts,
preferredLanguage preferredLanguage,
}, },
cwd, cwd,
mode mode,
) ))
// If the previous API request's total token usage is close to the context window, truncate the conversation history to free up space for the new request // If the previous API request's total token usage is close to the context window, truncate the conversation history to free up space for the new request
if (previousApiReqIndex >= 0) { if (previousApiReqIndex >= 0) {
@@ -844,18 +856,18 @@ export class Cline {
if (Array.isArray(content)) { if (Array.isArray(content)) {
if (!this.api.getModel().info.supportsImages) { if (!this.api.getModel().info.supportsImages) {
// Convert image blocks to text descriptions // Convert image blocks to text descriptions
content = content.map(block => { content = content.map((block) => {
if (block.type === 'image') { if (block.type === "image") {
// Convert image blocks to text descriptions // Convert image blocks to text descriptions
// Note: We can't access the actual image content/url due to API limitations, // Note: We can't access the actual image content/url due to API limitations,
// but we can indicate that an image was present in the conversation // but we can indicate that an image was present in the conversation
return { return {
type: 'text', type: "text",
text: '[Referenced image in conversation]' text: "[Referenced image in conversation]",
};
} }
return block; }
}); return block
})
} }
} }
return { role, content } return { role, content }
@@ -875,7 +887,12 @@ export class Cline {
// Automatically retry with delay // Automatically retry with delay
// Show countdown timer in error color // Show countdown timer in error color
for (let i = requestDelay; i > 0; i--) { for (let i = requestDelay; i > 0; i--) {
await this.say("api_req_retry_delayed", `${errorMsg}\n\nRetrying in ${i} seconds...`, undefined, true) await this.say(
"api_req_retry_delayed",
`${errorMsg}\n\nRetrying in ${i} seconds...`,
undefined,
true,
)
await delay(1000) await delay(1000)
} }
await this.say("api_req_retry_delayed", `${errorMsg}\n\nRetrying now...`, undefined, false) await this.say("api_req_retry_delayed", `${errorMsg}\n\nRetrying now...`, undefined, false)
@@ -1124,9 +1141,9 @@ export class Cline {
} }
// Validate tool use based on current mode // Validate tool use based on current mode
const { mode } = await this.providerRef.deref()?.getState() ?? {} const { mode } = (await this.providerRef.deref()?.getState()) ?? {}
try { try {
validateToolUse(block.name, mode ?? codeMode) validateToolUse(block.name, mode ?? defaultModeSlug)
} catch (error) { } catch (error) {
this.consecutiveMistakeCount++ this.consecutiveMistakeCount++
pushToolResult(formatResponse.toolError(error.message)) pushToolResult(formatResponse.toolError(error.message))
@@ -1191,7 +1208,10 @@ export class Cline {
await this.diffViewProvider.open(relPath) await this.diffViewProvider.open(relPath)
} }
// editor is open, stream content in // editor is open, stream content in
await this.diffViewProvider.update(everyLineHasLineNumbers(newContent) ? stripLineNumbers(newContent) : newContent, false) await this.diffViewProvider.update(
everyLineHasLineNumbers(newContent) ? stripLineNumbers(newContent) : newContent,
false,
)
break break
} else { } else {
if (!relPath) { if (!relPath) {
@@ -1208,7 +1228,9 @@ export class Cline {
} }
if (!predictedLineCount) { if (!predictedLineCount) {
this.consecutiveMistakeCount++ this.consecutiveMistakeCount++
pushToolResult(await this.sayAndCreateMissingParamError("write_to_file", "line_count")) pushToolResult(
await this.sayAndCreateMissingParamError("write_to_file", "line_count"),
)
await this.diffViewProvider.reset() await this.diffViewProvider.reset()
break break
} }
@@ -1223,17 +1245,28 @@ export class Cline {
await this.ask("tool", partialMessage, true).catch(() => {}) // sending true for partial even though it's not a partial, this shows the edit row before the content is streamed into the editor await this.ask("tool", partialMessage, true).catch(() => {}) // sending true for partial even though it's not a partial, this shows the edit row before the content is streamed into the editor
await this.diffViewProvider.open(relPath) await this.diffViewProvider.open(relPath)
} }
await this.diffViewProvider.update(everyLineHasLineNumbers(newContent) ? stripLineNumbers(newContent) : newContent, true) await this.diffViewProvider.update(
everyLineHasLineNumbers(newContent) ? stripLineNumbers(newContent) : newContent,
true,
)
await delay(300) // wait for diff view to update await delay(300) // wait for diff view to update
this.diffViewProvider.scrollToFirstDiff() this.diffViewProvider.scrollToFirstDiff()
// Check for code omissions before proceeding // Check for code omissions before proceeding
if (detectCodeOmission(this.diffViewProvider.originalContent || "", newContent, predictedLineCount)) { if (
detectCodeOmission(
this.diffViewProvider.originalContent || "",
newContent,
predictedLineCount,
)
) {
if (this.diffStrategy) { if (this.diffStrategy) {
await this.diffViewProvider.revertChanges() await this.diffViewProvider.revertChanges()
pushToolResult(formatResponse.toolError( pushToolResult(
`Content appears to be truncated (file has ${newContent.split("\n").length} lines but was predicted to have ${predictedLineCount} lines), and found comments indicating omitted code (e.g., '// rest of code unchanged', '/* previous code */'). Please provide the complete file content without any omissions if possible, or otherwise use the 'apply_diff' tool to apply the diff to the original file.` formatResponse.toolError(
)) `Content appears to be truncated (file has ${newContent.split("\n").length} lines but was predicted to have ${predictedLineCount} lines), and found comments indicating omitted code (e.g., '// rest of code unchanged', '/* previous code */'). Please provide the complete file content without any omissions if possible, or otherwise use the 'apply_diff' tool to apply the diff to the original file.`,
),
)
break break
} else { } else {
vscode.window vscode.window
@@ -1284,7 +1317,7 @@ export class Cline {
pushToolResult( pushToolResult(
`The user made the following updates to your content:\n\n${userEdits}\n\n` + `The user made the following updates to your content:\n\n${userEdits}\n\n` +
`The updated content, which includes both your original modifications and the user's edits, has been successfully saved to ${relPath.toPosix()}. Here is the full, updated content of the file, including line numbers:\n\n` + `The updated content, which includes both your original modifications and the user's edits, has been successfully saved to ${relPath.toPosix()}. Here is the full, updated content of the file, including line numbers:\n\n` +
`<final_file_content path="${relPath.toPosix()}">\n${addLineNumbers(finalContent || '')}\n</final_file_content>\n\n` + `<final_file_content path="${relPath.toPosix()}">\n${addLineNumbers(finalContent || "")}\n</final_file_content>\n\n` +
`Please note:\n` + `Please note:\n` +
`1. You do not need to re-write the file with these changes, as they have already been applied.\n` + `1. You do not need to re-write the file with these changes, as they have already been applied.\n` +
`2. Proceed with the task using this updated file content as the new baseline.\n` + `2. Proceed with the task using this updated file content as the new baseline.\n` +
@@ -1346,21 +1379,24 @@ export class Cline {
const originalContent = await fs.readFile(absolutePath, "utf-8") const originalContent = await fs.readFile(absolutePath, "utf-8")
// Apply the diff to the original content // Apply the diff to the original content
const diffResult = await this.diffStrategy?.applyDiff( const diffResult = (await this.diffStrategy?.applyDiff(
originalContent, originalContent,
diffContent, diffContent,
parseInt(block.params.start_line ?? ''), parseInt(block.params.start_line ?? ""),
parseInt(block.params.end_line ?? '') parseInt(block.params.end_line ?? ""),
) ?? { )) ?? {
success: false, success: false,
error: "No diff strategy available" error: "No diff strategy available",
} }
if (!diffResult.success) { if (!diffResult.success) {
this.consecutiveMistakeCount++ this.consecutiveMistakeCount++
const currentCount = (this.consecutiveMistakeCountForApplyDiff.get(relPath) || 0) + 1 const currentCount =
(this.consecutiveMistakeCountForApplyDiff.get(relPath) || 0) + 1
this.consecutiveMistakeCountForApplyDiff.set(relPath, currentCount) this.consecutiveMistakeCountForApplyDiff.set(relPath, currentCount)
const errorDetails = diffResult.details ? JSON.stringify(diffResult.details, null, 2) : '' const errorDetails = diffResult.details
const formattedError = `Unable to apply diff to file: ${absolutePath}\n\n<error_details>\n${diffResult.error}${errorDetails ? `\n\nDetails:\n${errorDetails}` : ''}\n</error_details>` ? JSON.stringify(diffResult.details, null, 2)
: ""
const formattedError = `Unable to apply diff to file: ${absolutePath}\n\n<error_details>\n${diffResult.error}${errorDetails ? `\n\nDetails:\n${errorDetails}` : ""}\n</error_details>`
if (currentCount >= 2) { if (currentCount >= 2) {
await this.say("error", formattedError) await this.say("error", formattedError)
} }
@@ -1372,9 +1408,9 @@ export class Cline {
this.consecutiveMistakeCountForApplyDiff.delete(relPath) this.consecutiveMistakeCountForApplyDiff.delete(relPath)
// Show diff view before asking for approval // Show diff view before asking for approval
this.diffViewProvider.editType = "modify" this.diffViewProvider.editType = "modify"
await this.diffViewProvider.open(relPath); await this.diffViewProvider.open(relPath)
await this.diffViewProvider.update(diffResult.content, true); await this.diffViewProvider.update(diffResult.content, true)
await this.diffViewProvider.scrollToFirstDiff(); await this.diffViewProvider.scrollToFirstDiff()
const completeMessage = JSON.stringify({ const completeMessage = JSON.stringify({
...sharedMessageProps, ...sharedMessageProps,
@@ -1402,7 +1438,7 @@ export class Cline {
pushToolResult( pushToolResult(
`The user made the following updates to your content:\n\n${userEdits}\n\n` + `The user made the following updates to your content:\n\n${userEdits}\n\n` +
`The updated content, which includes both your original modifications and the user's edits, has been successfully saved to ${relPath.toPosix()}. Here is the full, updated content of the file, including line numbers:\n\n` + `The updated content, which includes both your original modifications and the user's edits, has been successfully saved to ${relPath.toPosix()}. Here is the full, updated content of the file, including line numbers:\n\n` +
`<final_file_content path="${relPath.toPosix()}">\n${addLineNumbers(finalContent || '')}\n</final_file_content>\n\n` + `<final_file_content path="${relPath.toPosix()}">\n${addLineNumbers(finalContent || "")}\n</final_file_content>\n\n` +
`Please note:\n` + `Please note:\n` +
`1. You do not need to re-write the file with these changes, as they have already been applied.\n` + `1. You do not need to re-write the file with these changes, as they have already been applied.\n` +
`2. Proceed with the task using this updated file content as the new baseline.\n` + `2. Proceed with the task using this updated file content as the new baseline.\n` +
@@ -1410,7 +1446,9 @@ export class Cline {
`${newProblemsMessage}`, `${newProblemsMessage}`,
) )
} else { } else {
pushToolResult(`Changes successfully applied to ${relPath.toPosix()}:\n\n${newProblemsMessage}`) pushToolResult(
`Changes successfully applied to ${relPath.toPosix()}:\n\n${newProblemsMessage}`,
)
} }
await this.diffViewProvider.reset() await this.diffViewProvider.reset()
break break
@@ -1614,7 +1652,7 @@ export class Cline {
await this.ask( await this.ask(
"browser_action_launch", "browser_action_launch",
removeClosingTag("url", url), removeClosingTag("url", url),
block.partial block.partial,
).catch(() => {}) ).catch(() => {})
} else { } else {
await this.say( await this.say(
@@ -1743,7 +1781,7 @@ export class Cline {
try { try {
if (block.partial) { if (block.partial) {
await this.ask("command", removeClosingTag("command", command), block.partial).catch( await this.ask("command", removeClosingTag("command", command), block.partial).catch(
() => {} () => {},
) )
break break
} else { } else {
@@ -2408,7 +2446,7 @@ export class Cline {
Promise.all( Promise.all(
userContent.map(async (block) => { userContent.map(async (block) => {
const shouldProcessMentions = (text: string) => const shouldProcessMentions = (text: string) =>
text.includes("<task>") || text.includes("<feedback>"); text.includes("<task>") || text.includes("<feedback>")
if (block.type === "text") { if (block.type === "text") {
if (shouldProcessMentions(block.text)) { if (shouldProcessMentions(block.text)) {
@@ -2417,7 +2455,7 @@ export class Cline {
text: await parseMentions(block.text, cwd, this.urlContentFetcher), text: await parseMentions(block.text, cwd, this.urlContentFetcher),
} }
} }
return block; return block
} else if (block.type === "tool_result") { } else if (block.type === "tool_result") {
if (typeof block.content === "string") { if (typeof block.content === "string") {
if (shouldProcessMentions(block.content)) { if (shouldProcessMentions(block.content)) {
@@ -2426,7 +2464,7 @@ export class Cline {
content: await parseMentions(block.content, cwd, this.urlContentFetcher), content: await parseMentions(block.content, cwd, this.urlContentFetcher),
} }
} }
return block; return block
} else if (Array.isArray(block.content)) { } else if (Array.isArray(block.content)) {
const parsedContent = await Promise.all( const parsedContent = await Promise.all(
block.content.map(async (contentBlock) => { block.content.map(async (contentBlock) => {
@@ -2444,7 +2482,7 @@ export class Cline {
content: parsedContent, content: parsedContent,
} }
} }
return block; return block
} }
return block return block
}), }),
@@ -2570,27 +2608,30 @@ export class Cline {
// Add current time information with timezone // Add current time information with timezone
const now = new Date() const now = new Date()
const formatter = new Intl.DateTimeFormat(undefined, { const formatter = new Intl.DateTimeFormat(undefined, {
year: 'numeric', year: "numeric",
month: 'numeric', month: "numeric",
day: 'numeric', day: "numeric",
hour: 'numeric', hour: "numeric",
minute: 'numeric', minute: "numeric",
second: 'numeric', second: "numeric",
hour12: true hour12: true,
}) })
const timeZone = formatter.resolvedOptions().timeZone const timeZone = formatter.resolvedOptions().timeZone
const timeZoneOffset = -now.getTimezoneOffset() / 60 // Convert to hours and invert sign to match conventional notation const timeZoneOffset = -now.getTimezoneOffset() / 60 // Convert to hours and invert sign to match conventional notation
const timeZoneOffsetStr = `${timeZoneOffset >= 0 ? '+' : ''}${timeZoneOffset}:00` const timeZoneOffsetStr = `${timeZoneOffset >= 0 ? "+" : ""}${timeZoneOffset}:00`
details += `\n\n# Current Time\n${formatter.format(now)} (${timeZone}, UTC${timeZoneOffsetStr})` details += `\n\n# Current Time\n${formatter.format(now)} (${timeZone}, UTC${timeZoneOffsetStr})`
// Add current mode and any mode-specific warnings // Add current mode and any mode-specific warnings
const { mode } = await this.providerRef.deref()?.getState() ?? {} const { mode } = (await this.providerRef.deref()?.getState()) ?? {}
const currentMode = mode ?? codeMode const currentMode = mode ?? defaultModeSlug
details += `\n\n# Current Mode\n${currentMode}` details += `\n\n# Current Mode\n${currentMode}`
// Add warning if not in code mode // Add warning if not in code mode
if (!isToolAllowedForMode('write_to_file', currentMode) || !isToolAllowedForMode('execute_command', currentMode)) { if (
details += `\n\nNOTE: You are currently in '${currentMode}' mode which only allows read-only operations. To write files or execute commands, the user will need to switch to 'code' mode. Note that only the user can switch modes.` !isToolAllowedForMode("write_to_file", currentMode) ||
!isToolAllowedForMode("execute_command", currentMode)
) {
details += `\n\nNOTE: You are currently in '${currentMode}' mode which only allows read-only operations. To write files or execute commands, the user will need to switch to '${defaultModeSlug}' mode. Note that only the user can switch modes.`
} }
if (includeFileDetails) { if (includeFileDetails) {

File diff suppressed because it is too large Load Diff

View File

@@ -1,88 +1,52 @@
import { isToolAllowedForMode, validateToolUse } from '../mode-validator' import { Mode, isToolAllowedForMode, TestToolName, getModeConfig, modes } from "../../shared/modes"
import { codeMode, architectMode, askMode } from '../prompts/system' import { validateToolUse } from "../mode-validator"
import { CODE_ALLOWED_TOOLS, READONLY_ALLOWED_TOOLS, ToolName } from '../tool-lists'
// For testing purposes, we need to handle the 'unknown_tool' case const asTestTool = (tool: string): TestToolName => tool as TestToolName
type TestToolName = ToolName | 'unknown_tool'; const [codeMode, architectMode, askMode] = modes.map((mode) => mode.slug)
// Helper function to safely cast string to TestToolName for testing describe("mode-validator", () => {
function asTestTool(str: string): TestToolName { describe("isToolAllowedForMode", () => {
return str as TestToolName; describe("code mode", () => {
} it("allows all code mode tools", () => {
const mode = getModeConfig(codeMode)
describe('mode-validator', () => { mode.tools.forEach(([tool]) => {
describe('isToolAllowedForMode', () => {
describe('code mode', () => {
it('allows all code mode tools', () => {
CODE_ALLOWED_TOOLS.forEach(tool => {
expect(isToolAllowedForMode(tool, codeMode)).toBe(true) expect(isToolAllowedForMode(tool, codeMode)).toBe(true)
}) })
}) })
it('disallows unknown tools', () => { it("disallows unknown tools", () => {
expect(isToolAllowedForMode(asTestTool('unknown_tool'), codeMode)).toBe(false) expect(isToolAllowedForMode(asTestTool("unknown_tool"), codeMode)).toBe(false)
}) })
}) })
describe('architect mode', () => { describe("architect mode", () => {
it('allows only read-only and MCP tools', () => { it("allows configured tools", () => {
// Test allowed tools const mode = getModeConfig(architectMode)
READONLY_ALLOWED_TOOLS.forEach(tool => { mode.tools.forEach(([tool]) => {
expect(isToolAllowedForMode(tool, architectMode)).toBe(true) expect(isToolAllowedForMode(tool, architectMode)).toBe(true)
}) })
// Test specific disallowed tools that we know are in CODE_ALLOWED_TOOLS but not in READONLY_ALLOWED_TOOLS
const disallowedTools = ['execute_command', 'write_to_file', 'apply_diff'] as const;
disallowedTools.forEach(tool => {
expect(isToolAllowedForMode(tool as ToolName, architectMode)).toBe(false)
})
}) })
}) })
describe('ask mode', () => { describe("ask mode", () => {
it('allows only read-only and MCP tools', () => { it("allows configured tools", () => {
// Test allowed tools const mode = getModeConfig(askMode)
READONLY_ALLOWED_TOOLS.forEach(tool => { mode.tools.forEach(([tool]) => {
expect(isToolAllowedForMode(tool, askMode)).toBe(true) expect(isToolAllowedForMode(tool, askMode)).toBe(true)
}) })
// Test specific disallowed tools that we know are in CODE_ALLOWED_TOOLS but not in READONLY_ALLOWED_TOOLS
const disallowedTools = ['execute_command', 'write_to_file', 'apply_diff'] as const;
disallowedTools.forEach(tool => {
expect(isToolAllowedForMode(tool as ToolName, askMode)).toBe(false)
})
}) })
}) })
}) })
describe('validateToolUse', () => { describe("validateToolUse", () => {
it('throws error for disallowed tools in architect mode', () => { it("throws error for disallowed tools in architect mode", () => {
expect(() => validateToolUse('write_to_file' as ToolName, architectMode)).toThrow( expect(() => validateToolUse("unknown_tool", "architect")).toThrow(
'Tool "write_to_file" is not allowed in architect mode.' 'Tool "unknown_tool" is not allowed in architect mode.',
) )
}) })
it('throws error for disallowed tools in ask mode', () => { it("does not throw for allowed tools in architect mode", () => {
expect(() => validateToolUse('execute_command' as ToolName, askMode)).toThrow( expect(() => validateToolUse("read_file", "architect")).not.toThrow()
'Tool "execute_command" is not allowed in ask mode.'
)
})
it('throws error for unknown tools in code mode', () => {
expect(() => validateToolUse(asTestTool('unknown_tool'), codeMode)).toThrow(
'Tool "unknown_tool" is not allowed in code mode.'
)
})
it('does not throw for allowed tools', () => {
// Code mode
expect(() => validateToolUse('write_to_file' as ToolName, codeMode)).not.toThrow()
// Architect mode
expect(() => validateToolUse('read_file' as ToolName, architectMode)).not.toThrow()
// Ask mode
expect(() => validateToolUse('browser_action' as ToolName, askMode)).not.toThrow()
}) })
}) })
}) })

View File

@@ -1,7 +1,7 @@
import { ExtensionContext } from 'vscode' import { ExtensionContext } from "vscode"
import { ApiConfiguration } from '../../shared/api' import { ApiConfiguration } from "../../shared/api"
import { Mode } from '../prompts/types' import { Mode } from "../prompts/types"
import { ApiConfigMeta } from '../../shared/ExtensionMessage' import { ApiConfigMeta } from "../../shared/ExtensionMessage"
export interface ApiConfigData { export interface ApiConfigData {
currentApiConfigName: string currentApiConfigName: string
@@ -13,12 +13,12 @@ export interface ApiConfigData {
export class ConfigManager { export class ConfigManager {
private readonly defaultConfig: ApiConfigData = { private readonly defaultConfig: ApiConfigData = {
currentApiConfigName: 'default', currentApiConfigName: "default",
apiConfigs: { apiConfigs: {
default: { default: {
id: this.generateId() id: this.generateId(),
} },
} },
} }
private readonly SCOPE_PREFIX = "roo_cline_config_" private readonly SCOPE_PREFIX = "roo_cline_config_"
@@ -69,7 +69,7 @@ export class ConfigManager {
const config = await this.readConfig() const config = await this.readConfig()
return Object.entries(config.apiConfigs).map(([name, apiConfig]) => ({ return Object.entries(config.apiConfigs).map(([name, apiConfig]) => ({
name, name,
id: apiConfig.id || '', id: apiConfig.id || "",
apiProvider: apiConfig.apiProvider, apiProvider: apiConfig.apiProvider,
})) }))
} catch (error) { } catch (error) {
@@ -86,7 +86,7 @@ export class ConfigManager {
const existingConfig = currentConfig.apiConfigs[name] const existingConfig = currentConfig.apiConfigs[name]
currentConfig.apiConfigs[name] = { currentConfig.apiConfigs[name] = {
...config, ...config,
id: existingConfig?.id || this.generateId() id: existingConfig?.id || this.generateId(),
} }
await this.writeConfig(currentConfig) await this.writeConfig(currentConfig)
} catch (error) { } catch (error) {
@@ -106,7 +106,7 @@ export class ConfigManager {
throw new Error(`Config '${name}' not found`) throw new Error(`Config '${name}' not found`)
} }
config.currentApiConfigName = name; config.currentApiConfigName = name
await this.writeConfig(config) await this.writeConfig(config)
return apiConfig return apiConfig

View File

@@ -1,19 +1,19 @@
import { ExtensionContext } from 'vscode' import { ExtensionContext } from "vscode"
import { ConfigManager, ApiConfigData } from '../ConfigManager' import { ConfigManager, ApiConfigData } from "../ConfigManager"
import { ApiConfiguration } from '../../../shared/api' import { ApiConfiguration } from "../../../shared/api"
// Mock VSCode ExtensionContext // Mock VSCode ExtensionContext
const mockSecrets = { const mockSecrets = {
get: jest.fn(), get: jest.fn(),
store: jest.fn(), store: jest.fn(),
delete: jest.fn() delete: jest.fn(),
} }
const mockContext = { const mockContext = {
secrets: mockSecrets secrets: mockSecrets,
} as unknown as ExtensionContext } as unknown as ExtensionContext
describe('ConfigManager', () => { describe("ConfigManager", () => {
let configManager: ConfigManager let configManager: ConfigManager
beforeEach(() => { beforeEach(() => {
@@ -21,8 +21,8 @@ describe('ConfigManager', () => {
configManager = new ConfigManager(mockContext) configManager = new ConfigManager(mockContext)
}) })
describe('initConfig', () => { describe("initConfig", () => {
it('should not write to storage when secrets.get returns null', async () => { it("should not write to storage when secrets.get returns null", async () => {
// Mock readConfig to return null // Mock readConfig to return null
mockSecrets.get.mockResolvedValueOnce(null) mockSecrets.get.mockResolvedValueOnce(null)
@@ -32,35 +32,39 @@ describe('ConfigManager', () => {
expect(mockSecrets.store).not.toHaveBeenCalled() expect(mockSecrets.store).not.toHaveBeenCalled()
}) })
it('should not initialize config if it exists', async () => { it("should not initialize config if it exists", async () => {
mockSecrets.get.mockResolvedValue(JSON.stringify({ mockSecrets.get.mockResolvedValue(
currentApiConfigName: 'default', JSON.stringify({
currentApiConfigName: "default",
apiConfigs: { apiConfigs: {
default: { default: {
config: {}, config: {},
id: 'default' id: "default",
} },
} },
})) }),
)
await configManager.initConfig() await configManager.initConfig()
expect(mockSecrets.store).not.toHaveBeenCalled() expect(mockSecrets.store).not.toHaveBeenCalled()
}) })
it('should generate IDs for configs that lack them', async () => { it("should generate IDs for configs that lack them", async () => {
// Mock a config with missing IDs // Mock a config with missing IDs
mockSecrets.get.mockResolvedValue(JSON.stringify({ mockSecrets.get.mockResolvedValue(
currentApiConfigName: 'default', JSON.stringify({
currentApiConfigName: "default",
apiConfigs: { apiConfigs: {
default: { default: {
config: {} config: {},
}, },
test: { test: {
apiProvider: 'anthropic' apiProvider: "anthropic",
} },
} },
})) }),
)
await configManager.initConfig() await configManager.initConfig()
@@ -71,53 +75,53 @@ describe('ConfigManager', () => {
expect(storedConfig.apiConfigs.test.id).toBeTruthy() expect(storedConfig.apiConfigs.test.id).toBeTruthy()
}) })
it('should throw error if secrets storage fails', async () => { it("should throw error if secrets storage fails", async () => {
mockSecrets.get.mockRejectedValue(new Error('Storage failed')) mockSecrets.get.mockRejectedValue(new Error("Storage failed"))
await expect(configManager.initConfig()).rejects.toThrow( await expect(configManager.initConfig()).rejects.toThrow(
'Failed to initialize config: Error: Failed to read config from secrets: Error: Storage failed' "Failed to initialize config: Error: Failed to read config from secrets: Error: Storage failed",
) )
}) })
}) })
describe('ListConfig', () => { describe("ListConfig", () => {
it('should list all available configs', async () => { it("should list all available configs", async () => {
const existingConfig: ApiConfigData = { const existingConfig: ApiConfigData = {
currentApiConfigName: 'default', currentApiConfigName: "default",
apiConfigs: { apiConfigs: {
default: { default: {
id: 'default' id: "default",
}, },
test: { test: {
apiProvider: 'anthropic', apiProvider: "anthropic",
id: 'test-id' id: "test-id",
} },
}, },
modeApiConfigs: { modeApiConfigs: {
code: 'default', code: "default",
architect: 'default', architect: "default",
ask: 'default' ask: "default",
} },
} }
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
const configs = await configManager.ListConfig() const configs = await configManager.ListConfig()
expect(configs).toEqual([ expect(configs).toEqual([
{ name: 'default', id: 'default', apiProvider: undefined }, { name: "default", id: "default", apiProvider: undefined },
{ name: 'test', id: 'test-id', apiProvider: 'anthropic' } { name: "test", id: "test-id", apiProvider: "anthropic" },
]) ])
}) })
it('should handle empty config file', async () => { it("should handle empty config file", async () => {
const emptyConfig: ApiConfigData = { const emptyConfig: ApiConfigData = {
currentApiConfigName: 'default', currentApiConfigName: "default",
apiConfigs: {}, apiConfigs: {},
modeApiConfigs: { modeApiConfigs: {
code: 'default', code: "default",
architect: 'default', architect: "default",
ask: 'default' ask: "default",
} },
} }
mockSecrets.get.mockResolvedValue(JSON.stringify(emptyConfig)) mockSecrets.get.mockResolvedValue(JSON.stringify(emptyConfig))
@@ -126,326 +130,340 @@ describe('ConfigManager', () => {
expect(configs).toEqual([]) expect(configs).toEqual([])
}) })
it('should throw error if reading from secrets fails', async () => { it("should throw error if reading from secrets fails", async () => {
mockSecrets.get.mockRejectedValue(new Error('Read failed')) mockSecrets.get.mockRejectedValue(new Error("Read failed"))
await expect(configManager.ListConfig()).rejects.toThrow( await expect(configManager.ListConfig()).rejects.toThrow(
'Failed to list configs: Error: Failed to read config from secrets: Error: Read failed' "Failed to list configs: Error: Failed to read config from secrets: Error: Read failed",
) )
}) })
}) })
describe('SaveConfig', () => { describe("SaveConfig", () => {
it('should save new config', async () => { it("should save new config", async () => {
mockSecrets.get.mockResolvedValue(JSON.stringify({ mockSecrets.get.mockResolvedValue(
currentApiConfigName: 'default', JSON.stringify({
currentApiConfigName: "default",
apiConfigs: { apiConfigs: {
default: {} default: {},
}, },
modeApiConfigs: { modeApiConfigs: {
code: 'default', code: "default",
architect: 'default', architect: "default",
ask: 'default' ask: "default",
} },
})) }),
)
const newConfig: ApiConfiguration = { const newConfig: ApiConfiguration = {
apiProvider: 'anthropic', apiProvider: "anthropic",
apiKey: 'test-key' apiKey: "test-key",
} }
await configManager.SaveConfig('test', newConfig) await configManager.SaveConfig("test", newConfig)
// Get the actual stored config to check the generated ID // Get the actual stored config to check the generated ID
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1]) const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
const testConfigId = storedConfig.apiConfigs.test.id const testConfigId = storedConfig.apiConfigs.test.id
const expectedConfig = { const expectedConfig = {
currentApiConfigName: 'default', currentApiConfigName: "default",
apiConfigs: { apiConfigs: {
default: {}, default: {},
test: { test: {
...newConfig, ...newConfig,
id: testConfigId id: testConfigId,
} },
}, },
modeApiConfigs: { modeApiConfigs: {
code: 'default', code: "default",
architect: 'default', architect: "default",
ask: 'default' ask: "default",
} },
} }
expect(mockSecrets.store).toHaveBeenCalledWith( expect(mockSecrets.store).toHaveBeenCalledWith(
'roo_cline_config_api_config', "roo_cline_config_api_config",
JSON.stringify(expectedConfig, null, 2) JSON.stringify(expectedConfig, null, 2),
) )
}) })
it('should update existing config', async () => { it("should update existing config", async () => {
const existingConfig: ApiConfigData = { const existingConfig: ApiConfigData = {
currentApiConfigName: 'default', currentApiConfigName: "default",
apiConfigs: { apiConfigs: {
test: { test: {
apiProvider: 'anthropic', apiProvider: "anthropic",
apiKey: 'old-key', apiKey: "old-key",
id: 'test-id' id: "test-id",
} },
} },
} }
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
const updatedConfig: ApiConfiguration = { const updatedConfig: ApiConfiguration = {
apiProvider: 'anthropic', apiProvider: "anthropic",
apiKey: 'new-key' apiKey: "new-key",
} }
await configManager.SaveConfig('test', updatedConfig) await configManager.SaveConfig("test", updatedConfig)
const expectedConfig = { const expectedConfig = {
currentApiConfigName: 'default', currentApiConfigName: "default",
apiConfigs: { apiConfigs: {
test: { test: {
apiProvider: 'anthropic', apiProvider: "anthropic",
apiKey: 'new-key', apiKey: "new-key",
id: 'test-id' id: "test-id",
} },
} },
} }
expect(mockSecrets.store).toHaveBeenCalledWith( expect(mockSecrets.store).toHaveBeenCalledWith(
'roo_cline_config_api_config', "roo_cline_config_api_config",
JSON.stringify(expectedConfig, null, 2) JSON.stringify(expectedConfig, null, 2),
) )
}) })
it('should throw error if secrets storage fails', async () => { it("should throw error if secrets storage fails", async () => {
mockSecrets.get.mockResolvedValue(JSON.stringify({ mockSecrets.get.mockResolvedValue(
currentApiConfigName: 'default', JSON.stringify({
apiConfigs: { default: {} } currentApiConfigName: "default",
})) apiConfigs: { default: {} },
mockSecrets.store.mockRejectedValueOnce(new Error('Storage failed')) }),
)
mockSecrets.store.mockRejectedValueOnce(new Error("Storage failed"))
await expect(configManager.SaveConfig('test', {})).rejects.toThrow( await expect(configManager.SaveConfig("test", {})).rejects.toThrow(
'Failed to save config: Error: Failed to write config to secrets: Error: Storage failed' "Failed to save config: Error: Failed to write config to secrets: Error: Storage failed",
) )
}) })
}) })
describe('DeleteConfig', () => { describe("DeleteConfig", () => {
it('should delete existing config', async () => { it("should delete existing config", async () => {
const existingConfig: ApiConfigData = { const existingConfig: ApiConfigData = {
currentApiConfigName: 'default', currentApiConfigName: "default",
apiConfigs: { apiConfigs: {
default: { default: {
id: 'default' id: "default",
}, },
test: { test: {
apiProvider: 'anthropic', apiProvider: "anthropic",
id: 'test-id' id: "test-id",
} },
} },
} }
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
await configManager.DeleteConfig('test') await configManager.DeleteConfig("test")
// Get the stored config to check the ID // Get the stored config to check the ID
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1]) const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
expect(storedConfig.currentApiConfigName).toBe('default') expect(storedConfig.currentApiConfigName).toBe("default")
expect(Object.keys(storedConfig.apiConfigs)).toEqual(['default']) expect(Object.keys(storedConfig.apiConfigs)).toEqual(["default"])
expect(storedConfig.apiConfigs.default.id).toBeTruthy() expect(storedConfig.apiConfigs.default.id).toBeTruthy()
}) })
it('should throw error when trying to delete non-existent config', async () => { it("should throw error when trying to delete non-existent config", async () => {
mockSecrets.get.mockResolvedValue(JSON.stringify({ mockSecrets.get.mockResolvedValue(
currentApiConfigName: 'default', JSON.stringify({
apiConfigs: { default: {} } currentApiConfigName: "default",
})) apiConfigs: { default: {} },
}),
await expect(configManager.DeleteConfig('nonexistent')).rejects.toThrow(
"Config 'nonexistent' not found"
) )
await expect(configManager.DeleteConfig("nonexistent")).rejects.toThrow("Config 'nonexistent' not found")
}) })
it('should throw error when trying to delete last remaining config', async () => { it("should throw error when trying to delete last remaining config", async () => {
mockSecrets.get.mockResolvedValue(JSON.stringify({ mockSecrets.get.mockResolvedValue(
currentApiConfigName: 'default', JSON.stringify({
currentApiConfigName: "default",
apiConfigs: { apiConfigs: {
default: { default: {
id: 'default' id: "default",
} },
} },
})) }),
)
await expect(configManager.DeleteConfig('default')).rejects.toThrow( await expect(configManager.DeleteConfig("default")).rejects.toThrow(
'Cannot delete the last remaining configuration.' "Cannot delete the last remaining configuration.",
) )
}) })
}) })
describe('LoadConfig', () => { describe("LoadConfig", () => {
it('should load config and update current config name', async () => { it("should load config and update current config name", async () => {
const existingConfig: ApiConfigData = { const existingConfig: ApiConfigData = {
currentApiConfigName: 'default', currentApiConfigName: "default",
apiConfigs: { apiConfigs: {
test: { test: {
apiProvider: 'anthropic', apiProvider: "anthropic",
apiKey: 'test-key', apiKey: "test-key",
id: 'test-id' id: "test-id",
} },
} },
} }
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
const config = await configManager.LoadConfig('test') const config = await configManager.LoadConfig("test")
expect(config).toEqual({ expect(config).toEqual({
apiProvider: 'anthropic', apiProvider: "anthropic",
apiKey: 'test-key', apiKey: "test-key",
id: 'test-id' id: "test-id",
}) })
// Get the stored config to check the structure // Get the stored config to check the structure
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1]) const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
expect(storedConfig.currentApiConfigName).toBe('test') expect(storedConfig.currentApiConfigName).toBe("test")
expect(storedConfig.apiConfigs.test).toEqual({ expect(storedConfig.apiConfigs.test).toEqual({
apiProvider: 'anthropic', apiProvider: "anthropic",
apiKey: 'test-key', apiKey: "test-key",
id: 'test-id' id: "test-id",
}) })
}) })
it('should throw error when config does not exist', async () => { it("should throw error when config does not exist", async () => {
mockSecrets.get.mockResolvedValue(JSON.stringify({ mockSecrets.get.mockResolvedValue(
currentApiConfigName: 'default', JSON.stringify({
currentApiConfigName: "default",
apiConfigs: { apiConfigs: {
default: { default: {
config: {}, config: {},
id: 'default' id: "default",
} },
} },
})) }),
await expect(configManager.LoadConfig('nonexistent')).rejects.toThrow(
"Config 'nonexistent' not found"
) )
await expect(configManager.LoadConfig("nonexistent")).rejects.toThrow("Config 'nonexistent' not found")
}) })
it('should throw error if secrets storage fails', async () => { it("should throw error if secrets storage fails", async () => {
mockSecrets.get.mockResolvedValue(JSON.stringify({ mockSecrets.get.mockResolvedValue(
currentApiConfigName: 'default', JSON.stringify({
currentApiConfigName: "default",
apiConfigs: { apiConfigs: {
test: { test: {
config: { config: {
apiProvider: 'anthropic' apiProvider: "anthropic",
}, },
id: 'test-id' id: "test-id",
} },
} },
})) }),
mockSecrets.store.mockRejectedValueOnce(new Error('Storage failed')) )
mockSecrets.store.mockRejectedValueOnce(new Error("Storage failed"))
await expect(configManager.LoadConfig('test')).rejects.toThrow( await expect(configManager.LoadConfig("test")).rejects.toThrow(
'Failed to load config: Error: Failed to write config to secrets: Error: Storage failed' "Failed to load config: Error: Failed to write config to secrets: Error: Storage failed",
) )
}) })
}) })
describe('SetCurrentConfig', () => { describe("SetCurrentConfig", () => {
it('should set current config', async () => { it("should set current config", async () => {
const existingConfig: ApiConfigData = { const existingConfig: ApiConfigData = {
currentApiConfigName: 'default', currentApiConfigName: "default",
apiConfigs: { apiConfigs: {
default: { default: {
id: 'default' id: "default",
}, },
test: { test: {
apiProvider: 'anthropic', apiProvider: "anthropic",
id: 'test-id' id: "test-id",
} },
} },
} }
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
await configManager.SetCurrentConfig('test') await configManager.SetCurrentConfig("test")
// Get the stored config to check the structure // Get the stored config to check the structure
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1]) const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
expect(storedConfig.currentApiConfigName).toBe('test') expect(storedConfig.currentApiConfigName).toBe("test")
expect(storedConfig.apiConfigs.default.id).toBe('default') expect(storedConfig.apiConfigs.default.id).toBe("default")
expect(storedConfig.apiConfigs.test).toEqual({ expect(storedConfig.apiConfigs.test).toEqual({
apiProvider: 'anthropic', apiProvider: "anthropic",
id: 'test-id' id: "test-id",
}) })
}) })
it('should throw error when config does not exist', async () => { it("should throw error when config does not exist", async () => {
mockSecrets.get.mockResolvedValue(JSON.stringify({ mockSecrets.get.mockResolvedValue(
currentApiConfigName: 'default', JSON.stringify({
apiConfigs: { default: {} } currentApiConfigName: "default",
})) apiConfigs: { default: {} },
}),
)
await expect(configManager.SetCurrentConfig('nonexistent')).rejects.toThrow( await expect(configManager.SetCurrentConfig("nonexistent")).rejects.toThrow(
"Config 'nonexistent' not found" "Config 'nonexistent' not found",
) )
}) })
it('should throw error if secrets storage fails', async () => { it("should throw error if secrets storage fails", async () => {
mockSecrets.get.mockResolvedValue(JSON.stringify({ mockSecrets.get.mockResolvedValue(
currentApiConfigName: 'default', JSON.stringify({
currentApiConfigName: "default",
apiConfigs: { apiConfigs: {
test: { apiProvider: 'anthropic' } test: { apiProvider: "anthropic" },
} },
})) }),
mockSecrets.store.mockRejectedValueOnce(new Error('Storage failed')) )
mockSecrets.store.mockRejectedValueOnce(new Error("Storage failed"))
await expect(configManager.SetCurrentConfig('test')).rejects.toThrow( await expect(configManager.SetCurrentConfig("test")).rejects.toThrow(
'Failed to set current config: Error: Failed to write config to secrets: Error: Storage failed' "Failed to set current config: Error: Failed to write config to secrets: Error: Storage failed",
) )
}) })
}) })
describe('HasConfig', () => { describe("HasConfig", () => {
it('should return true for existing config', async () => { it("should return true for existing config", async () => {
const existingConfig: ApiConfigData = { const existingConfig: ApiConfigData = {
currentApiConfigName: 'default', currentApiConfigName: "default",
apiConfigs: { apiConfigs: {
default: { default: {
id: 'default' id: "default",
}, },
test: { test: {
apiProvider: 'anthropic', apiProvider: "anthropic",
id: 'test-id' id: "test-id",
} },
} },
} }
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig)) mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
const hasConfig = await configManager.HasConfig('test') const hasConfig = await configManager.HasConfig("test")
expect(hasConfig).toBe(true) expect(hasConfig).toBe(true)
}) })
it('should return false for non-existent config', async () => { it("should return false for non-existent config", async () => {
mockSecrets.get.mockResolvedValue(JSON.stringify({ mockSecrets.get.mockResolvedValue(
currentApiConfigName: 'default', JSON.stringify({
apiConfigs: { default: {} } currentApiConfigName: "default",
})) apiConfigs: { default: {} },
}),
)
const hasConfig = await configManager.HasConfig('nonexistent') const hasConfig = await configManager.HasConfig("nonexistent")
expect(hasConfig).toBe(false) expect(hasConfig).toBe(false)
}) })
it('should throw error if secrets storage fails', async () => { it("should throw error if secrets storage fails", async () => {
mockSecrets.get.mockRejectedValue(new Error('Storage failed')) mockSecrets.get.mockRejectedValue(new Error("Storage failed"))
await expect(configManager.HasConfig('test')).rejects.toThrow( await expect(configManager.HasConfig("test")).rejects.toThrow(
'Failed to check config existence: Error: Failed to read config from secrets: Error: Storage failed' "Failed to check config existence: Error: Failed to read config from secrets: Error: Storage failed",
) )
}) })
}) })

View File

@@ -1,13 +1,17 @@
import type { DiffStrategy } from './types' import type { DiffStrategy } from "./types"
import { UnifiedDiffStrategy } from './strategies/unified' import { UnifiedDiffStrategy } from "./strategies/unified"
import { SearchReplaceDiffStrategy } from './strategies/search-replace' import { SearchReplaceDiffStrategy } from "./strategies/search-replace"
import { NewUnifiedDiffStrategy } from './strategies/new-unified' import { NewUnifiedDiffStrategy } from "./strategies/new-unified"
/** /**
* Get the appropriate diff strategy for the given model * Get the appropriate diff strategy for the given model
* @param model The name of the model being used (e.g., 'gpt-4', 'claude-3-opus') * @param model The name of the model being used (e.g., 'gpt-4', 'claude-3-opus')
* @returns The appropriate diff strategy for the model * @returns The appropriate diff strategy for the model
*/ */
export function getDiffStrategy(model: string, fuzzyMatchThreshold?: number, experimentalDiffStrategy: boolean = false): DiffStrategy { export function getDiffStrategy(
model: string,
fuzzyMatchThreshold?: number,
experimentalDiffStrategy: boolean = false,
): DiffStrategy {
if (experimentalDiffStrategy) { if (experimentalDiffStrategy) {
return new NewUnifiedDiffStrategy(fuzzyMatchThreshold) return new NewUnifiedDiffStrategy(fuzzyMatchThreshold)
} }

View File

@@ -1,46 +1,45 @@
import { NewUnifiedDiffStrategy } from '../new-unified'; import { NewUnifiedDiffStrategy } from "../new-unified"
describe('main', () => {
describe("main", () => {
let strategy: NewUnifiedDiffStrategy let strategy: NewUnifiedDiffStrategy
beforeEach(() => { beforeEach(() => {
strategy = new NewUnifiedDiffStrategy(0.97) strategy = new NewUnifiedDiffStrategy(0.97)
}) })
describe('constructor', () => { describe("constructor", () => {
it('should use default confidence threshold when not provided', () => { it("should use default confidence threshold when not provided", () => {
const defaultStrategy = new NewUnifiedDiffStrategy() const defaultStrategy = new NewUnifiedDiffStrategy()
expect(defaultStrategy['confidenceThreshold']).toBe(1) expect(defaultStrategy["confidenceThreshold"]).toBe(1)
}) })
it('should use provided confidence threshold', () => { it("should use provided confidence threshold", () => {
const customStrategy = new NewUnifiedDiffStrategy(0.85) const customStrategy = new NewUnifiedDiffStrategy(0.85)
expect(customStrategy['confidenceThreshold']).toBe(0.85) expect(customStrategy["confidenceThreshold"]).toBe(0.85)
}) })
it('should enforce minimum confidence threshold', () => { it("should enforce minimum confidence threshold", () => {
const lowStrategy = new NewUnifiedDiffStrategy(0.7) // Below minimum of 0.8 const lowStrategy = new NewUnifiedDiffStrategy(0.7) // Below minimum of 0.8
expect(lowStrategy['confidenceThreshold']).toBe(0.8) expect(lowStrategy["confidenceThreshold"]).toBe(0.8)
}) })
}) })
describe('getToolDescription', () => { describe("getToolDescription", () => {
it('should return tool description with correct cwd', () => { it("should return tool description with correct cwd", () => {
const cwd = '/test/path' const cwd = "/test/path"
const description = strategy.getToolDescription(cwd) const description = strategy.getToolDescription({ cwd })
expect(description).toContain('apply_diff') expect(description).toContain("apply_diff")
expect(description).toContain(cwd) expect(description).toContain(cwd)
expect(description).toContain('Parameters:') expect(description).toContain("Parameters:")
expect(description).toContain('Format Requirements:') expect(description).toContain("Format Requirements:")
}) })
}) })
it('should apply simple diff correctly', async () => { it("should apply simple diff correctly", async () => {
const original = `line1 const original = `line1
line2 line2
line3`; line3`
const diff = `--- a/file.txt const diff = `--- a/file.txt
+++ b/file.txt +++ b/file.txt
@@ -49,24 +48,24 @@ line3`;
+new line +new line
line2 line2
-line3 -line3
+modified line3`; +modified line3`
const result = await strategy.applyDiff(original, diff); const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true); expect(result.success).toBe(true)
if(result.success) { if (result.success) {
expect(result.content).toBe(`line1 expect(result.content).toBe(`line1
new line new line
line2 line2
modified line3`); modified line3`)
} }
}); })
it('should handle multiple hunks', async () => { it("should handle multiple hunks", async () => {
const original = `line1 const original = `line1
line2 line2
line3 line3
line4 line4
line5`; line5`
const diff = `--- a/file.txt const diff = `--- a/file.txt
+++ b/file.txt +++ b/file.txt
@@ -80,10 +79,10 @@ line5`;
line4 line4
-line5 -line5
+modified line5 +modified line5
+new line at end`; +new line at end`
const result = await strategy.applyDiff(original, diff); const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true); expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(`line1 expect(result.content).toBe(`line1
new line new line
@@ -91,11 +90,11 @@ line2
modified line3 modified line3
line4 line4
modified line5 modified line5
new line at end`); new line at end`)
} }
}); })
it('should handle complex large', async () => { it("should handle complex large", async () => {
const original = `line1 const original = `line1
line2 line2
line3 line3
@@ -105,7 +104,7 @@ line6
line7 line7
line8 line8
line9 line9
line10`; line10`
const diff = `--- a/file.txt const diff = `--- a/file.txt
+++ b/file.txt +++ b/file.txt
@@ -130,10 +129,10 @@ line10`;
line9 line9
-line10 -line10
+final line +final line
+very last line`; +very last line`
const result = await strategy.applyDiff(original, diff); const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true); expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(`line1 expect(result.content).toBe(`line1
header line header line
@@ -150,11 +149,11 @@ changed line8
bonus line bonus line
line9 line9
final line final line
very last line`); very last line`)
} }
}); })
it('should handle indentation changes', async () => { it("should handle indentation changes", async () => {
const original = `first line const original = `first line
indented line indented line
double indented line double indented line
@@ -164,7 +163,7 @@ no indent
double indent again double indent again
triple indent triple indent
back to single back to single
last line`; last line`
const diff = `--- original const diff = `--- original
+++ modified +++ modified
@@ -181,7 +180,7 @@ last line`;
- triple indent - triple indent
+ hi there mate + hi there mate
back to single back to single
last line`; last line`
const expected = `first line const expected = `first line
indented line indented line
@@ -194,17 +193,16 @@ no indent
double indent again double indent again
hi there mate hi there mate
back to single back to single
last line`; last line`
const result = await strategy.applyDiff(original, diff); const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true); expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(expected); expect(result.content).toBe(expected)
} }
}); })
it('should handle high level edits', async () => {
it("should handle high level edits", async () => {
const original = `def factorial(n): const original = `def factorial(n):
if n == 0: if n == 0:
return 1 return 1
@@ -222,20 +220,20 @@ last line`;
+ else: + else:
+ return number * factorial(number-1)` + return number * factorial(number-1)`
const expected = `def factorial(number): const expected = `def factorial(number):
if number == 0: if number == 0:
return 1 return 1
else: else:
return number * factorial(number-1)` return number * factorial(number-1)`
const result = await strategy.applyDiff(original, diff); const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true); expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(expected); expect(result.content).toBe(expected)
} }
}); })
it('it should handle very complex edits', async () => { it("it should handle very complex edits", async () => {
const original = `//Initialize the array that will hold the primes const original = `//Initialize the array that will hold the primes
var primeArray = []; var primeArray = [];
/*Write a function that checks for primeness and /*Write a function that checks for primeness and
@@ -321,56 +319,55 @@ for (var i = 2; primeArray.length < numPrimes; i++) {
console.log(primeArray); console.log(primeArray);
` `
const result = await strategy.applyDiff(original, diff)
const result = await strategy.applyDiff(original, diff); expect(result.success).toBe(true)
expect(result.success).toBe(true);
if (result.success) { if (result.success) {
expect(result.content).toBe(expected); expect(result.content).toBe(expected)
} }
}); })
describe('error handling and edge cases', () => { describe("error handling and edge cases", () => {
it('should reject completely invalid diff format', async () => { it("should reject completely invalid diff format", async () => {
const original = 'line1\nline2\nline3'; const original = "line1\nline2\nline3"
const invalidDiff = 'this is not a diff at all'; const invalidDiff = "this is not a diff at all"
const result = await strategy.applyDiff(original, invalidDiff); const result = await strategy.applyDiff(original, invalidDiff)
expect(result.success).toBe(false); expect(result.success).toBe(false)
}); })
it('should reject diff with invalid hunk format', async () => { it("should reject diff with invalid hunk format", async () => {
const original = 'line1\nline2\nline3'; const original = "line1\nline2\nline3"
const invalidHunkDiff = `--- a/file.txt const invalidHunkDiff = `--- a/file.txt
+++ b/file.txt +++ b/file.txt
invalid hunk header invalid hunk header
line1 line1
-line2 -line2
+new line`; +new line`
const result = await strategy.applyDiff(original, invalidHunkDiff); const result = await strategy.applyDiff(original, invalidHunkDiff)
expect(result.success).toBe(false); expect(result.success).toBe(false)
}); })
it('should fail when diff tries to modify non-existent content', async () => { it("should fail when diff tries to modify non-existent content", async () => {
const original = 'line1\nline2\nline3'; const original = "line1\nline2\nline3"
const nonMatchingDiff = `--- a/file.txt const nonMatchingDiff = `--- a/file.txt
+++ b/file.txt +++ b/file.txt
@@ ... @@ @@ ... @@
line1 line1
-nonexistent line -nonexistent line
+new line +new line
line3`; line3`
const result = await strategy.applyDiff(original, nonMatchingDiff); const result = await strategy.applyDiff(original, nonMatchingDiff)
expect(result.success).toBe(false); expect(result.success).toBe(false)
}); })
it('should handle overlapping hunks', async () => { it("should handle overlapping hunks", async () => {
const original = `line1 const original = `line1
line2 line2
line3 line3
line4 line4
line5`; line5`
const overlappingDiff = `--- a/file.txt const overlappingDiff = `--- a/file.txt
+++ b/file.txt +++ b/file.txt
@@ ... @@ @@ ... @@
@@ -384,18 +381,18 @@ line5`;
-line3 -line3
-line4 -line4
+modified3and4 +modified3and4
line5`; line5`
const result = await strategy.applyDiff(original, overlappingDiff); const result = await strategy.applyDiff(original, overlappingDiff)
expect(result.success).toBe(false); expect(result.success).toBe(false)
}); })
it('should handle empty lines modifications', async () => { it("should handle empty lines modifications", async () => {
const original = `line1 const original = `line1
line3 line3
line5`; line5`
const emptyLinesDiff = `--- a/file.txt const emptyLinesDiff = `--- a/file.txt
+++ b/file.txt +++ b/file.txt
@@ ... @@ @@ ... @@
@@ -404,53 +401,53 @@ line5`;
-line3 -line3
+line3modified +line3modified
line5`; line5`
const result = await strategy.applyDiff(original, emptyLinesDiff); const result = await strategy.applyDiff(original, emptyLinesDiff)
expect(result.success).toBe(true); expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(`line1 expect(result.content).toBe(`line1
line3modified line3modified
line5`); line5`)
} }
}); })
it('should handle mixed line endings in diff', async () => { it("should handle mixed line endings in diff", async () => {
const original = 'line1\r\nline2\nline3\r\n'; const original = "line1\r\nline2\nline3\r\n"
const mixedEndingsDiff = `--- a/file.txt const mixedEndingsDiff = `--- a/file.txt
+++ b/file.txt +++ b/file.txt
@@ ... @@ @@ ... @@
line1\r line1\r
-line2 -line2
+modified2\r +modified2\r
line3`; line3`
const result = await strategy.applyDiff(original, mixedEndingsDiff); const result = await strategy.applyDiff(original, mixedEndingsDiff)
expect(result.success).toBe(true); expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe('line1\r\nmodified2\r\nline3\r\n'); expect(result.content).toBe("line1\r\nmodified2\r\nline3\r\n")
} }
}); })
it('should handle partial line modifications', async () => { it("should handle partial line modifications", async () => {
const original = 'const value = oldValue + 123;'; const original = "const value = oldValue + 123;"
const partialDiff = `--- a/file.txt const partialDiff = `--- a/file.txt
+++ b/file.txt +++ b/file.txt
@@ ... @@ @@ ... @@
-const value = oldValue + 123; -const value = oldValue + 123;
+const value = newValue + 123;`; +const value = newValue + 123;`
const result = await strategy.applyDiff(original, partialDiff); const result = await strategy.applyDiff(original, partialDiff)
expect(result.success).toBe(true); expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe('const value = newValue + 123;'); expect(result.content).toBe("const value = newValue + 123;")
} }
}); })
it('should handle slightly malformed but recoverable diff', async () => { it("should handle slightly malformed but recoverable diff", async () => {
const original = 'line1\nline2\nline3'; const original = "line1\nline2\nline3"
// Missing space after --- and +++ // Missing space after --- and +++
const slightlyBadDiff = `---a/file.txt const slightlyBadDiff = `---a/file.txt
+++b/file.txt +++b/file.txt
@@ -458,18 +455,18 @@ line5`);
line1 line1
-line2 -line2
+new line +new line
line3`; line3`
const result = await strategy.applyDiff(original, slightlyBadDiff); const result = await strategy.applyDiff(original, slightlyBadDiff)
expect(result.success).toBe(true); expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe('line1\nnew line\nline3'); expect(result.content).toBe("line1\nnew line\nline3")
} }
}); })
}); })
describe('similar code sections', () => { describe("similar code sections", () => {
it('should correctly modify the right section when similar code exists', async () => { it("should correctly modify the right section when similar code exists", async () => {
const original = `function add(a, b) { const original = `function add(a, b) {
return a + b; return a + b;
} }
@@ -480,7 +477,7 @@ function subtract(a, b) {
function multiply(a, b) { function multiply(a, b) {
return a + b; // Bug here return a + b; // Bug here
}`; }`
const diff = `--- a/math.js const diff = `--- a/math.js
+++ b/math.js +++ b/math.js
@@ -488,10 +485,10 @@ function multiply(a, b) {
function multiply(a, b) { function multiply(a, b) {
- return a + b; // Bug here - return a + b; // Bug here
+ return a * b; + return a * b;
}`; }`
const result = await strategy.applyDiff(original, diff); const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true); expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(`function add(a, b) { expect(result.content).toBe(`function add(a, b) {
return a + b; return a + b;
@@ -503,11 +500,11 @@ function subtract(a, b) {
function multiply(a, b) { function multiply(a, b) {
return a * b; return a * b;
}`); }`)
} }
}); })
it('should handle multiple similar sections with correct context', async () => { it("should handle multiple similar sections with correct context", async () => {
const original = `if (condition) { const original = `if (condition) {
doSomething(); doSomething();
doSomething(); doSomething();
@@ -518,7 +515,7 @@ if (otherCondition) {
doSomething(); doSomething();
doSomething(); doSomething();
doSomething(); doSomething();
}`; }`
const diff = `--- a/file.js const diff = `--- a/file.js
+++ b/file.js +++ b/file.js
@@ -528,10 +525,10 @@ if (otherCondition) {
- doSomething(); - doSomething();
+ doSomethingElse(); + doSomethingElse();
doSomething(); doSomething();
}`; }`
const result = await strategy.applyDiff(original, diff); const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true); expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(`if (condition) { expect(result.content).toBe(`if (condition) {
doSomething(); doSomething();
@@ -543,13 +540,13 @@ if (otherCondition) {
doSomething(); doSomething();
doSomethingElse(); doSomethingElse();
doSomething(); doSomething();
}`); }`)
} }
}); })
}); })
describe('hunk splitting', () => { describe("hunk splitting", () => {
it('should handle large diffs with multiple non-contiguous changes', async () => { it("should handle large diffs with multiple non-contiguous changes", async () => {
const original = `import { readFile } from 'fs'; const original = `import { readFile } from 'fs';
import { join } from 'path'; import { join } from 'path';
import { Logger } from './logger'; import { Logger } from './logger';
@@ -595,7 +592,7 @@ export {
validateInput, validateInput,
writeOutput, writeOutput,
parseConfig parseConfig
};`; };`
const diff = `--- a/file.ts const diff = `--- a/file.ts
+++ b/file.ts +++ b/file.ts
@@ -672,7 +669,7 @@ export {
- parseConfig - parseConfig
+ parseConfig, + parseConfig,
+ type Config + type Config
};`; };`
const expected = `import { readFile, writeFile } from 'fs'; const expected = `import { readFile, writeFile } from 'fs';
import { join } from 'path'; import { join } from 'path';
@@ -727,13 +724,13 @@ export {
writeOutput, writeOutput,
parseConfig, parseConfig,
type Config type Config
};`; };`
const result = await strategy.applyDiff(original, diff); const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true); expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(expected); expect(result.content).toBe(expected)
} }
}); })
}); })
}); })

View File

@@ -1,14 +1,14 @@
import { SearchReplaceDiffStrategy } from '../search-replace' import { SearchReplaceDiffStrategy } from "../search-replace"
describe('SearchReplaceDiffStrategy', () => { describe("SearchReplaceDiffStrategy", () => {
describe('exact matching', () => { describe("exact matching", () => {
let strategy: SearchReplaceDiffStrategy let strategy: SearchReplaceDiffStrategy
beforeEach(() => { beforeEach(() => {
strategy = new SearchReplaceDiffStrategy(1.0, 5) // Default 1.0 threshold for exact matching, 5 line buffer for tests strategy = new SearchReplaceDiffStrategy(1.0, 5) // Default 1.0 threshold for exact matching, 5 line buffer for tests
}) })
it('should replace matching content', async () => { it("should replace matching content", async () => {
const originalContent = 'function hello() {\n console.log("hello")\n}\n' const originalContent = 'function hello() {\n console.log("hello")\n}\n'
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
@@ -28,8 +28,8 @@ function hello() {
} }
}) })
it('should match content with different surrounding whitespace', async () => { it("should match content with different surrounding whitespace", async () => {
const originalContent = '\nfunction example() {\n return 42;\n}\n\n' const originalContent = "\nfunction example() {\n return 42;\n}\n\n"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
function example() { function example() {
@@ -44,12 +44,12 @@ function example() {
const result = await strategy.applyDiff(originalContent, diffContent) const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe('\nfunction example() {\n return 43;\n}\n\n') expect(result.content).toBe("\nfunction example() {\n return 43;\n}\n\n")
} }
}) })
it('should match content with different indentation in search block', async () => { it("should match content with different indentation in search block", async () => {
const originalContent = ' function test() {\n return true;\n }\n' const originalContent = " function test() {\n return true;\n }\n"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
function test() { function test() {
@@ -64,11 +64,11 @@ function test() {
const result = await strategy.applyDiff(originalContent, diffContent) const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(' function test() {\n return false;\n }\n') expect(result.content).toBe(" function test() {\n return false;\n }\n")
} }
}) })
it('should handle tab-based indentation', async () => { it("should handle tab-based indentation", async () => {
const originalContent = "function test() {\n\treturn true;\n}\n" const originalContent = "function test() {\n\treturn true;\n}\n"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
@@ -88,7 +88,7 @@ function test() {
} }
}) })
it('should preserve mixed tabs and spaces', async () => { it("should preserve mixed tabs and spaces", async () => {
const originalContent = "\tclass Example {\n\t constructor() {\n\t\tthis.value = 0;\n\t }\n\t}" const originalContent = "\tclass Example {\n\t constructor() {\n\t\tthis.value = 0;\n\t }\n\t}"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
@@ -108,11 +108,13 @@ function test() {
const result = await strategy.applyDiff(originalContent, diffContent) const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe("\tclass Example {\n\t constructor() {\n\t\tthis.value = 1;\n\t }\n\t}") expect(result.content).toBe(
"\tclass Example {\n\t constructor() {\n\t\tthis.value = 1;\n\t }\n\t}",
)
} }
}) })
it('should handle additional indentation with tabs', async () => { it("should handle additional indentation with tabs", async () => {
const originalContent = "\tfunction test() {\n\t\treturn true;\n\t}" const originalContent = "\tfunction test() {\n\t\treturn true;\n\t}"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
@@ -133,7 +135,7 @@ function test() {
} }
}) })
it('should preserve exact indentation characters when adding lines', async () => { it("should preserve exact indentation characters when adding lines", async () => {
const originalContent = "\tfunction test() {\n\t\treturn true;\n\t}" const originalContent = "\tfunction test() {\n\t\treturn true;\n\t}"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
@@ -151,11 +153,13 @@ function test() {
const result = await strategy.applyDiff(originalContent, diffContent) const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe("\tfunction test() {\n\t\t// First comment\n\t\t// Second comment\n\t\treturn true;\n\t}") expect(result.content).toBe(
"\tfunction test() {\n\t\t// First comment\n\t\t// Second comment\n\t\treturn true;\n\t}",
)
} }
}) })
it('should handle Windows-style CRLF line endings', async () => { it("should handle Windows-style CRLF line endings", async () => {
const originalContent = "function test() {\r\n return true;\r\n}\r\n" const originalContent = "function test() {\r\n return true;\r\n}\r\n"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
@@ -175,7 +179,7 @@ function test() {
} }
}) })
it('should return false if search content does not match', async () => { it("should return false if search content does not match", async () => {
const originalContent = 'function hello() {\n console.log("hello")\n}\n' const originalContent = 'function hello() {\n console.log("hello")\n}\n'
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
@@ -192,7 +196,7 @@ function hello() {
expect(result.success).toBe(false) expect(result.success).toBe(false)
}) })
it('should return false if diff format is invalid', async () => { it("should return false if diff format is invalid", async () => {
const originalContent = 'function hello() {\n console.log("hello")\n}\n' const originalContent = 'function hello() {\n console.log("hello")\n}\n'
const diffContent = `test.ts\nInvalid diff format` const diffContent = `test.ts\nInvalid diff format`
@@ -200,8 +204,9 @@ function hello() {
expect(result.success).toBe(false) expect(result.success).toBe(false)
}) })
it('should handle multiple lines with proper indentation', async () => { it("should handle multiple lines with proper indentation", async () => {
const originalContent = 'class Example {\n constructor() {\n this.value = 0\n }\n\n getValue() {\n return this.value\n }\n}\n' const originalContent =
"class Example {\n constructor() {\n this.value = 0\n }\n\n getValue() {\n return this.value\n }\n}\n"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
getValue() { getValue() {
@@ -218,11 +223,13 @@ function hello() {
const result = await strategy.applyDiff(originalContent, diffContent) const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe('class Example {\n constructor() {\n this.value = 0\n }\n\n getValue() {\n // Add logging\n console.log("Getting value")\n return this.value\n }\n}\n') expect(result.content).toBe(
'class Example {\n constructor() {\n this.value = 0\n }\n\n getValue() {\n // Add logging\n console.log("Getting value")\n return this.value\n }\n}\n',
)
} }
}) })
it('should preserve whitespace exactly in the output', async () => { it("should preserve whitespace exactly in the output", async () => {
const originalContent = " indented\n more indented\n back\n" const originalContent = " indented\n more indented\n back\n"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
@@ -242,8 +249,8 @@ function hello() {
} }
}) })
it('should preserve indentation when adding new lines after existing content', async () => { it("should preserve indentation when adding new lines after existing content", async () => {
const originalContent = ' onScroll={() => updateHighlights()}' const originalContent = " onScroll={() => updateHighlights()}"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
onScroll={() => updateHighlights()} onScroll={() => updateHighlights()}
@@ -258,11 +265,13 @@ function hello() {
const result = await strategy.applyDiff(originalContent, diffContent) const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(' onScroll={() => updateHighlights()}\n onDragOver={(e) => {\n e.preventDefault()\n e.stopPropagation()\n }}') expect(result.content).toBe(
" onScroll={() => updateHighlights()}\n onDragOver={(e) => {\n e.preventDefault()\n e.stopPropagation()\n }}",
)
} }
}) })
it('should handle varying indentation levels correctly', async () => { it("should handle varying indentation levels correctly", async () => {
const originalContent = ` const originalContent = `
class Example { class Example {
constructor() { constructor() {
@@ -271,7 +280,7 @@ class Example {
this.init(); this.init();
} }
} }
}`.trim(); }`.trim()
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
@@ -294,12 +303,13 @@ class Example {
} }
} }
} }
>>>>>>> REPLACE`.trim(); >>>>>>> REPLACE`.trim()
const result = await strategy.applyDiff(originalContent, diffContent); const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(` expect(result.content).toBe(
`
class Example { class Example {
constructor() { constructor() {
this.value = 1; this.value = 1;
@@ -309,11 +319,12 @@ class Example {
this.validate(); this.validate();
} }
} }
}`.trim()); }`.trim(),
)
} }
}) })
it('should handle mixed indentation styles in the same file', async () => { it("should handle mixed indentation styles in the same file", async () => {
const originalContent = `class Example { const originalContent = `class Example {
constructor() { constructor() {
this.value = 0; this.value = 0;
@@ -321,7 +332,7 @@ class Example {
this.init(); this.init();
} }
} }
}`.trim(); }`.trim()
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
constructor() { constructor() {
@@ -338,9 +349,9 @@ class Example {
this.validate(); this.validate();
} }
} }
>>>>>>> REPLACE`; >>>>>>> REPLACE`
const result = await strategy.applyDiff(originalContent, diffContent); const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(`class Example { expect(result.content).toBe(`class Example {
@@ -351,17 +362,17 @@ class Example {
this.validate(); this.validate();
} }
} }
}`); }`)
} }
}) })
it('should handle Python-style significant whitespace', async () => { it("should handle Python-style significant whitespace", async () => {
const originalContent = `def example(): const originalContent = `def example():
if condition: if condition:
do_something() do_something()
for item in items: for item in items:
process(item) process(item)
return True`.trim(); return True`.trim()
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
if condition: if condition:
@@ -374,9 +385,9 @@ class Example {
while items: while items:
item = items.pop() item = items.pop()
process(item) process(item)
>>>>>>> REPLACE`; >>>>>>> REPLACE`
const result = await strategy.applyDiff(originalContent, diffContent); const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(`def example(): expect(result.content).toBe(`def example():
@@ -385,18 +396,18 @@ class Example {
while items: while items:
item = items.pop() item = items.pop()
process(item) process(item)
return True`); return True`)
} }
}); })
it('should preserve empty lines with indentation', async () => { it("should preserve empty lines with indentation", async () => {
const originalContent = `function test() { const originalContent = `function test() {
const x = 1; const x = 1;
if (x) { if (x) {
return true; return true;
} }
}`.trim(); }`.trim()
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
const x = 1; const x = 1;
@@ -407,9 +418,9 @@ class Example {
// Check x // Check x
if (x) { if (x) {
>>>>>>> REPLACE`; >>>>>>> REPLACE`
const result = await strategy.applyDiff(originalContent, diffContent); const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(`function test() { expect(result.content).toBe(`function test() {
@@ -419,18 +430,18 @@ class Example {
if (x) { if (x) {
return true; return true;
} }
}`); }`)
} }
}); })
it('should handle indentation when replacing entire blocks', async () => { it("should handle indentation when replacing entire blocks", async () => {
const originalContent = `class Test { const originalContent = `class Test {
method() { method() {
if (true) { if (true) {
console.log("test"); console.log("test");
} }
} }
}`.trim(); }`.trim()
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
method() { method() {
@@ -448,9 +459,9 @@ class Example {
console.error(e); console.error(e);
} }
} }
>>>>>>> REPLACE`; >>>>>>> REPLACE`
const result = await strategy.applyDiff(originalContent, diffContent); const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(`class Test { expect(result.content).toBe(`class Test {
@@ -463,11 +474,11 @@ class Example {
console.error(e); console.error(e);
} }
} }
}`); }`)
} }
}); })
it('should handle negative indentation relative to search content', async () => { it("should handle negative indentation relative to search content", async () => {
const originalContent = `class Example { const originalContent = `class Example {
constructor() { constructor() {
if (true) { if (true) {
@@ -475,7 +486,7 @@ class Example {
this.setup(); this.setup();
} }
} }
}`.trim(); }`.trim()
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
this.init(); this.init();
@@ -483,9 +494,9 @@ class Example {
======= =======
this.init(); this.init();
this.setup(); this.setup();
>>>>>>> REPLACE`; >>>>>>> REPLACE`
const result = await strategy.applyDiff(originalContent, diffContent); const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(`class Example { expect(result.content).toBe(`class Example {
@@ -495,26 +506,26 @@ class Example {
this.setup(); this.setup();
} }
} }
}`); }`)
} }
}); })
it('should handle extreme negative indentation (no indent)', async () => { it("should handle extreme negative indentation (no indent)", async () => {
const originalContent = `class Example { const originalContent = `class Example {
constructor() { constructor() {
if (true) { if (true) {
this.init(); this.init();
} }
} }
}`.trim(); }`.trim()
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
this.init(); this.init();
======= =======
this.init(); this.init();
>>>>>>> REPLACE`; >>>>>>> REPLACE`
const result = await strategy.applyDiff(originalContent, diffContent); const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(`class Example { expect(result.content).toBe(`class Example {
@@ -523,11 +534,11 @@ this.init();
this.init(); this.init();
} }
} }
}`); }`)
} }
}); })
it('should handle mixed indentation changes in replace block', async () => { it("should handle mixed indentation changes in replace block", async () => {
const originalContent = `class Example { const originalContent = `class Example {
constructor() { constructor() {
if (true) { if (true) {
@@ -536,7 +547,7 @@ this.init();
this.validate(); this.validate();
} }
} }
}`.trim(); }`.trim()
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
this.init(); this.init();
@@ -546,9 +557,9 @@ this.init();
this.init(); this.init();
this.setup(); this.setup();
this.validate(); this.validate();
>>>>>>> REPLACE`; >>>>>>> REPLACE`
const result = await strategy.applyDiff(originalContent, diffContent); const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(`class Example { expect(result.content).toBe(`class Example {
@@ -559,11 +570,11 @@ this.init();
this.validate(); this.validate();
} }
} }
}`); }`)
} }
}); })
it('should find matches from middle out', async () => { it("should find matches from middle out", async () => {
const originalContent = ` const originalContent = `
function one() { function one() {
return "target"; return "target";
@@ -621,16 +632,16 @@ function five() {
}) })
}) })
describe('line number stripping', () => { describe("line number stripping", () => {
describe('line number stripping', () => { describe("line number stripping", () => {
let strategy: SearchReplaceDiffStrategy let strategy: SearchReplaceDiffStrategy
beforeEach(() => { beforeEach(() => {
strategy = new SearchReplaceDiffStrategy() strategy = new SearchReplaceDiffStrategy()
}) })
it('should strip line numbers from both search and replace sections', async () => { it("should strip line numbers from both search and replace sections", async () => {
const originalContent = 'function test() {\n return true;\n}\n' const originalContent = "function test() {\n return true;\n}\n"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
1 | function test() { 1 | function test() {
@@ -645,12 +656,12 @@ function five() {
const result = await strategy.applyDiff(originalContent, diffContent) const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe('function test() {\n return false;\n}\n') expect(result.content).toBe("function test() {\n return false;\n}\n")
} }
}) })
it('should strip line numbers with leading spaces', async () => { it("should strip line numbers with leading spaces", async () => {
const originalContent = 'function test() {\n return true;\n}\n' const originalContent = "function test() {\n return true;\n}\n"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
1 | function test() { 1 | function test() {
@@ -665,12 +676,12 @@ function five() {
const result = await strategy.applyDiff(originalContent, diffContent) const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe('function test() {\n return false;\n}\n') expect(result.content).toBe("function test() {\n return false;\n}\n")
} }
}) })
it('should not strip when not all lines have numbers in either section', async () => { it("should not strip when not all lines have numbers in either section", async () => {
const originalContent = 'function test() {\n return true;\n}\n' const originalContent = "function test() {\n return true;\n}\n"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
1 | function test() { 1 | function test() {
@@ -686,8 +697,8 @@ function five() {
expect(result.success).toBe(false) expect(result.success).toBe(false)
}) })
it('should preserve content that naturally starts with pipe', async () => { it("should preserve content that naturally starts with pipe", async () => {
const originalContent = '|header|another|\n|---|---|\n|data|more|\n' const originalContent = "|header|another|\n|---|---|\n|data|more|\n"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
1 | |header|another| 1 | |header|another|
@@ -702,12 +713,12 @@ function five() {
const result = await strategy.applyDiff(originalContent, diffContent) const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe('|header|another|\n|---|---|\n|data|updated|\n') expect(result.content).toBe("|header|another|\n|---|---|\n|data|updated|\n")
} }
}) })
it('should preserve indentation when stripping line numbers', async () => { it("should preserve indentation when stripping line numbers", async () => {
const originalContent = ' function test() {\n return true;\n }\n' const originalContent = " function test() {\n return true;\n }\n"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
1 | function test() { 1 | function test() {
@@ -722,12 +733,12 @@ function five() {
const result = await strategy.applyDiff(originalContent, diffContent) const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe(' function test() {\n return false;\n }\n') expect(result.content).toBe(" function test() {\n return false;\n }\n")
} }
}) })
it('should handle different line numbers between sections', async () => { it("should handle different line numbers between sections", async () => {
const originalContent = 'function test() {\n return true;\n}\n' const originalContent = "function test() {\n return true;\n}\n"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
10 | function test() { 10 | function test() {
@@ -742,12 +753,12 @@ function five() {
const result = await strategy.applyDiff(originalContent, diffContent) const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe('function test() {\n return false;\n}\n') expect(result.content).toBe("function test() {\n return false;\n}\n")
} }
}) })
it('should not strip content that starts with pipe but no line number', async () => { it("should not strip content that starts with pipe but no line number", async () => {
const originalContent = '| Pipe\n|---|\n| Data\n' const originalContent = "| Pipe\n|---|\n| Data\n"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
| Pipe | Pipe
@@ -762,12 +773,12 @@ function five() {
const result = await strategy.applyDiff(originalContent, diffContent) const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe('| Pipe\n|---|\n| Updated\n') expect(result.content).toBe("| Pipe\n|---|\n| Updated\n")
} }
}) })
it('should handle mix of line-numbered and pipe-only content', async () => { it("should handle mix of line-numbered and pipe-only content", async () => {
const originalContent = '| Pipe\n|---|\n| Data\n' const originalContent = "| Pipe\n|---|\n| Data\n"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
| Pipe | Pipe
@@ -782,21 +793,21 @@ function five() {
const result = await strategy.applyDiff(originalContent, diffContent) const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe('1 | | Pipe\n2 | |---|\n3 | | NewData\n') expect(result.content).toBe("1 | | Pipe\n2 | |---|\n3 | | NewData\n")
} }
}) })
}) })
}); })
describe('insertion/deletion', () => { describe("insertion/deletion", () => {
let strategy: SearchReplaceDiffStrategy let strategy: SearchReplaceDiffStrategy
beforeEach(() => { beforeEach(() => {
strategy = new SearchReplaceDiffStrategy() strategy = new SearchReplaceDiffStrategy()
}) })
describe('deletion', () => { describe("deletion", () => {
it('should delete code when replace block is empty', async () => { it("should delete code when replace block is empty", async () => {
const originalContent = `function test() { const originalContent = `function test() {
console.log("hello"); console.log("hello");
// Comment to remove // Comment to remove
@@ -818,7 +829,7 @@ function five() {
} }
}) })
it('should delete multiple lines when replace block is empty', async () => { it("should delete multiple lines when replace block is empty", async () => {
const originalContent = `class Example { const originalContent = `class Example {
constructor() { constructor() {
// Initialize // Initialize
@@ -848,7 +859,7 @@ function five() {
} }
}) })
it('should preserve indentation when deleting nested code', async () => { it("should preserve indentation when deleting nested code", async () => {
const originalContent = `function outer() { const originalContent = `function outer() {
if (true) { if (true) {
// Remove this // Remove this
@@ -877,8 +888,8 @@ function five() {
}) })
}) })
describe('insertion', () => { describe("insertion", () => {
it('should insert code at specified line when search block is empty', async () => { it("should insert code at specified line when search block is empty", async () => {
const originalContent = `function test() { const originalContent = `function test() {
const x = 1; const x = 1;
return x; return x;
@@ -900,7 +911,7 @@ function five() {
} }
}) })
it('should preserve indentation when inserting at nested location', async () => { it("should preserve indentation when inserting at nested location", async () => {
const originalContent = `function test() { const originalContent = `function test() {
if (true) { if (true) {
const x = 1; const x = 1;
@@ -926,7 +937,7 @@ function five() {
} }
}) })
it('should handle insertion at start of file', async () => { it("should handle insertion at start of file", async () => {
const originalContent = `function test() { const originalContent = `function test() {
return true; return true;
}` }`
@@ -950,7 +961,7 @@ function test() {
} }
}) })
it('should handle insertion at end of file', async () => { it("should handle insertion at end of file", async () => {
const originalContent = `function test() { const originalContent = `function test() {
return true; return true;
}` }`
@@ -972,7 +983,7 @@ function test() {
} }
}) })
it('should error if no start_line is provided for insertion', async () => { it("should error if no start_line is provided for insertion", async () => {
const originalContent = `function test() { const originalContent = `function test() {
return true; return true;
}` }`
@@ -988,14 +999,15 @@ console.log("test");
}) })
}) })
describe('fuzzy matching', () => { describe("fuzzy matching", () => {
let strategy: SearchReplaceDiffStrategy let strategy: SearchReplaceDiffStrategy
beforeEach(() => { beforeEach(() => {
strategy = new SearchReplaceDiffStrategy(0.9, 5) // 90% similarity threshold, 5 line buffer for tests strategy = new SearchReplaceDiffStrategy(0.9, 5) // 90% similarity threshold, 5 line buffer for tests
}) })
it('should match content with small differences (>90% similar)', async () => { it("should match content with small differences (>90% similar)", async () => {
const originalContent = 'function getData() {\n const results = fetchData();\n return results.filter(Boolean);\n}\n' const originalContent =
"function getData() {\n const results = fetchData();\n return results.filter(Boolean);\n}\n"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
function getData() { function getData() {
@@ -1014,12 +1026,14 @@ function getData() {
const result = await strategy.applyDiff(originalContent, diffContent) const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe('function getData() {\n const data = fetchData();\n return data.filter(Boolean);\n}\n') expect(result.content).toBe(
"function getData() {\n const data = fetchData();\n return data.filter(Boolean);\n}\n",
)
} }
}) })
it('should not match when content is too different (<90% similar)', async () => { it("should not match when content is too different (<90% similar)", async () => {
const originalContent = 'function processUsers(data) {\n return data.map(user => user.name);\n}\n' const originalContent = "function processUsers(data) {\n return data.map(user => user.name);\n}\n"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
function handleItems(items) { function handleItems(items) {
@@ -1035,8 +1049,8 @@ function processData(data) {
expect(result.success).toBe(false) expect(result.success).toBe(false)
}) })
it('should match content with extra whitespace', async () => { it("should match content with extra whitespace", async () => {
const originalContent = 'function sum(a, b) {\n return a + b;\n}' const originalContent = "function sum(a, b) {\n return a + b;\n}"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
function sum(a, b) { function sum(a, b) {
@@ -1051,12 +1065,12 @@ function sum(a, b) {
const result = await strategy.applyDiff(originalContent, diffContent) const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true) expect(result.success).toBe(true)
if (result.success) { if (result.success) {
expect(result.content).toBe('function sum(a, b) {\n return a + b + 1;\n}') expect(result.content).toBe("function sum(a, b) {\n return a + b + 1;\n}")
} }
}) })
it('should not exact match empty lines', async () => { it("should not exact match empty lines", async () => {
const originalContent = 'function sum(a, b) {\n\n return a + b;\n}' const originalContent = "function sum(a, b) {\n\n return a + b;\n}"
const diffContent = `test.ts const diffContent = `test.ts
<<<<<<< SEARCH <<<<<<< SEARCH
function sum(a, b) { function sum(a, b) {
@@ -1073,14 +1087,14 @@ function sum(a, b) {
}) })
}) })
describe('line-constrained search', () => { describe("line-constrained search", () => {
let strategy: SearchReplaceDiffStrategy let strategy: SearchReplaceDiffStrategy
beforeEach(() => { beforeEach(() => {
strategy = new SearchReplaceDiffStrategy(0.9, 5) strategy = new SearchReplaceDiffStrategy(0.9, 5)
}) })
it('should find and replace within specified line range', async () => { it("should find and replace within specified line range", async () => {
const originalContent = ` const originalContent = `
function one() { function one() {
return 1; return 1;
@@ -1122,7 +1136,7 @@ function three() {
} }
}) })
it('should find and replace within buffer zone (5 lines before/after)', async () => { it("should find and replace within buffer zone (5 lines before/after)", async () => {
const originalContent = ` const originalContent = `
function one() { function one() {
return 1; return 1;
@@ -1166,7 +1180,7 @@ function three() {
} }
}) })
it('should not find matches outside search range and buffer zone', async () => { it("should not find matches outside search range and buffer zone", async () => {
const originalContent = ` const originalContent = `
function one() { function one() {
return 1; return 1;
@@ -1205,7 +1219,7 @@ function five() {
expect(result.success).toBe(false) expect(result.success).toBe(false)
}) })
it('should handle search range at start of file', async () => { it("should handle search range at start of file", async () => {
const originalContent = ` const originalContent = `
function one() { function one() {
return 1; return 1;
@@ -1239,7 +1253,7 @@ function two() {
} }
}) })
it('should handle search range at end of file', async () => { it("should handle search range at end of file", async () => {
const originalContent = ` const originalContent = `
function one() { function one() {
return 1; return 1;
@@ -1273,7 +1287,7 @@ function two() {
} }
}) })
it('should match specific instance of duplicate code using line numbers', async () => { it("should match specific instance of duplicate code using line numbers", async () => {
const originalContent = ` const originalContent = `
function processData(data) { function processData(data) {
return data.map(x => x * 2); return data.map(x => x * 2);
@@ -1330,7 +1344,7 @@ function moreStuff() {
} }
}) })
it('should search from start line to end of file when only start_line is provided', async () => { it("should search from start line to end of file when only start_line is provided", async () => {
const originalContent = ` const originalContent = `
function one() { function one() {
return 1; return 1;
@@ -1373,7 +1387,7 @@ function three() {
} }
}) })
it('should search from start of file to end line when only end_line is provided', async () => { it("should search from start of file to end line when only end_line is provided", async () => {
const originalContent = ` const originalContent = `
function one() { function one() {
return 1; return 1;
@@ -1416,7 +1430,7 @@ function three() {
} }
}) })
it('should prioritize exact line match over expanded search', async () => { it("should prioritize exact line match over expanded search", async () => {
const originalContent = ` const originalContent = `
function one() { function one() {
return 1; return 1;
@@ -1468,7 +1482,7 @@ function two() {
} }
}) })
it('should fall back to expanded search only if exact match fails', async () => { it("should fall back to expanded search only if exact match fails", async () => {
const originalContent = ` const originalContent = `
function one() { function one() {
return 1; return 1;
@@ -1512,32 +1526,32 @@ function two() {
}) })
}) })
describe('getToolDescription', () => { describe("getToolDescription", () => {
let strategy: SearchReplaceDiffStrategy let strategy: SearchReplaceDiffStrategy
beforeEach(() => { beforeEach(() => {
strategy = new SearchReplaceDiffStrategy() strategy = new SearchReplaceDiffStrategy()
}) })
it('should include the current working directory', async () => { it("should include the current working directory", async () => {
const cwd = '/test/dir' const cwd = "/test/dir"
const description = await strategy.getToolDescription(cwd) const description = await strategy.getToolDescription({ cwd })
expect(description).toContain(`relative to the current working directory ${cwd}`) expect(description).toContain(`relative to the current working directory ${cwd}`)
}) })
it('should include required format elements', async () => { it("should include required format elements", async () => {
const description = await strategy.getToolDescription('/test') const description = await strategy.getToolDescription({ cwd: "/test" })
expect(description).toContain('<<<<<<< SEARCH') expect(description).toContain("<<<<<<< SEARCH")
expect(description).toContain('=======') expect(description).toContain("=======")
expect(description).toContain('>>>>>>> REPLACE') expect(description).toContain(">>>>>>> REPLACE")
expect(description).toContain('<apply_diff>') expect(description).toContain("<apply_diff>")
expect(description).toContain('</apply_diff>') expect(description).toContain("</apply_diff>")
}) })
it('should document start_line and end_line parameters', async () => { it("should document start_line and end_line parameters", async () => {
const description = await strategy.getToolDescription('/test') const description = await strategy.getToolDescription({ cwd: "/test" })
expect(description).toContain('start_line: (required) The line number where the search block starts.') expect(description).toContain("start_line: (required) The line number where the search block starts.")
expect(description).toContain('end_line: (required) The line number where the search block ends.') expect(description).toContain("end_line: (required) The line number where the search block ends.")
}) })
}) })
}) })

View File

@@ -1,26 +1,26 @@
import { UnifiedDiffStrategy } from '../unified' import { UnifiedDiffStrategy } from "../unified"
describe('UnifiedDiffStrategy', () => { describe("UnifiedDiffStrategy", () => {
let strategy: UnifiedDiffStrategy let strategy: UnifiedDiffStrategy
beforeEach(() => { beforeEach(() => {
strategy = new UnifiedDiffStrategy() strategy = new UnifiedDiffStrategy()
}) })
describe('getToolDescription', () => { describe("getToolDescription", () => {
it('should return tool description with correct cwd', () => { it("should return tool description with correct cwd", () => {
const cwd = '/test/path' const cwd = "/test/path"
const description = strategy.getToolDescription(cwd) const description = strategy.getToolDescription({ cwd })
expect(description).toContain('apply_diff') expect(description).toContain("apply_diff")
expect(description).toContain(cwd) expect(description).toContain(cwd)
expect(description).toContain('Parameters:') expect(description).toContain("Parameters:")
expect(description).toContain('Format Requirements:') expect(description).toContain("Format Requirements:")
}) })
}) })
describe('applyDiff', () => { describe("applyDiff", () => {
it('should successfully apply a function modification diff', async () => { it("should successfully apply a function modification diff", async () => {
const originalContent = `import { Logger } from '../logger'; const originalContent = `import { Logger } from '../logger';
function calculateTotal(items: number[]): number { function calculateTotal(items: number[]): number {
@@ -65,7 +65,7 @@ export { calculateTotal };`
} }
}) })
it('should successfully apply a diff adding a new method', async () => { it("should successfully apply a diff adding a new method", async () => {
const originalContent = `class Calculator { const originalContent = `class Calculator {
add(a: number, b: number): number { add(a: number, b: number): number {
return a + b; return a + b;
@@ -102,7 +102,7 @@ export { calculateTotal };`
} }
}) })
it('should successfully apply a diff modifying imports', async () => { it("should successfully apply a diff modifying imports", async () => {
const originalContent = `import { useState } from 'react'; const originalContent = `import { useState } from 'react';
import { Button } from './components'; import { Button } from './components';
@@ -140,7 +140,7 @@ function App() {
} }
}) })
it('should successfully apply a diff with multiple hunks', async () => { it("should successfully apply a diff with multiple hunks", async () => {
const originalContent = `import { readFile, writeFile } from 'fs'; const originalContent = `import { readFile, writeFile } from 'fs';
function processFile(path: string) { function processFile(path: string) {
@@ -205,8 +205,8 @@ export { processFile };`
} }
}) })
it('should handle empty original content', async () => { it("should handle empty original content", async () => {
const originalContent = '' const originalContent = ""
const diffContent = `--- empty.ts const diffContent = `--- empty.ts
+++ empty.ts +++ empty.ts
@@ -0,0 +1,3 @@ @@ -0,0 +1,3 @@
@@ -226,4 +226,3 @@ export { processFile };`
}) })
}) })
}) })

View File

@@ -265,8 +265,8 @@ describe("applyGitFallback", () => {
{ type: "context", content: "line1", indent: "" }, { type: "context", content: "line1", indent: "" },
{ type: "remove", content: "line2", indent: "" }, { type: "remove", content: "line2", indent: "" },
{ type: "add", content: "new line2", indent: "" }, { type: "add", content: "new line2", indent: "" },
{ type: "context", content: "line3", indent: "" } { type: "context", content: "line3", indent: "" },
] ],
} as Hunk } as Hunk
const content = ["line1", "line2", "line3"] const content = ["line1", "line2", "line3"]
@@ -281,8 +281,8 @@ describe("applyGitFallback", () => {
const hunk = { const hunk = {
changes: [ changes: [
{ type: "context", content: "nonexistent", indent: "" }, { type: "context", content: "nonexistent", indent: "" },
{ type: "add", content: "new line", indent: "" } { type: "add", content: "new line", indent: "" },
] ],
} as Hunk } as Hunk
const content = ["line1", "line2", "line3"] const content = ["line1", "line2", "line3"]

View File

@@ -3,7 +3,7 @@ import { findAnchorMatch, findExactMatch, findSimilarityMatch, findLevenshteinMa
type SearchStrategy = ( type SearchStrategy = (
searchStr: string, searchStr: string,
content: string[], content: string[],
startIndex?: number startIndex?: number,
) => { ) => {
index: number index: number
confidence: number confidence: number

View File

@@ -28,11 +28,7 @@ function inferIndentation(line: string, contextLines: string[], previousIndent:
} }
// Context matching edit strategy // Context matching edit strategy
export function applyContextMatching( export function applyContextMatching(hunk: Hunk, content: string[], matchPosition: number): EditResult {
hunk: Hunk,
content: string[],
matchPosition: number,
): EditResult {
if (matchPosition === -1) { if (matchPosition === -1) {
return { confidence: 0, result: content, strategy: "context" } return { confidence: 0, result: content, strategy: "context" }
} }
@@ -85,16 +81,12 @@ export function applyContextMatching(
return { return {
confidence, confidence,
result: newResult, result: newResult,
strategy: "context" strategy: "context",
} }
} }
// DMP edit strategy // DMP edit strategy
export function applyDMP( export function applyDMP(hunk: Hunk, content: string[], matchPosition: number): EditResult {
hunk: Hunk,
content: string[],
matchPosition: number,
): EditResult {
if (matchPosition === -1) { if (matchPosition === -1) {
return { confidence: 0, result: content, strategy: "dmp" } return { confidence: 0, result: content, strategy: "dmp" }
} }
@@ -276,12 +268,12 @@ export async function applyEdit(
content: string[], content: string[],
matchPosition: number, matchPosition: number,
confidence: number, confidence: number,
confidenceThreshold: number = 0.97 confidenceThreshold: number = 0.97,
): Promise<EditResult> { ): Promise<EditResult> {
// Don't attempt regular edits if confidence is too low // Don't attempt regular edits if confidence is too low
if (confidence < confidenceThreshold) { if (confidence < confidenceThreshold) {
console.log( console.log(
`Search confidence (${confidence}) below minimum threshold (${confidenceThreshold}), trying git fallback...` `Search confidence (${confidence}) below minimum threshold (${confidenceThreshold}), trying git fallback...`,
) )
return applyGitFallback(hunk, content) return applyGitFallback(hunk, content)
} }

View File

@@ -164,8 +164,8 @@ Generate a unified diff that can be cleanly applied to modify code files.
\`\`\` \`\`\`
Parameters: Parameters:
- path: (required) File path relative to ${cwd} - path: (required) The path of the file to apply the diff to (relative to the current working directory ${args.cwd})
- diff: (required) Unified diff content - diff: (required) The diff content in unified format to apply to the file.
Usage: Usage:
<apply_diff> <apply_diff>
@@ -233,7 +233,7 @@ Your diff here
originalContent: string, originalContent: string,
diffContent: string, diffContent: string,
startLine?: number, startLine?: number,
endLine?: number endLine?: number,
): Promise<DiffResult> { ): Promise<DiffResult> {
const parsedDiff = this.parseUnifiedDiff(diffContent) const parsedDiff = this.parseUnifiedDiff(diffContent)
const originalLines = originalContent.split("\n") const originalLines = originalContent.split("\n")
@@ -271,7 +271,7 @@ Your diff here
subHunkResult, subHunkResult,
subSearchResult.index, subSearchResult.index,
subSearchResult.confidence, subSearchResult.confidence,
this.confidenceThreshold this.confidenceThreshold,
) )
if (subEditResult.confidence >= this.confidenceThreshold) { if (subEditResult.confidence >= this.confidenceThreshold) {
subHunkResult = subEditResult.result subHunkResult = subEditResult.result
@@ -293,12 +293,12 @@ Your diff here
const contextRatio = contextLines / totalLines const contextRatio = contextLines / totalLines
let errorMsg = `Failed to find a matching location in the file (${Math.floor( let errorMsg = `Failed to find a matching location in the file (${Math.floor(
confidence * 100 confidence * 100,
)}% confidence, needs ${Math.floor(this.confidenceThreshold * 100)}%)\n\n` )}% confidence, needs ${Math.floor(this.confidenceThreshold * 100)}%)\n\n`
errorMsg += "Debug Info:\n" errorMsg += "Debug Info:\n"
errorMsg += `- Search Strategy Used: ${strategy}\n` errorMsg += `- Search Strategy Used: ${strategy}\n`
errorMsg += `- Context Lines: ${contextLines} out of ${totalLines} total lines (${Math.floor( errorMsg += `- Context Lines: ${contextLines} out of ${totalLines} total lines (${Math.floor(
contextRatio * 100 contextRatio * 100,
)}%)\n` )}%)\n`
errorMsg += `- Attempted to split into ${subHunks.length} sub-hunks but still failed\n` errorMsg += `- Attempted to split into ${subHunks.length} sub-hunks but still failed\n`
@@ -330,7 +330,7 @@ Your diff here
} else { } else {
// Edit failure - likely due to content mismatch // Edit failure - likely due to content mismatch
let errorMsg = `Failed to apply the edit using ${editResult.strategy} strategy (${Math.floor( let errorMsg = `Failed to apply the edit using ${editResult.strategy} strategy (${Math.floor(
editResult.confidence * 100 editResult.confidence * 100,
)}% confidence)\n\n` )}% confidence)\n\n`
errorMsg += "Debug Info:\n" errorMsg += "Debug Info:\n"
errorMsg += "- The location was found but the content didn't match exactly\n" errorMsg += "- The location was found but the content didn't match exactly\n"

View File

@@ -69,26 +69,26 @@ export function getDMPSimilarity(original: string, modified: string): number {
export function validateEditResult(hunk: Hunk, result: string): number { export function validateEditResult(hunk: Hunk, result: string): number {
// Build the expected text from the hunk // Build the expected text from the hunk
const expectedText = hunk.changes const expectedText = hunk.changes
.filter(change => change.type === "context" || change.type === "add") .filter((change) => change.type === "context" || change.type === "add")
.map(change => change.indent ? change.indent + change.content : change.content) .map((change) => (change.indent ? change.indent + change.content : change.content))
.join("\n"); .join("\n")
// Calculate similarity between the result and expected text // Calculate similarity between the result and expected text
const similarity = getDMPSimilarity(expectedText, result); const similarity = getDMPSimilarity(expectedText, result)
// If the result is unchanged from original, return low confidence // If the result is unchanged from original, return low confidence
const originalText = hunk.changes const originalText = hunk.changes
.filter(change => change.type === "context" || change.type === "remove") .filter((change) => change.type === "context" || change.type === "remove")
.map(change => change.indent ? change.indent + change.content : change.content) .map((change) => (change.indent ? change.indent + change.content : change.content))
.join("\n"); .join("\n")
const originalSimilarity = getDMPSimilarity(originalText, result); const originalSimilarity = getDMPSimilarity(originalText, result)
if (originalSimilarity > 0.97 && similarity !== 1) { if (originalSimilarity > 0.97 && similarity !== 1) {
return 0.8 * similarity; // Some confidence since we found the right location return 0.8 * similarity // Some confidence since we found the right location
} }
// For partial matches, scale the confidence but keep it high if we're close // For partial matches, scale the confidence but keep it high if we're close
return similarity; return similarity
} }
// Helper function to validate context lines against original content // Helper function to validate context lines against original content
@@ -114,7 +114,7 @@ function validateContextLines(searchStr: string, content: string, confidenceThre
function createOverlappingWindows( function createOverlappingWindows(
content: string[], content: string[],
searchSize: number, searchSize: number,
overlapSize: number = DEFAULT_OVERLAP_SIZE overlapSize: number = DEFAULT_OVERLAP_SIZE,
): { window: string[]; startIndex: number }[] { ): { window: string[]; startIndex: number }[] {
const windows: { window: string[]; startIndex: number }[] = [] const windows: { window: string[]; startIndex: number }[] = []
@@ -140,7 +140,7 @@ function createOverlappingWindows(
// Helper function to combine overlapping matches // Helper function to combine overlapping matches
function combineOverlappingMatches( function combineOverlappingMatches(
matches: (SearchResult & { windowIndex: number })[], matches: (SearchResult & { windowIndex: number })[],
overlapSize: number = DEFAULT_OVERLAP_SIZE overlapSize: number = DEFAULT_OVERLAP_SIZE,
): SearchResult[] { ): SearchResult[] {
if (matches.length === 0) { if (matches.length === 0) {
return [] return []
@@ -162,7 +162,7 @@ function combineOverlappingMatches(
(m) => (m) =>
Math.abs(m.windowIndex - match.windowIndex) === 1 && Math.abs(m.windowIndex - match.windowIndex) === 1 &&
Math.abs(m.index - match.index) <= overlapSize && Math.abs(m.index - match.index) <= overlapSize &&
!usedIndices.has(m.windowIndex) !usedIndices.has(m.windowIndex),
) )
if (overlapping.length > 0) { if (overlapping.length > 0) {
@@ -196,7 +196,7 @@ export function findExactMatch(
searchStr: string, searchStr: string,
content: string[], content: string[],
startIndex: number = 0, startIndex: number = 0,
confidenceThreshold: number = 0.97 confidenceThreshold: number = 0.97,
): SearchResult { ): SearchResult {
const searchLines = searchStr.split("\n") const searchLines = searchStr.split("\n")
const windows = createOverlappingWindows(content.slice(startIndex), searchLines.length) const windows = createOverlappingWindows(content.slice(startIndex), searchLines.length)
@@ -210,7 +210,7 @@ export function findExactMatch(
const matchedContent = windowData.window const matchedContent = windowData.window
.slice( .slice(
windowStr.slice(0, exactMatch).split("\n").length - 1, windowStr.slice(0, exactMatch).split("\n").length - 1,
windowStr.slice(0, exactMatch).split("\n").length - 1 + searchLines.length windowStr.slice(0, exactMatch).split("\n").length - 1 + searchLines.length,
) )
.join("\n") .join("\n")
@@ -236,7 +236,7 @@ export function findSimilarityMatch(
searchStr: string, searchStr: string,
content: string[], content: string[],
startIndex: number = 0, startIndex: number = 0,
confidenceThreshold: number = 0.97 confidenceThreshold: number = 0.97,
): SearchResult { ): SearchResult {
const searchLines = searchStr.split("\n") const searchLines = searchStr.split("\n")
let bestScore = 0 let bestScore = 0
@@ -269,7 +269,7 @@ export function findLevenshteinMatch(
searchStr: string, searchStr: string,
content: string[], content: string[],
startIndex: number = 0, startIndex: number = 0,
confidenceThreshold: number = 0.97 confidenceThreshold: number = 0.97,
): SearchResult { ): SearchResult {
const searchLines = searchStr.split("\n") const searchLines = searchStr.split("\n")
const candidates = [] const candidates = []
@@ -324,7 +324,7 @@ export function findAnchorMatch(
searchStr: string, searchStr: string,
content: string[], content: string[],
startIndex: number = 0, startIndex: number = 0,
confidenceThreshold: number = 0.97 confidenceThreshold: number = 0.97,
): SearchResult { ): SearchResult {
const searchLines = searchStr.split("\n") const searchLines = searchStr.split("\n")
const { first, last } = identifyAnchors(searchStr) const { first, last } = identifyAnchors(searchStr)
@@ -391,7 +391,7 @@ export function findBestMatch(
searchStr: string, searchStr: string,
content: string[], content: string[],
startIndex: number = 0, startIndex: number = 0,
confidenceThreshold: number = 0.97 confidenceThreshold: number = 0.97,
): SearchResult { ): SearchResult {
const strategies = [findExactMatch, findAnchorMatch, findSimilarityMatch, findLevenshteinMatch] const strategies = [findExactMatch, findAnchorMatch, findSimilarityMatch, findLevenshteinMatch]

View File

@@ -1,20 +1,20 @@
export type Change = { export type Change = {
type: 'context' | 'add' | 'remove'; type: "context" | "add" | "remove"
content: string; content: string
indent: string; indent: string
originalLine?: string; originalLine?: string
}; }
export type Hunk = { export type Hunk = {
changes: Change[]; changes: Change[]
}; }
export type Diff = { export type Diff = {
hunks: Hunk[]; hunks: Hunk[]
}; }
export type EditResult = { export type EditResult = {
confidence: number; confidence: number
result: string[]; result: string[]
strategy: string; strategy: string
}; }

View File

@@ -1,71 +1,73 @@
import { DiffStrategy, DiffResult } from "../types" import { DiffStrategy, DiffResult } from "../types"
import { addLineNumbers, everyLineHasLineNumbers, stripLineNumbers } from "../../../integrations/misc/extract-text" import { addLineNumbers, everyLineHasLineNumbers, stripLineNumbers } from "../../../integrations/misc/extract-text"
const BUFFER_LINES = 20; // Number of extra context lines to show before and after matches const BUFFER_LINES = 20 // Number of extra context lines to show before and after matches
function levenshteinDistance(a: string, b: string): number { function levenshteinDistance(a: string, b: string): number {
const matrix: number[][] = []; const matrix: number[][] = []
// Initialize matrix // Initialize matrix
for (let i = 0; i <= a.length; i++) { for (let i = 0; i <= a.length; i++) {
matrix[i] = [i]; matrix[i] = [i]
} }
for (let j = 0; j <= b.length; j++) { for (let j = 0; j <= b.length; j++) {
matrix[0][j] = j; matrix[0][j] = j
} }
// Fill matrix // Fill matrix
for (let i = 1; i <= a.length; i++) { for (let i = 1; i <= a.length; i++) {
for (let j = 1; j <= b.length; j++) { for (let j = 1; j <= b.length; j++) {
if (a[i-1] === b[j-1]) { if (a[i - 1] === b[j - 1]) {
matrix[i][j] = matrix[i-1][j-1]; matrix[i][j] = matrix[i - 1][j - 1]
} else { } else {
matrix[i][j] = Math.min( matrix[i][j] = Math.min(
matrix[i-1][j-1] + 1, // substitution matrix[i - 1][j - 1] + 1, // substitution
matrix[i][j-1] + 1, // insertion matrix[i][j - 1] + 1, // insertion
matrix[i-1][j] + 1 // deletion matrix[i - 1][j] + 1, // deletion
); )
} }
} }
} }
return matrix[a.length][b.length]; return matrix[a.length][b.length]
} }
function getSimilarity(original: string, search: string): number { function getSimilarity(original: string, search: string): number {
if (search === '') { if (search === "") {
return 1; return 1
} }
// Normalize strings by removing extra whitespace but preserve case // Normalize strings by removing extra whitespace but preserve case
const normalizeStr = (str: string) => str.replace(/\s+/g, ' ').trim(); const normalizeStr = (str: string) => str.replace(/\s+/g, " ").trim()
const normalizedOriginal = normalizeStr(original); const normalizedOriginal = normalizeStr(original)
const normalizedSearch = normalizeStr(search); const normalizedSearch = normalizeStr(search)
if (normalizedOriginal === normalizedSearch) { return 1; } if (normalizedOriginal === normalizedSearch) {
return 1
}
// Calculate Levenshtein distance // Calculate Levenshtein distance
const distance = levenshteinDistance(normalizedOriginal, normalizedSearch); const distance = levenshteinDistance(normalizedOriginal, normalizedSearch)
// Calculate similarity ratio (0 to 1, where 1 is exact match) // Calculate similarity ratio (0 to 1, where 1 is exact match)
const maxLength = Math.max(normalizedOriginal.length, normalizedSearch.length); const maxLength = Math.max(normalizedOriginal.length, normalizedSearch.length)
return 1 - (distance / maxLength); return 1 - distance / maxLength
} }
export class SearchReplaceDiffStrategy implements DiffStrategy { export class SearchReplaceDiffStrategy implements DiffStrategy {
private fuzzyThreshold: number; private fuzzyThreshold: number
private bufferLines: number; private bufferLines: number
constructor(fuzzyThreshold?: number, bufferLines?: number) { constructor(fuzzyThreshold?: number, bufferLines?: number) {
// Use provided threshold or default to exact matching (1.0) // Use provided threshold or default to exact matching (1.0)
// Note: fuzzyThreshold is inverted in UI (0% = 1.0, 10% = 0.9) // Note: fuzzyThreshold is inverted in UI (0% = 1.0, 10% = 0.9)
// so we use it directly here // so we use it directly here
this.fuzzyThreshold = fuzzyThreshold ?? 1.0; this.fuzzyThreshold = fuzzyThreshold ?? 1.0
this.bufferLines = bufferLines ?? BUFFER_LINES; this.bufferLines = bufferLines ?? BUFFER_LINES
} }
getToolDescription(cwd: string): string { getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
return `## apply_diff return `## apply_diff
Description: Request to replace existing code using a search and replace block. Description: Request to replace existing code using a search and replace block.
This tool allows for precise, surgical replaces to files by specifying exactly what content to search for and what to replace it with. This tool allows for precise, surgical replaces to files by specifying exactly what content to search for and what to replace it with.
@@ -76,7 +78,7 @@ If you're not confident in the exact content to search for, use the read_file to
When applying the diffs, be extra careful to remember to change any closing brackets or other syntax that may be affected by the diff farther down in the file. When applying the diffs, be extra careful to remember to change any closing brackets or other syntax that may be affected by the diff farther down in the file.
Parameters: Parameters:
- path: (required) The path of the file to modify (relative to the current working directory ${cwd}) - path: (required) The path of the file to modify (relative to the current working directory ${args.cwd})
- diff: (required) The search/replace block defining the changes. - diff: (required) The search/replace block defining the changes.
- start_line: (required) The line number where the search block starts. - start_line: (required) The line number where the search block starts.
- end_line: (required) The line number where the search block ends. - end_line: (required) The line number where the search block ends.
@@ -127,191 +129,202 @@ Your search/replace content here
</apply_diff>` </apply_diff>`
} }
async applyDiff(originalContent: string, diffContent: string, startLine?: number, endLine?: number): Promise<DiffResult> { async applyDiff(
originalContent: string,
diffContent: string,
startLine?: number,
endLine?: number,
): Promise<DiffResult> {
// Extract the search and replace blocks // Extract the search and replace blocks
const match = diffContent.match(/<<<<<<< SEARCH\n([\s\S]*?)\n?=======\n([\s\S]*?)\n?>>>>>>> REPLACE/); const match = diffContent.match(/<<<<<<< SEARCH\n([\s\S]*?)\n?=======\n([\s\S]*?)\n?>>>>>>> REPLACE/)
if (!match) { if (!match) {
return { return {
success: false, success: false,
error: `Invalid diff format - missing required SEARCH/REPLACE sections\n\nDebug Info:\n- Expected Format: <<<<<<< SEARCH\\n[search content]\\n=======\\n[replace content]\\n>>>>>>> REPLACE\n- Tip: Make sure to include both SEARCH and REPLACE sections with correct markers` error: `Invalid diff format - missing required SEARCH/REPLACE sections\n\nDebug Info:\n- Expected Format: <<<<<<< SEARCH\\n[search content]\\n=======\\n[replace content]\\n>>>>>>> REPLACE\n- Tip: Make sure to include both SEARCH and REPLACE sections with correct markers`,
}; }
} }
let [_, searchContent, replaceContent] = match; let [_, searchContent, replaceContent] = match
// Detect line ending from original content // Detect line ending from original content
const lineEnding = originalContent.includes('\r\n') ? '\r\n' : '\n'; const lineEnding = originalContent.includes("\r\n") ? "\r\n" : "\n"
// Strip line numbers from search and replace content if every line starts with a line number // Strip line numbers from search and replace content if every line starts with a line number
if (everyLineHasLineNumbers(searchContent) && everyLineHasLineNumbers(replaceContent)) { if (everyLineHasLineNumbers(searchContent) && everyLineHasLineNumbers(replaceContent)) {
searchContent = stripLineNumbers(searchContent); searchContent = stripLineNumbers(searchContent)
replaceContent = stripLineNumbers(replaceContent); replaceContent = stripLineNumbers(replaceContent)
} }
// Split content into lines, handling both \n and \r\n // Split content into lines, handling both \n and \r\n
const searchLines = searchContent === '' ? [] : searchContent.split(/\r?\n/); const searchLines = searchContent === "" ? [] : searchContent.split(/\r?\n/)
const replaceLines = replaceContent === '' ? [] : replaceContent.split(/\r?\n/); const replaceLines = replaceContent === "" ? [] : replaceContent.split(/\r?\n/)
const originalLines = originalContent.split(/\r?\n/); const originalLines = originalContent.split(/\r?\n/)
// Validate that empty search requires start line // Validate that empty search requires start line
if (searchLines.length === 0 && !startLine) { if (searchLines.length === 0 && !startLine) {
return { return {
success: false, success: false,
error: `Empty search content requires start_line to be specified\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, specify the line number where content should be inserted` error: `Empty search content requires start_line to be specified\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, specify the line number where content should be inserted`,
}; }
} }
// Validate that empty search requires same start and end line // Validate that empty search requires same start and end line
if (searchLines.length === 0 && startLine && endLine && startLine !== endLine) { if (searchLines.length === 0 && startLine && endLine && startLine !== endLine) {
return { return {
success: false, success: false,
error: `Empty search content requires start_line and end_line to be the same (got ${startLine}-${endLine})\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, use the same line number for both start_line and end_line` error: `Empty search content requires start_line and end_line to be the same (got ${startLine}-${endLine})\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, use the same line number for both start_line and end_line`,
}; }
} }
// Initialize search variables // Initialize search variables
let matchIndex = -1; let matchIndex = -1
let bestMatchScore = 0; let bestMatchScore = 0
let bestMatchContent = ""; let bestMatchContent = ""
const searchChunk = searchLines.join('\n'); const searchChunk = searchLines.join("\n")
// Determine search bounds // Determine search bounds
let searchStartIndex = 0; let searchStartIndex = 0
let searchEndIndex = originalLines.length; let searchEndIndex = originalLines.length
// Validate and handle line range if provided // Validate and handle line range if provided
if (startLine && endLine) { if (startLine && endLine) {
// Convert to 0-based index // Convert to 0-based index
const exactStartIndex = startLine - 1; const exactStartIndex = startLine - 1
const exactEndIndex = endLine - 1; const exactEndIndex = endLine - 1
if (exactStartIndex < 0 || exactEndIndex > originalLines.length || exactStartIndex > exactEndIndex) { if (exactStartIndex < 0 || exactEndIndex > originalLines.length || exactStartIndex > exactEndIndex) {
return { return {
success: false, success: false,
error: `Line range ${startLine}-${endLine} is invalid (file has ${originalLines.length} lines)\n\nDebug Info:\n- Requested Range: lines ${startLine}-${endLine}\n- File Bounds: lines 1-${originalLines.length}`, error: `Line range ${startLine}-${endLine} is invalid (file has ${originalLines.length} lines)\n\nDebug Info:\n- Requested Range: lines ${startLine}-${endLine}\n- File Bounds: lines 1-${originalLines.length}`,
}; }
} }
// Try exact match first // Try exact match first
const originalChunk = originalLines.slice(exactStartIndex, exactEndIndex + 1).join('\n'); const originalChunk = originalLines.slice(exactStartIndex, exactEndIndex + 1).join("\n")
const similarity = getSimilarity(originalChunk, searchChunk); const similarity = getSimilarity(originalChunk, searchChunk)
if (similarity >= this.fuzzyThreshold) { if (similarity >= this.fuzzyThreshold) {
matchIndex = exactStartIndex; matchIndex = exactStartIndex
bestMatchScore = similarity; bestMatchScore = similarity
bestMatchContent = originalChunk; bestMatchContent = originalChunk
} else { } else {
// Set bounds for buffered search // Set bounds for buffered search
searchStartIndex = Math.max(0, startLine - (this.bufferLines + 1)); searchStartIndex = Math.max(0, startLine - (this.bufferLines + 1))
searchEndIndex = Math.min(originalLines.length, endLine + this.bufferLines); searchEndIndex = Math.min(originalLines.length, endLine + this.bufferLines)
} }
} }
// If no match found yet, try middle-out search within bounds // If no match found yet, try middle-out search within bounds
if (matchIndex === -1) { if (matchIndex === -1) {
const midPoint = Math.floor((searchStartIndex + searchEndIndex) / 2); const midPoint = Math.floor((searchStartIndex + searchEndIndex) / 2)
let leftIndex = midPoint; let leftIndex = midPoint
let rightIndex = midPoint + 1; let rightIndex = midPoint + 1
// Search outward from the middle within bounds // Search outward from the middle within bounds
while (leftIndex >= searchStartIndex || rightIndex <= searchEndIndex - searchLines.length) { while (leftIndex >= searchStartIndex || rightIndex <= searchEndIndex - searchLines.length) {
// Check left side if still in range // Check left side if still in range
if (leftIndex >= searchStartIndex) { if (leftIndex >= searchStartIndex) {
const originalChunk = originalLines.slice(leftIndex, leftIndex + searchLines.length).join('\n'); const originalChunk = originalLines.slice(leftIndex, leftIndex + searchLines.length).join("\n")
const similarity = getSimilarity(originalChunk, searchChunk); const similarity = getSimilarity(originalChunk, searchChunk)
if (similarity > bestMatchScore) { if (similarity > bestMatchScore) {
bestMatchScore = similarity; bestMatchScore = similarity
matchIndex = leftIndex; matchIndex = leftIndex
bestMatchContent = originalChunk; bestMatchContent = originalChunk
} }
leftIndex--; leftIndex--
} }
// Check right side if still in range // Check right side if still in range
if (rightIndex <= searchEndIndex - searchLines.length) { if (rightIndex <= searchEndIndex - searchLines.length) {
const originalChunk = originalLines.slice(rightIndex, rightIndex + searchLines.length).join('\n'); const originalChunk = originalLines.slice(rightIndex, rightIndex + searchLines.length).join("\n")
const similarity = getSimilarity(originalChunk, searchChunk); const similarity = getSimilarity(originalChunk, searchChunk)
if (similarity > bestMatchScore) { if (similarity > bestMatchScore) {
bestMatchScore = similarity; bestMatchScore = similarity
matchIndex = rightIndex; matchIndex = rightIndex
bestMatchContent = originalChunk; bestMatchContent = originalChunk
} }
rightIndex++; rightIndex++
} }
} }
} }
// Require similarity to meet threshold // Require similarity to meet threshold
if (matchIndex === -1 || bestMatchScore < this.fuzzyThreshold) { if (matchIndex === -1 || bestMatchScore < this.fuzzyThreshold) {
const searchChunk = searchLines.join('\n'); const searchChunk = searchLines.join("\n")
const originalContentSection = startLine !== undefined && endLine !== undefined const originalContentSection =
startLine !== undefined && endLine !== undefined
? `\n\nOriginal Content:\n${addLineNumbers( ? `\n\nOriginal Content:\n${addLineNumbers(
originalLines.slice( originalLines
.slice(
Math.max(0, startLine - 1 - this.bufferLines), Math.max(0, startLine - 1 - this.bufferLines),
Math.min(originalLines.length, endLine + this.bufferLines) Math.min(originalLines.length, endLine + this.bufferLines),
).join('\n'), )
Math.max(1, startLine - this.bufferLines) .join("\n"),
Math.max(1, startLine - this.bufferLines),
)}` )}`
: `\n\nOriginal Content:\n${addLineNumbers(originalLines.join('\n'))}`; : `\n\nOriginal Content:\n${addLineNumbers(originalLines.join("\n"))}`
const bestMatchSection = bestMatchContent const bestMatchSection = bestMatchContent
? `\n\nBest Match Found:\n${addLineNumbers(bestMatchContent, matchIndex + 1)}` ? `\n\nBest Match Found:\n${addLineNumbers(bestMatchContent, matchIndex + 1)}`
: `\n\nBest Match Found:\n(no match)`; : `\n\nBest Match Found:\n(no match)`
const lineRange = startLine || endLine ? const lineRange =
` at ${startLine ? `start: ${startLine}` : 'start'} to ${endLine ? `end: ${endLine}` : 'end'}` : ''; startLine || endLine
? ` at ${startLine ? `start: ${startLine}` : "start"} to ${endLine ? `end: ${endLine}` : "end"}`
: ""
return { return {
success: false, success: false,
error: `No sufficiently similar match found${lineRange} (${Math.floor(bestMatchScore * 100)}% similar, needs ${Math.floor(this.fuzzyThreshold * 100)}%)\n\nDebug Info:\n- Similarity Score: ${Math.floor(bestMatchScore * 100)}%\n- Required Threshold: ${Math.floor(this.fuzzyThreshold * 100)}%\n- Search Range: ${startLine && endLine ? `lines ${startLine}-${endLine}` : 'start to end'}\n\nSearch Content:\n${searchChunk}${bestMatchSection}${originalContentSection}` error: `No sufficiently similar match found${lineRange} (${Math.floor(bestMatchScore * 100)}% similar, needs ${Math.floor(this.fuzzyThreshold * 100)}%)\n\nDebug Info:\n- Similarity Score: ${Math.floor(bestMatchScore * 100)}%\n- Required Threshold: ${Math.floor(this.fuzzyThreshold * 100)}%\n- Search Range: ${startLine && endLine ? `lines ${startLine}-${endLine}` : "start to end"}\n\nSearch Content:\n${searchChunk}${bestMatchSection}${originalContentSection}`,
}; }
} }
// Get the matched lines from the original content // Get the matched lines from the original content
const matchedLines = originalLines.slice(matchIndex, matchIndex + searchLines.length); const matchedLines = originalLines.slice(matchIndex, matchIndex + searchLines.length)
// Get the exact indentation (preserving tabs/spaces) of each line // Get the exact indentation (preserving tabs/spaces) of each line
const originalIndents = matchedLines.map(line => { const originalIndents = matchedLines.map((line) => {
const match = line.match(/^[\t ]*/); const match = line.match(/^[\t ]*/)
return match ? match[0] : ''; return match ? match[0] : ""
}); })
// Get the exact indentation of each line in the search block // Get the exact indentation of each line in the search block
const searchIndents = searchLines.map(line => { const searchIndents = searchLines.map((line) => {
const match = line.match(/^[\t ]*/); const match = line.match(/^[\t ]*/)
return match ? match[0] : ''; return match ? match[0] : ""
}); })
// Apply the replacement while preserving exact indentation // Apply the replacement while preserving exact indentation
const indentedReplaceLines = replaceLines.map((line, i) => { const indentedReplaceLines = replaceLines.map((line, i) => {
// Get the matched line's exact indentation // Get the matched line's exact indentation
const matchedIndent = originalIndents[0] || ''; const matchedIndent = originalIndents[0] || ""
// Get the current line's indentation relative to the search content // Get the current line's indentation relative to the search content
const currentIndentMatch = line.match(/^[\t ]*/); const currentIndentMatch = line.match(/^[\t ]*/)
const currentIndent = currentIndentMatch ? currentIndentMatch[0] : ''; const currentIndent = currentIndentMatch ? currentIndentMatch[0] : ""
const searchBaseIndent = searchIndents[0] || ''; const searchBaseIndent = searchIndents[0] || ""
// Calculate the relative indentation level // Calculate the relative indentation level
const searchBaseLevel = searchBaseIndent.length; const searchBaseLevel = searchBaseIndent.length
const currentLevel = currentIndent.length; const currentLevel = currentIndent.length
const relativeLevel = currentLevel - searchBaseLevel; const relativeLevel = currentLevel - searchBaseLevel
// If relative level is negative, remove indentation from matched indent // If relative level is negative, remove indentation from matched indent
// If positive, add to matched indent // If positive, add to matched indent
const finalIndent = relativeLevel < 0 const finalIndent =
relativeLevel < 0
? matchedIndent.slice(0, Math.max(0, matchedIndent.length + relativeLevel)) ? matchedIndent.slice(0, Math.max(0, matchedIndent.length + relativeLevel))
: matchedIndent + currentIndent.slice(searchBaseLevel); : matchedIndent + currentIndent.slice(searchBaseLevel)
return finalIndent + line.trim(); return finalIndent + line.trim()
}); })
// Construct the final content // Construct the final content
const beforeMatch = originalLines.slice(0, matchIndex); const beforeMatch = originalLines.slice(0, matchIndex)
const afterMatch = originalLines.slice(matchIndex + searchLines.length); const afterMatch = originalLines.slice(matchIndex + searchLines.length)
const finalContent = [...beforeMatch, ...indentedReplaceLines, ...afterMatch].join(lineEnding); const finalContent = [...beforeMatch, ...indentedReplaceLines, ...afterMatch].join(lineEnding)
return { return {
success: true, success: true,
content: finalContent content: finalContent,
}; }
} }
} }

View File

@@ -2,12 +2,12 @@ import { applyPatch } from "diff"
import { DiffStrategy, DiffResult } from "../types" import { DiffStrategy, DiffResult } from "../types"
export class UnifiedDiffStrategy implements DiffStrategy { export class UnifiedDiffStrategy implements DiffStrategy {
getToolDescription(cwd: string): string { getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
return `## apply_diff return `## apply_diff
Description: Apply a unified diff to a file at the specified path. This tool is useful when you need to make specific modifications to a file based on a set of changes provided in unified diff format (diff -U3). Description: Apply a unified diff to a file at the specified path. This tool is useful when you need to make specific modifications to a file based on a set of changes provided in unified diff format (diff -U3).
Parameters: Parameters:
- path: (required) The path of the file to apply the diff to (relative to the current working directory ${cwd}) - path: (required) The path of the file to apply the diff to (relative to the current working directory ${args.cwd})
- diff: (required) The diff content in unified format to apply to the file. - diff: (required) The diff content in unified format to apply to the file.
Format Requirements: Format Requirements:
@@ -116,21 +116,21 @@ Your diff here
success: false, success: false,
error: "Failed to apply unified diff - patch rejected", error: "Failed to apply unified diff - patch rejected",
details: { details: {
searchContent: diffContent searchContent: diffContent,
} },
} }
} }
return { return {
success: true, success: true,
content: result content: result,
} }
} catch (error) { } catch (error) {
return { return {
success: false, success: false,
error: `Error applying unified diff: ${error.message}`, error: `Error applying unified diff: ${error.message}`,
details: { details: {
searchContent: diffContent searchContent: diffContent,
} },
} }
} }
} }

View File

@@ -4,21 +4,25 @@
export type DiffResult = export type DiffResult =
| { success: true; content: string } | { success: true; content: string }
| { success: false; error: string; details?: { | {
similarity?: number; success: false
threshold?: number; error: string
matchedRange?: { start: number; end: number }; details?: {
searchContent?: string; similarity?: number
bestMatch?: string; threshold?: number
}}; matchedRange?: { start: number; end: number }
searchContent?: string
bestMatch?: string
}
}
export interface DiffStrategy { export interface DiffStrategy {
/** /**
* Get the tool description for this diff strategy * Get the tool description for this diff strategy
* @param cwd The current working directory * @param args The tool arguments including cwd and toolOptions
* @returns The complete tool description including format requirements and examples * @returns The complete tool description including format requirements and examples
*/ */
getToolDescription(cwd: string): string getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string
/** /**
* Apply a diff to the original content * Apply a diff to the original content

View File

@@ -1,20 +1,20 @@
// Create mock vscode module before importing anything // Create mock vscode module before importing anything
const createMockUri = (scheme: string, path: string) => ({ const createMockUri = (scheme: string, path: string) => ({
scheme, scheme,
authority: '', authority: "",
path, path,
query: '', query: "",
fragment: '', fragment: "",
fsPath: path, fsPath: path,
with: jest.fn(), with: jest.fn(),
toString: () => path, toString: () => path,
toJSON: () => ({ toJSON: () => ({
scheme, scheme,
authority: '', authority: "",
path, path,
query: '', query: "",
fragment: '' fragment: "",
}) }),
}) })
const mockExecuteCommand = jest.fn() const mockExecuteCommand = jest.fn()
@@ -23,9 +23,11 @@ const mockShowErrorMessage = jest.fn()
const mockVscode = { const mockVscode = {
workspace: { workspace: {
workspaceFolders: [{ workspaceFolders: [
uri: { fsPath: "/test/workspace" } {
}] uri: { fsPath: "/test/workspace" },
},
],
}, },
window: { window: {
showErrorMessage: mockShowErrorMessage, showErrorMessage: mockShowErrorMessage,
@@ -34,17 +36,17 @@ const mockVscode = {
createTextEditorDecorationType: jest.fn(), createTextEditorDecorationType: jest.fn(),
createOutputChannel: jest.fn(), createOutputChannel: jest.fn(),
createWebviewPanel: jest.fn(), createWebviewPanel: jest.fn(),
activeTextEditor: undefined activeTextEditor: undefined,
}, },
commands: { commands: {
executeCommand: mockExecuteCommand executeCommand: mockExecuteCommand,
}, },
env: { env: {
openExternal: mockOpenExternal openExternal: mockOpenExternal,
}, },
Uri: { Uri: {
parse: jest.fn((url: string) => createMockUri('https', url)), parse: jest.fn((url: string) => createMockUri("https", url)),
file: jest.fn((path: string) => createMockUri('file', path)) file: jest.fn((path: string) => createMockUri("file", path)),
}, },
Position: jest.fn(), Position: jest.fn(),
Range: jest.fn(), Range: jest.fn(),
@@ -54,12 +56,12 @@ const mockVscode = {
Error: 0, Error: 0,
Warning: 1, Warning: 1,
Information: 2, Information: 2,
Hint: 3 Hint: 3,
} },
} }
// Mock modules // Mock modules
jest.mock('vscode', () => mockVscode) jest.mock("vscode", () => mockVscode)
jest.mock("../../../services/browser/UrlContentFetcher") jest.mock("../../../services/browser/UrlContentFetcher")
jest.mock("../../../utils/git") jest.mock("../../../utils/git")
@@ -97,11 +99,7 @@ Detailed commit message with multiple lines
jest.mocked(git.getCommitInfo).mockResolvedValue(commitInfo) jest.mocked(git.getCommitInfo).mockResolvedValue(commitInfo)
const result = await parseMentions( const result = await parseMentions(`Check out this commit @${commitHash}`, mockCwd, mockUrlContentFetcher)
`Check out this commit @${commitHash}`,
mockCwd,
mockUrlContentFetcher
)
expect(result).toContain(`'${commitHash}' (see below for commit info)`) expect(result).toContain(`'${commitHash}' (see below for commit info)`)
expect(result).toContain(`<git_commit hash="${commitHash}">`) expect(result).toContain(`<git_commit hash="${commitHash}">`)
@@ -114,11 +112,7 @@ Detailed commit message with multiple lines
jest.mocked(git.getCommitInfo).mockRejectedValue(new Error(errorMessage)) jest.mocked(git.getCommitInfo).mockRejectedValue(new Error(errorMessage))
const result = await parseMentions( const result = await parseMentions(`Check out this commit @${commitHash}`, mockCwd, mockUrlContentFetcher)
`Check out this commit @${commitHash}`,
mockCwd,
mockUrlContentFetcher
)
expect(result).toContain(`'${commitHash}' (see below for commit info)`) expect(result).toContain(`'${commitHash}' (see below for commit info)`)
expect(result).toContain(`<git_commit hash="${commitHash}">`) expect(result).toContain(`<git_commit hash="${commitHash}">`)
@@ -143,13 +137,15 @@ Detailed commit message with multiple lines
const mockUri = mockVscode.Uri.parse(url) const mockUri = mockVscode.Uri.parse(url)
expect(mockOpenExternal).toHaveBeenCalled() expect(mockOpenExternal).toHaveBeenCalled()
const calledArg = mockOpenExternal.mock.calls[0][0] const calledArg = mockOpenExternal.mock.calls[0][0]
expect(calledArg).toEqual(expect.objectContaining({ expect(calledArg).toEqual(
expect.objectContaining({
scheme: mockUri.scheme, scheme: mockUri.scheme,
authority: mockUri.authority, authority: mockUri.authority,
path: mockUri.path, path: mockUri.path,
query: mockUri.query, query: mockUri.query,
fragment: mockUri.fragment fragment: mockUri.fragment,
})) }),
)
}) })
}) })
}) })

View File

@@ -1,32 +1,10 @@
import { Mode } from './prompts/types' import { Mode, isToolAllowedForMode, TestToolName, getModeConfig } from "../shared/modes"
import { codeMode } from './prompts/system'
import { CODE_ALLOWED_TOOLS, READONLY_ALLOWED_TOOLS, ToolName, ReadOnlyToolName } from './tool-lists'
// Extended tool type that includes 'unknown_tool' for testing export { isToolAllowedForMode }
export type TestToolName = ToolName | 'unknown_tool'; export type { TestToolName }
// Type guard to check if a tool is a valid tool
function isValidTool(tool: TestToolName): tool is ToolName {
return CODE_ALLOWED_TOOLS.includes(tool as ToolName);
}
// Type guard to check if a tool is a read-only tool
function isReadOnlyTool(tool: TestToolName): tool is ReadOnlyToolName {
return READONLY_ALLOWED_TOOLS.includes(tool as ReadOnlyToolName);
}
export function isToolAllowedForMode(toolName: TestToolName, mode: Mode): boolean {
if (mode === codeMode) {
return isValidTool(toolName);
}
// Both architect and ask modes use the same read-only tools
return isReadOnlyTool(toolName);
}
export function validateToolUse(toolName: TestToolName, mode: Mode): void { export function validateToolUse(toolName: TestToolName, mode: Mode): void {
if (!isToolAllowedForMode(toolName, mode)) { if (!isToolAllowedForMode(toolName, mode)) {
throw new Error( throw new Error(`Tool "${toolName}" is not allowed in ${mode} mode.`)
`Tool "${toolName}" is not allowed in ${mode} mode.`
);
} }
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,139 +0,0 @@
import { ARCHITECT_PROMPT } from '../architect'
import { McpHub } from '../../../services/mcp/McpHub'
import { SearchReplaceDiffStrategy } from '../../../core/diff/strategies/search-replace'
import fs from 'fs/promises'
import os from 'os'
// Import path utils to get access to toPosix string extension
import '../../../utils/path'
// Mock environment-specific values for consistent tests
jest.mock('os', () => ({
...jest.requireActual('os'),
homedir: () => '/home/user'
}))
jest.mock('default-shell', () => '/bin/bash')
jest.mock('os-name', () => () => 'Linux')
// Mock fs.readFile to return empty mcpServers config
jest.mock('fs/promises', () => ({
...jest.requireActual('fs/promises'),
readFile: jest.fn().mockImplementation(async (path: string) => {
if (path.endsWith('mcpSettings.json')) {
return '{"mcpServers": {}}'
}
if (path.endsWith('.clinerules')) {
return '# Test Rules\n1. First rule\n2. Second rule'
}
return ''
}),
writeFile: jest.fn().mockResolvedValue(undefined)
}))
// Instead of extending McpHub, create a mock that implements just what we need
const createMockMcpHub = (): McpHub => ({
getServers: () => [],
getMcpServersPath: async () => '/mock/mcp/path',
getMcpSettingsFilePath: async () => '/mock/settings/path',
dispose: async () => {},
// Add other required public methods with no-op implementations
restartConnection: async () => {},
readResource: async () => ({ contents: [] }),
callTool: async () => ({ content: [] }),
toggleServerDisabled: async () => {},
toggleToolAlwaysAllow: async () => {},
isConnecting: false,
connections: []
} as unknown as McpHub)
describe('ARCHITECT_PROMPT', () => {
let mockMcpHub: McpHub
beforeEach(() => {
jest.clearAllMocks()
})
afterEach(async () => {
// Clean up any McpHub instances
if (mockMcpHub) {
await mockMcpHub.dispose()
}
})
it('should maintain consistent architect prompt', async () => {
const prompt = await ARCHITECT_PROMPT(
'/test/path',
false, // supportsComputerUse
undefined, // mcpHub
undefined, // diffStrategy
undefined // browserViewportSize
)
expect(prompt).toMatchSnapshot()
})
it('should include browser actions when supportsComputerUse is true', async () => {
const prompt = await ARCHITECT_PROMPT(
'/test/path',
true,
undefined,
undefined,
'1280x800'
)
expect(prompt).toMatchSnapshot()
})
it('should include MCP server info when mcpHub is provided', async () => {
mockMcpHub = createMockMcpHub()
const prompt = await ARCHITECT_PROMPT(
'/test/path',
false,
mockMcpHub
)
expect(prompt).toMatchSnapshot()
})
it('should explicitly handle undefined mcpHub', async () => {
const prompt = await ARCHITECT_PROMPT(
'/test/path',
false,
undefined, // explicitly undefined mcpHub
undefined,
undefined
)
expect(prompt).toMatchSnapshot()
})
it('should handle different browser viewport sizes', async () => {
const prompt = await ARCHITECT_PROMPT(
'/test/path',
true,
undefined,
undefined,
'900x600' // different viewport size
)
expect(prompt).toMatchSnapshot()
})
it('should include diff strategy tool description', async () => {
const prompt = await ARCHITECT_PROMPT(
'/test/path',
false,
undefined,
new SearchReplaceDiffStrategy(), // Use actual diff strategy from the codebase
undefined
)
expect(prompt).toMatchSnapshot()
})
afterAll(() => {
jest.restoreAllMocks()
})
})

View File

@@ -1,139 +0,0 @@
import { ASK_PROMPT } from '../ask'
import { McpHub } from '../../../services/mcp/McpHub'
import { SearchReplaceDiffStrategy } from '../../../core/diff/strategies/search-replace'
import fs from 'fs/promises'
import os from 'os'
// Import path utils to get access to toPosix string extension
import '../../../utils/path'
// Mock environment-specific values for consistent tests
jest.mock('os', () => ({
...jest.requireActual('os'),
homedir: () => '/home/user'
}))
jest.mock('default-shell', () => '/bin/bash')
jest.mock('os-name', () => () => 'Linux')
// Mock fs.readFile to return empty mcpServers config
jest.mock('fs/promises', () => ({
...jest.requireActual('fs/promises'),
readFile: jest.fn().mockImplementation(async (path: string) => {
if (path.endsWith('mcpSettings.json')) {
return '{"mcpServers": {}}'
}
if (path.endsWith('.clinerules')) {
return '# Test Rules\n1. First rule\n2. Second rule'
}
return ''
}),
writeFile: jest.fn().mockResolvedValue(undefined)
}))
// Instead of extending McpHub, create a mock that implements just what we need
const createMockMcpHub = (): McpHub => ({
getServers: () => [],
getMcpServersPath: async () => '/mock/mcp/path',
getMcpSettingsFilePath: async () => '/mock/settings/path',
dispose: async () => {},
// Add other required public methods with no-op implementations
restartConnection: async () => {},
readResource: async () => ({ contents: [] }),
callTool: async () => ({ content: [] }),
toggleServerDisabled: async () => {},
toggleToolAlwaysAllow: async () => {},
isConnecting: false,
connections: []
} as unknown as McpHub)
describe('ASK_PROMPT', () => {
let mockMcpHub: McpHub
beforeEach(() => {
jest.clearAllMocks()
})
afterEach(async () => {
// Clean up any McpHub instances
if (mockMcpHub) {
await mockMcpHub.dispose()
}
})
it('should maintain consistent ask prompt', async () => {
const prompt = await ASK_PROMPT(
'/test/path',
false, // supportsComputerUse
undefined, // mcpHub
undefined, // diffStrategy
undefined // browserViewportSize
)
expect(prompt).toMatchSnapshot()
})
it('should include browser actions when supportsComputerUse is true', async () => {
const prompt = await ASK_PROMPT(
'/test/path',
true,
undefined,
undefined,
'1280x800'
)
expect(prompt).toMatchSnapshot()
})
it('should include MCP server info when mcpHub is provided', async () => {
mockMcpHub = createMockMcpHub()
const prompt = await ASK_PROMPT(
'/test/path',
false,
mockMcpHub
)
expect(prompt).toMatchSnapshot()
})
it('should explicitly handle undefined mcpHub', async () => {
const prompt = await ASK_PROMPT(
'/test/path',
false,
undefined, // explicitly undefined mcpHub
undefined,
undefined
)
expect(prompt).toMatchSnapshot()
})
it('should handle different browser viewport sizes', async () => {
const prompt = await ASK_PROMPT(
'/test/path',
true,
undefined,
undefined,
'900x600' // different viewport size
)
expect(prompt).toMatchSnapshot()
})
it('should include diff strategy tool description', async () => {
const prompt = await ASK_PROMPT(
'/test/path',
false,
undefined,
new SearchReplaceDiffStrategy(), // Use actual diff strategy from the codebase
undefined
)
expect(prompt).toMatchSnapshot()
})
afterAll(() => {
jest.restoreAllMocks()
})
})

View File

@@ -1,67 +1,68 @@
import { SYSTEM_PROMPT, addCustomInstructions } from '../system' import { SYSTEM_PROMPT, addCustomInstructions } from "../system"
import { McpHub } from '../../../services/mcp/McpHub' import { McpHub } from "../../../services/mcp/McpHub"
import { McpServer } from '../../../shared/mcp' import { McpServer } from "../../../shared/mcp"
import { ClineProvider } from '../../../core/webview/ClineProvider' import { ClineProvider } from "../../../core/webview/ClineProvider"
import { SearchReplaceDiffStrategy } from '../../../core/diff/strategies/search-replace' import { SearchReplaceDiffStrategy } from "../../../core/diff/strategies/search-replace"
import fs from 'fs/promises' import fs from "fs/promises"
import os from 'os' import os from "os"
import { codeMode, askMode, architectMode } from '../modes' import { defaultModeSlug, modes } from "../../../shared/modes"
// Import path utils to get access to toPosix string extension // Import path utils to get access to toPosix string extension
import '../../../utils/path' import "../../../utils/path"
// Mock environment-specific values for consistent tests // Mock environment-specific values for consistent tests
jest.mock('os', () => ({ jest.mock("os", () => ({
...jest.requireActual('os'), ...jest.requireActual("os"),
homedir: () => '/home/user' homedir: () => "/home/user",
})) }))
jest.mock('default-shell', () => '/bin/bash') jest.mock("default-shell", () => "/bin/bash")
jest.mock('os-name', () => () => 'Linux') jest.mock("os-name", () => () => "Linux")
// Mock fs.readFile to return empty mcpServers config and mock rules files // Mock fs.readFile to return empty mcpServers config and mock rules files
jest.mock('fs/promises', () => ({ jest.mock("fs/promises", () => ({
...jest.requireActual('fs/promises'), ...jest.requireActual("fs/promises"),
readFile: jest.fn().mockImplementation(async (path: string) => { readFile: jest.fn().mockImplementation(async (path: string) => {
if (path.endsWith('mcpSettings.json')) { if (path.endsWith("mcpSettings.json")) {
return '{"mcpServers": {}}' return '{"mcpServers": {}}'
} }
if (path.endsWith('.clinerules-code')) { if (path.endsWith(".clinerules-code")) {
return '# Code Mode Rules\n1. Code specific rule' return "# Code Mode Rules\n1. Code specific rule"
} }
if (path.endsWith('.clinerules-ask')) { if (path.endsWith(".clinerules-ask")) {
return '# Ask Mode Rules\n1. Ask specific rule' return "# Ask Mode Rules\n1. Ask specific rule"
} }
if (path.endsWith('.clinerules-architect')) { if (path.endsWith(".clinerules-architect")) {
return '# Architect Mode Rules\n1. Architect specific rule' return "# Architect Mode Rules\n1. Architect specific rule"
} }
if (path.endsWith('.clinerules')) { if (path.endsWith(".clinerules")) {
return '# Test Rules\n1. First rule\n2. Second rule' return "# Test Rules\n1. First rule\n2. Second rule"
} }
return '' return ""
}), }),
writeFile: jest.fn().mockResolvedValue(undefined) writeFile: jest.fn().mockResolvedValue(undefined),
})) }))
// Create a minimal mock of ClineProvider // Create a minimal mock of ClineProvider
const mockProvider = { const mockProvider = {
ensureMcpServersDirectoryExists: async () => '/mock/mcp/path', ensureMcpServersDirectoryExists: async () => "/mock/mcp/path",
ensureSettingsDirectoryExists: async () => '/mock/settings/path', ensureSettingsDirectoryExists: async () => "/mock/settings/path",
postMessageToWebview: async () => {}, postMessageToWebview: async () => {},
context: { context: {
extension: { extension: {
packageJSON: { packageJSON: {
version: '1.0.0' version: "1.0.0",
} },
} },
} },
} as unknown as ClineProvider } as unknown as ClineProvider
// Instead of extending McpHub, create a mock that implements just what we need // Instead of extending McpHub, create a mock that implements just what we need
const createMockMcpHub = (): McpHub => ({ const createMockMcpHub = (): McpHub =>
({
getServers: () => [], getServers: () => [],
getMcpServersPath: async () => '/mock/mcp/path', getMcpServersPath: async () => "/mock/mcp/path",
getMcpSettingsFilePath: async () => '/mock/settings/path', getMcpSettingsFilePath: async () => "/mock/settings/path",
dispose: async () => {}, dispose: async () => {},
// Add other required public methods with no-op implementations // Add other required public methods with no-op implementations
restartConnection: async () => {}, restartConnection: async () => {},
@@ -70,10 +71,10 @@ const createMockMcpHub = (): McpHub => ({
toggleServerDisabled: async () => {}, toggleServerDisabled: async () => {},
toggleToolAlwaysAllow: async () => {}, toggleToolAlwaysAllow: async () => {},
isConnecting: false, isConnecting: false,
connections: [] connections: [],
} as unknown as McpHub) }) as unknown as McpHub
describe('SYSTEM_PROMPT', () => { describe("SYSTEM_PROMPT", () => {
let mockMcpHub: McpHub let mockMcpHub: McpHub
beforeEach(() => { beforeEach(() => {
@@ -87,73 +88,63 @@ describe('SYSTEM_PROMPT', () => {
} }
}) })
it('should maintain consistent system prompt', async () => { it("should maintain consistent system prompt", async () => {
const prompt = await SYSTEM_PROMPT( const prompt = await SYSTEM_PROMPT(
'/test/path', "/test/path",
false, // supportsComputerUse false, // supportsComputerUse
undefined, // mcpHub undefined, // mcpHub
undefined, // diffStrategy undefined, // diffStrategy
undefined // browserViewportSize undefined, // browserViewportSize
) )
expect(prompt).toMatchSnapshot() expect(prompt).toMatchSnapshot()
}) })
it('should include browser actions when supportsComputerUse is true', async () => { it("should include browser actions when supportsComputerUse is true", async () => {
const prompt = await SYSTEM_PROMPT( const prompt = await SYSTEM_PROMPT("/test/path", true, undefined, undefined, "1280x800")
'/test/path',
true,
undefined,
undefined,
'1280x800'
)
expect(prompt).toMatchSnapshot() expect(prompt).toMatchSnapshot()
}) })
it('should include MCP server info when mcpHub is provided', async () => { it("should include MCP server info when mcpHub is provided", async () => {
mockMcpHub = createMockMcpHub() mockMcpHub = createMockMcpHub()
const prompt = await SYSTEM_PROMPT( const prompt = await SYSTEM_PROMPT("/test/path", false, mockMcpHub)
'/test/path',
false,
mockMcpHub
)
expect(prompt).toMatchSnapshot() expect(prompt).toMatchSnapshot()
}) })
it('should explicitly handle undefined mcpHub', async () => { it("should explicitly handle undefined mcpHub", async () => {
const prompt = await SYSTEM_PROMPT( const prompt = await SYSTEM_PROMPT(
'/test/path', "/test/path",
false, false,
undefined, // explicitly undefined mcpHub undefined, // explicitly undefined mcpHub
undefined, undefined,
undefined undefined,
) )
expect(prompt).toMatchSnapshot() expect(prompt).toMatchSnapshot()
}) })
it('should handle different browser viewport sizes', async () => { it("should handle different browser viewport sizes", async () => {
const prompt = await SYSTEM_PROMPT( const prompt = await SYSTEM_PROMPT(
'/test/path', "/test/path",
true, true,
undefined, undefined,
undefined, undefined,
'900x600' // different viewport size "900x600", // different viewport size
) )
expect(prompt).toMatchSnapshot() expect(prompt).toMatchSnapshot()
}) })
it('should include diff strategy tool description', async () => { it("should include diff strategy tool description", async () => {
const prompt = await SYSTEM_PROMPT( const prompt = await SYSTEM_PROMPT(
'/test/path', "/test/path",
false, false,
undefined, undefined,
new SearchReplaceDiffStrategy(), // Use actual diff strategy from the codebase new SearchReplaceDiffStrategy(), // Use actual diff strategy from the codebase
undefined undefined,
) )
expect(prompt).toMatchSnapshot() expect(prompt).toMatchSnapshot()
@@ -164,151 +155,197 @@ describe('SYSTEM_PROMPT', () => {
}) })
}) })
describe('addCustomInstructions', () => { describe("addCustomInstructions", () => {
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks() jest.clearAllMocks()
}) })
it('should prioritize mode-specific rules for code mode', async () => { it("should generate correct prompt for architect mode", async () => {
const instructions = await addCustomInstructions( const prompt = await SYSTEM_PROMPT("/test/path", false, undefined, undefined, undefined, "architect")
{},
'/test/path', expect(prompt).toMatchSnapshot()
codeMode })
)
it("should generate correct prompt for ask mode", async () => {
const prompt = await SYSTEM_PROMPT("/test/path", false, undefined, undefined, undefined, "ask")
expect(prompt).toMatchSnapshot()
})
it("should prioritize mode-specific rules for code mode", async () => {
const instructions = await addCustomInstructions({}, "/test/path", defaultModeSlug)
expect(instructions).toMatchSnapshot() expect(instructions).toMatchSnapshot()
}) })
it('should prioritize mode-specific rules for ask mode', async () => { it("should prioritize mode-specific rules for ask mode", async () => {
const instructions = await addCustomInstructions( const instructions = await addCustomInstructions({}, "/test/path", modes[2].slug)
{},
'/test/path',
askMode
)
expect(instructions).toMatchSnapshot() expect(instructions).toMatchSnapshot()
}) })
it('should prioritize mode-specific rules for architect mode', async () => { it("should prioritize mode-specific rules for architect mode", async () => {
const instructions = await addCustomInstructions( const instructions = await addCustomInstructions({}, "/test/path", modes[1].slug)
{},
'/test/path',
architectMode
)
expect(instructions).toMatchSnapshot() expect(instructions).toMatchSnapshot()
}) })
it('should fall back to generic rules when mode-specific rules not found', async () => { it("should prioritize mode-specific rules for test engineer mode", async () => {
// Mock readFile to include test engineer rules
const mockReadFile = jest.fn().mockImplementation(async (path: string) => {
if (path.endsWith(".clinerules-test")) {
return "# Test Engineer Rules\n1. Always write tests first\n2. Get approval before modifying non-test code"
}
if (path.endsWith(".clinerules")) {
return "# Test Rules\n1. First rule\n2. Second rule"
}
return ""
})
jest.spyOn(fs, "readFile").mockImplementation(mockReadFile)
const instructions = await addCustomInstructions({}, "/test/path", "test")
expect(instructions).toMatchSnapshot()
})
it("should prioritize mode-specific rules for code reviewer mode", async () => {
// Mock readFile to include code reviewer rules
const mockReadFile = jest.fn().mockImplementation(async (path: string) => {
if (path.endsWith(".clinerules-review")) {
return "# Code Reviewer Rules\n1. Provide specific examples in feedback\n2. Focus on maintainability and best practices"
}
if (path.endsWith(".clinerules")) {
return "# Test Rules\n1. First rule\n2. Second rule"
}
return ""
})
jest.spyOn(fs, "readFile").mockImplementation(mockReadFile)
const instructions = await addCustomInstructions({}, "/test/path", "review")
expect(instructions).toMatchSnapshot()
})
it("should generate correct prompt for test engineer mode", async () => {
const prompt = await SYSTEM_PROMPT("/test/path", false, undefined, undefined, undefined, "test")
// Verify test engineer role requirements
expect(prompt).toContain("must ask the user to confirm before making ANY changes to non-test code")
expect(prompt).toContain("ask the user to confirm your test plan")
expect(prompt).toMatchSnapshot()
})
it("should generate correct prompt for code reviewer mode", async () => {
const prompt = await SYSTEM_PROMPT("/test/path", false, undefined, undefined, undefined, "review")
// Verify code reviewer role constraints
expect(prompt).toContain("providing detailed, actionable feedback")
expect(prompt).toContain("maintain a read-only approach")
expect(prompt).toMatchSnapshot()
})
it("should fall back to generic rules when mode-specific rules not found", async () => {
// Mock readFile to return ENOENT for mode-specific file // Mock readFile to return ENOENT for mode-specific file
const mockReadFile = jest.fn().mockImplementation(async (path: string) => { const mockReadFile = jest.fn().mockImplementation(async (path: string) => {
if (path.endsWith('.clinerules-code')) { if (
const error = new Error('ENOENT') as NodeJS.ErrnoException path.endsWith(".clinerules-code") ||
error.code = 'ENOENT' path.endsWith(".clinerules-test") ||
path.endsWith(".clinerules-review")
) {
const error = new Error("ENOENT") as NodeJS.ErrnoException
error.code = "ENOENT"
throw error throw error
} }
if (path.endsWith('.clinerules')) { if (path.endsWith(".clinerules")) {
return '# Test Rules\n1. First rule\n2. Second rule' return "# Test Rules\n1. First rule\n2. Second rule"
} }
return '' return ""
}) })
jest.spyOn(fs, 'readFile').mockImplementation(mockReadFile) jest.spyOn(fs, "readFile").mockImplementation(mockReadFile)
const instructions = await addCustomInstructions({}, "/test/path", defaultModeSlug)
expect(instructions).toMatchSnapshot()
})
it("should include preferred language when provided", async () => {
const instructions = await addCustomInstructions( const instructions = await addCustomInstructions(
{}, { preferredLanguage: "Spanish" },
'/test/path', "/test/path",
codeMode defaultModeSlug,
) )
expect(instructions).toMatchSnapshot() expect(instructions).toMatchSnapshot()
}) })
it('should include preferred language when provided', async () => { it("should include custom instructions when provided", async () => {
const instructions = await addCustomInstructions( const instructions = await addCustomInstructions(
{ preferredLanguage: 'Spanish' }, { customInstructions: "Custom test instructions" },
'/test/path', "/test/path",
codeMode
) )
expect(instructions).toMatchSnapshot() expect(instructions).toMatchSnapshot()
}) })
it('should include custom instructions when provided', async () => { it("should combine all custom instructions", async () => {
const instructions = await addCustomInstructions(
{ customInstructions: 'Custom test instructions' },
'/test/path'
)
expect(instructions).toMatchSnapshot()
})
it('should combine all custom instructions', async () => {
const instructions = await addCustomInstructions( const instructions = await addCustomInstructions(
{ {
customInstructions: 'Custom test instructions', customInstructions: "Custom test instructions",
preferredLanguage: 'French' preferredLanguage: "French",
}, },
'/test/path', "/test/path",
codeMode defaultModeSlug,
) )
expect(instructions).toMatchSnapshot() expect(instructions).toMatchSnapshot()
}) })
it('should handle undefined mode-specific instructions', async () => { it("should handle undefined mode-specific instructions", async () => {
const instructions = await addCustomInstructions({}, "/test/path")
expect(instructions).toMatchSnapshot()
})
it("should trim mode-specific instructions", async () => {
const instructions = await addCustomInstructions( const instructions = await addCustomInstructions(
{}, { customInstructions: " Custom mode instructions " },
'/test/path' "/test/path",
) )
expect(instructions).toMatchSnapshot() expect(instructions).toMatchSnapshot()
}) })
it('should trim mode-specific instructions', async () => { it("should handle empty mode-specific instructions", async () => {
const instructions = await addCustomInstructions( const instructions = await addCustomInstructions({ customInstructions: "" }, "/test/path")
{ customInstructions: ' Custom mode instructions ' },
'/test/path'
)
expect(instructions).toMatchSnapshot() expect(instructions).toMatchSnapshot()
}) })
it('should handle empty mode-specific instructions', async () => { it("should combine global and mode-specific instructions", async () => {
const instructions = await addCustomInstructions(
{ customInstructions: '' },
'/test/path'
)
expect(instructions).toMatchSnapshot()
})
it('should combine global and mode-specific instructions', async () => {
const instructions = await addCustomInstructions( const instructions = await addCustomInstructions(
{ {
customInstructions: 'Global instructions', customInstructions: "Global instructions",
customPrompts: { customPrompts: {
code: { customInstructions: 'Mode-specific instructions' } code: { customInstructions: "Mode-specific instructions" },
}
}, },
'/test/path', },
codeMode "/test/path",
defaultModeSlug,
) )
expect(instructions).toMatchSnapshot() expect(instructions).toMatchSnapshot()
}) })
it('should prioritize mode-specific instructions after global ones', async () => { it("should prioritize mode-specific instructions after global ones", async () => {
const instructions = await addCustomInstructions( const instructions = await addCustomInstructions(
{ {
customInstructions: 'First instruction', customInstructions: "First instruction",
customPrompts: { customPrompts: {
code: { customInstructions: 'Second instruction' } code: { customInstructions: "Second instruction" },
}
}, },
'/test/path', },
codeMode "/test/path",
defaultModeSlug,
) )
const instructionParts = instructions.split('\n\n') const instructionParts = instructions.split("\n\n")
const globalIndex = instructionParts.findIndex(part => part.includes('First instruction')) const globalIndex = instructionParts.findIndex((part) => part.includes("First instruction"))
const modeSpecificIndex = instructionParts.findIndex(part => part.includes('Second instruction')) const modeSpecificIndex = instructionParts.findIndex((part) => part.includes("Second instruction"))
expect(globalIndex).toBeLessThan(modeSpecificIndex) expect(globalIndex).toBeLessThan(modeSpecificIndex)
expect(instructions).toMatchSnapshot() expect(instructions).toMatchSnapshot()

View File

@@ -1,40 +0,0 @@
import { architectMode, defaultPrompts, PromptComponent } from "../../shared/modes"
import { getToolDescriptionsForMode } from "./tools"
import {
getRulesSection,
getSystemInfoSection,
getObjectiveSection,
getSharedToolUseSection,
getMcpServersSection,
getToolUseGuidelinesSection,
getCapabilitiesSection
} from "./sections"
import { DiffStrategy } from "../diff/DiffStrategy"
import { McpHub } from "../../services/mcp/McpHub"
export const mode = architectMode
export const ARCHITECT_PROMPT = async (
cwd: string,
supportsComputerUse: boolean,
mcpHub?: McpHub,
diffStrategy?: DiffStrategy,
browserViewportSize?: string,
customPrompt?: PromptComponent,
) => `${customPrompt?.roleDefinition || defaultPrompts[architectMode].roleDefinition}
${getSharedToolUseSection()}
${getToolDescriptionsForMode(mode, cwd, supportsComputerUse, diffStrategy, browserViewportSize, mcpHub)}
${getToolUseGuidelinesSection()}
${await getMcpServersSection(mcpHub, diffStrategy)}
${getCapabilitiesSection(cwd, supportsComputerUse, mcpHub, diffStrategy)}
${getRulesSection(cwd, supportsComputerUse, diffStrategy)}
${getSystemInfoSection(cwd)}
${getObjectiveSection()}`

View File

@@ -1,40 +0,0 @@
import { Mode, askMode, defaultPrompts, PromptComponent } from "../../shared/modes"
import { getToolDescriptionsForMode } from "./tools"
import {
getRulesSection,
getSystemInfoSection,
getObjectiveSection,
getSharedToolUseSection,
getMcpServersSection,
getToolUseGuidelinesSection,
getCapabilitiesSection
} from "./sections"
import { DiffStrategy } from "../diff/DiffStrategy"
import { McpHub } from "../../services/mcp/McpHub"
export const mode = askMode
export const ASK_PROMPT = async (
cwd: string,
supportsComputerUse: boolean,
mcpHub?: McpHub,
diffStrategy?: DiffStrategy,
browserViewportSize?: string,
customPrompt?: PromptComponent,
) => `${customPrompt?.roleDefinition || defaultPrompts[askMode].roleDefinition}
${getSharedToolUseSection()}
${getToolDescriptionsForMode(mode, cwd, supportsComputerUse, diffStrategy, browserViewportSize, mcpHub)}
${getToolUseGuidelinesSection()}
${await getMcpServersSection(mcpHub, diffStrategy)}
${getCapabilitiesSection(cwd, supportsComputerUse, mcpHub, diffStrategy)}
${getRulesSection(cwd, supportsComputerUse, diffStrategy)}
${getSystemInfoSection(cwd)}
${getObjectiveSection()}`

View File

@@ -1,40 +0,0 @@
import { Mode, codeMode, defaultPrompts, PromptComponent } from "../../shared/modes"
import { getToolDescriptionsForMode } from "./tools"
import {
getRulesSection,
getSystemInfoSection,
getObjectiveSection,
getSharedToolUseSection,
getMcpServersSection,
getToolUseGuidelinesSection,
getCapabilitiesSection
} from "./sections"
import { DiffStrategy } from "../diff/DiffStrategy"
import { McpHub } from "../../services/mcp/McpHub"
export const mode: Mode = codeMode
export const CODE_PROMPT = async (
cwd: string,
supportsComputerUse: boolean,
mcpHub?: McpHub,
diffStrategy?: DiffStrategy,
browserViewportSize?: string,
customPrompt?: PromptComponent,
) => `${customPrompt?.roleDefinition || defaultPrompts[codeMode].roleDefinition}
${getSharedToolUseSection()}
${getToolDescriptionsForMode(mode, cwd, supportsComputerUse, diffStrategy, browserViewportSize, mcpHub)}
${getToolUseGuidelinesSection()}
${await getMcpServersSection(mcpHub, diffStrategy)}
${getCapabilitiesSection(cwd, supportsComputerUse, mcpHub, diffStrategy)}
${getRulesSection(cwd, supportsComputerUse, diffStrategy)}
${getSystemInfoSection(cwd)}
${getObjectiveSection()}`

View File

@@ -1,5 +0,0 @@
export const codeMode = 'code' as const;
export const architectMode = 'architect' as const;
export const askMode = 'ask' as const;
export type Mode = typeof codeMode | typeof architectMode | typeof askMode;

View File

@@ -13,7 +13,7 @@ CAPABILITIES
- You have access to tools that let you execute CLI commands on the user's computer, list files, view source code definitions, regex search${ - You have access to tools that let you execute CLI commands on the user's computer, list files, view source code definitions, regex search${
supportsComputerUse ? ", use the browser" : "" supportsComputerUse ? ", use the browser" : ""
}, read and write files, and ask follow-up questions. These tools help you effectively accomplish a wide range of tasks, such as writing code, making edits or improvements to existing files, understanding the current state of a project, performing system operations, and much more. }, read and write files, and ask follow-up questions. These tools help you effectively accomplish a wide range of tasks, such as writing code, making edits or improvements to existing files, understanding the current state of a project, performing system operations, and much more.
- When the user initially gives you a task, a recursive list of all filepaths in the current working directory ('${cwd}') will be included in environment_details. This provides an overview of the project's file structure, offering key insights into the project from directory/file names (how developers conceptualize and organize their code) and file extensions (the language used). This can also guide decision-making on which files to explore further. If you need to further explore directories such as outside the current working directory, you can use the list_files tool. If you pass 'true' for the recursive parameter, it will list files recursively. Otherwise, it will list files at the top level, which is better suited for generic directories where you don't necessarily need the nested structure, like the Desktop. - When the user initially gives you a task, a recursive list of all filepaths in the current working directory ('${cwd}') will be included in environment_details. This provides an overview of the project's file structure, offering key insights into the project from directory/file names (how developers conceptualize and organize their code) and file extensions (the language used). This can also guide decision-making on which files to explore further. If you need to further explore directories such as outside the current working directory, you can use the list_files tool. If you pass 'true' for the recursive parameter, it will list files recursively. Otherwise, it will list files at the top level, which is better suited for generic directories where you don't necessarily need the nested structure, like the Desktop.
- You can use search_files to perform regex searches across files in a specified directory, outputting context-rich results that include surrounding lines. This is particularly useful for understanding code patterns, finding specific implementations, or identifying areas that need refactoring. - You can use search_files to perform regex searches across files in a specified directory, outputting context-rich results that include surrounding lines. This is particularly useful for understanding code patterns, finding specific implementations, or identifying areas that need refactoring.
- You can use the list_code_definition_names tool to get an overview of source code definitions for all files at the top level of a specified directory. This can be particularly useful when you need to understand the broader context and relationships between certain parts of the code. You may need to call this tool multiple times to understand various parts of the codebase related to the task. - You can use the list_code_definition_names tool to get an overview of source code definitions for all files at the top level of a specified directory. This can be particularly useful when you need to understand the broader context and relationships between certain parts of the code. You may need to call this tool multiple times to understand various parts of the codebase related to the task.
@@ -22,7 +22,11 @@ CAPABILITIES
supportsComputerUse supportsComputerUse
? "\n- You can use the browser_action tool to interact with websites (including html files and locally running development servers) through a Puppeteer-controlled browser when you feel it is necessary in accomplishing the user's task. This tool is particularly useful for web development tasks as it allows you to launch a browser, navigate to pages, interact with elements through clicks and keyboard input, and capture the results through screenshots and console logs. This tool may be useful at key stages of web development tasks-such as after implementing new features, making substantial changes, when troubleshooting issues, or to verify the result of your work. You can analyze the provided screenshots to ensure correct rendering or identify errors, and review console logs for runtime issues.\n - For example, if asked to add a component to a react website, you might create the necessary files, use execute_command to run the site locally, then use browser_action to launch the browser, navigate to the local server, and verify the component renders & functions correctly before closing the browser." ? "\n- You can use the browser_action tool to interact with websites (including html files and locally running development servers) through a Puppeteer-controlled browser when you feel it is necessary in accomplishing the user's task. This tool is particularly useful for web development tasks as it allows you to launch a browser, navigate to pages, interact with elements through clicks and keyboard input, and capture the results through screenshots and console logs. This tool may be useful at key stages of web development tasks-such as after implementing new features, making substantial changes, when troubleshooting issues, or to verify the result of your work. You can analyze the provided screenshots to ensure correct rendering or identify errors, and review console logs for runtime issues.\n - For example, if asked to add a component to a react website, you might create the necessary files, use execute_command to run the site locally, then use browser_action to launch the browser, navigate to the local server, and verify the component renders & functions correctly before closing the browser."
: "" : ""
}${mcpHub ? ` }${
mcpHub
? `
- You have access to MCP servers that may provide additional tools and resources. Each server may provide different capabilities that you can use to accomplish tasks more effectively. - You have access to MCP servers that may provide additional tools and resources. Each server may provide different capabilities that you can use to accomplish tasks more effectively.
` : ''}` `
: ""
}`
} }

View File

@@ -1,19 +1,19 @@
import fs from 'fs/promises' import fs from "fs/promises"
import path from 'path' import path from "path"
export async function loadRuleFiles(cwd: string): Promise<string> { export async function loadRuleFiles(cwd: string): Promise<string> {
const ruleFiles = ['.clinerules', '.cursorrules', '.windsurfrules'] const ruleFiles = [".clinerules", ".cursorrules", ".windsurfrules"]
let combinedRules = '' let combinedRules = ""
for (const file of ruleFiles) { for (const file of ruleFiles) {
try { try {
const content = await fs.readFile(path.join(cwd, file), 'utf-8') const content = await fs.readFile(path.join(cwd, file), "utf-8")
if (content.trim()) { if (content.trim()) {
combinedRules += `\n# Rules from ${file}:\n${content.trim()}\n` combinedRules += `\n# Rules from ${file}:\n${content.trim()}\n`
} }
} catch (err) { } catch (err) {
// Silently skip if file doesn't exist // Silently skip if file doesn't exist
if ((err as NodeJS.ErrnoException).code !== 'ENOENT') { if ((err as NodeJS.ErrnoException).code !== "ENOENT") {
throw err throw err
} }
} }
@@ -22,7 +22,11 @@ export async function loadRuleFiles(cwd: string): Promise<string> {
return combinedRules return combinedRules
} }
export async function addCustomInstructions(customInstructions: string, cwd: string, preferredLanguage?: string): Promise<string> { export async function addCustomInstructions(
customInstructions: string,
cwd: string,
preferredLanguage?: string,
): Promise<string> {
const ruleFileContent = await loadRuleFiles(cwd) const ruleFileContent = await loadRuleFiles(cwd)
const allInstructions = [] const allInstructions = []
@@ -38,9 +42,10 @@ export async function addCustomInstructions(customInstructions: string, cwd: str
allInstructions.push(ruleFileContent.trim()) allInstructions.push(ruleFileContent.trim())
} }
const joinedInstructions = allInstructions.join('\n\n') const joinedInstructions = allInstructions.join("\n\n")
return joinedInstructions ? ` return joinedInstructions
? `
==== ====
USER'S CUSTOM INSTRUCTIONS USER'S CUSTOM INSTRUCTIONS

View File

@@ -1,8 +1,8 @@
export { getRulesSection } from './rules' export { getRulesSection } from "./rules"
export { getSystemInfoSection } from './system-info' export { getSystemInfoSection } from "./system-info"
export { getObjectiveSection } from './objective' export { getObjectiveSection } from "./objective"
export { addCustomInstructions } from './custom-instructions' export { addCustomInstructions } from "./custom-instructions"
export { getSharedToolUseSection } from './tool-use' export { getSharedToolUseSection } from "./tool-use"
export { getMcpServersSection } from './mcp-servers' export { getMcpServersSection } from "./mcp-servers"
export { getToolUseGuidelinesSection } from './tool-use-guidelines' export { getToolUseGuidelinesSection } from "./tool-use-guidelines"
export { getCapabilitiesSection } from './capabilities' export { getCapabilitiesSection } from "./capabilities"

View File

@@ -3,10 +3,11 @@ import { McpHub } from "../../../services/mcp/McpHub"
export async function getMcpServersSection(mcpHub?: McpHub, diffStrategy?: DiffStrategy): Promise<string> { export async function getMcpServersSection(mcpHub?: McpHub, diffStrategy?: DiffStrategy): Promise<string> {
if (!mcpHub) { if (!mcpHub) {
return ''; return ""
} }
const connectedServers = mcpHub.getServers().length > 0 const connectedServers =
mcpHub.getServers().length > 0
? `${mcpHub ? `${mcpHub
.getServers() .getServers()
.filter((server) => server.status === "connected") .filter((server) => server.status === "connected")
@@ -40,7 +41,7 @@ export async function getMcpServersSection(mcpHub?: McpHub, diffStrategy?: DiffS
) )
}) })
.join("\n\n")}` .join("\n\n")}`
: "(No MCP servers currently connected)"; : "(No MCP servers currently connected)"
return `MCP SERVERS return `MCP SERVERS
@@ -401,7 +402,7 @@ The user may ask to add tools or resources that may make sense to add to an exis
.getServers() .getServers()
.map((server) => server.name) .map((server) => server.name)
.join(", ") || "(None running currently)" .join(", ") || "(None running currently)"
}, e.g. if it would use the same API. This would be possible if you can locate the MCP server repository on the user's system by looking at the server arguments for a filepath. You might then use list_files and read_file to explore the files in the repository, and use write_to_file${diffStrategy ? " or apply_diff" : ""} to make changes to the files. }, e.g. if it would use the same API. This would be possible if you can locate the MCP server repository on the user's system by looking at the server arguments for a filepath. You might then use list_files and read_file to explore the files in the repository, and use write_to_file${diffStrategy ? " or apply_diff" : ""} to make changes to the files.
However some MCP servers may be running from installed packages rather than a local repository, in which case it may make more sense to create a new MCP server. However some MCP servers may be running from installed packages rather than a local repository, in which case it may make more sense to create a new MCP server.

View File

@@ -1,10 +1,6 @@
import { DiffStrategy } from "../../diff/DiffStrategy" import { DiffStrategy } from "../../diff/DiffStrategy"
export function getRulesSection( export function getRulesSection(cwd: string, supportsComputerUse: boolean, diffStrategy?: DiffStrategy): string {
cwd: string,
supportsComputerUse: boolean,
diffStrategy?: DiffStrategy
): string {
return `==== return `====
RULES RULES
@@ -26,7 +22,7 @@ ${diffStrategy ? "- You should use apply_diff instead of write_to_file when maki
supportsComputerUse supportsComputerUse
? '\n- The user may ask generic non-development tasks, such as "what\'s the latest news" or "look up the weather in San Diego", in which case you might use the browser_action tool to complete the task if it makes sense to do so, rather than trying to create a website or using curl to answer the question. However, if an available MCP server tool or resource can be used instead, you should prefer to use it over browser_action.' ? '\n- The user may ask generic non-development tasks, such as "what\'s the latest news" or "look up the weather in San Diego", in which case you might use the browser_action tool to complete the task if it makes sense to do so, rather than trying to create a website or using curl to answer the question. However, if an available MCP server tool or resource can be used instead, you should prefer to use it over browser_action.'
: "" : ""
} }
- NEVER end attempt_completion result with a question or request to engage in further conversation! Formulate the end of your result in a way that is final and does not require further input from the user. - NEVER end attempt_completion result with a question or request to engage in further conversation! Formulate the end of your result in a way that is final and does not require further input from the user.
- You are STRICTLY FORBIDDEN from starting your messages with "Great", "Certainly", "Okay", "Sure". You should NOT be conversational in your responses, but rather direct and to the point. For example you should NOT say "Great, I've updated the CSS" but instead something like "I've updated the CSS". It is important you be clear and technical in your messages. - You are STRICTLY FORBIDDEN from starting your messages with "Great", "Certainly", "Okay", "Sure". You should NOT be conversational in your responses, but rather direct and to the point. For example you should NOT say "Great, I've updated the CSS" but instead something like "I've updated the CSS". It is important you be clear and technical in your messages.
- When presented with images, utilize your vision capabilities to thoroughly examine them and extract meaningful information. Incorporate these insights into your thought process as you accomplish the user's task. - When presented with images, utilize your vision capabilities to thoroughly examine them and extract meaningful information. Incorporate these insights into your thought process as you accomplish the user's task.
@@ -38,5 +34,5 @@ ${diffStrategy ? "- You should use apply_diff instead of write_to_file when maki
supportsComputerUse supportsComputerUse
? " Then if you want to test your work, you might use browser_action to launch the site, wait for the user's response confirming the site was launched along with a screenshot, then perhaps e.g., click a button to test functionality if needed, wait for the user's response confirming the button was clicked along with a screenshot of the new state, before finally closing the browser." ? " Then if you want to test your work, you might use browser_action to launch the site, wait for the user's response confirming the site was launched along with a screenshot, then perhaps e.g., click a button to test functionality if needed, wait for the user's response confirming the button was clicked along with a screenshot of the new state, before finally closing the browser."
: "" : ""
}` }`
} }

View File

@@ -1,41 +1,47 @@
import { Mode, modes, CustomPrompts, PromptComponent, getRoleDefinition, defaultModeSlug } from "../../shared/modes"
import { DiffStrategy } from "../diff/DiffStrategy" import { DiffStrategy } from "../diff/DiffStrategy"
import { McpHub } from "../../services/mcp/McpHub" import { McpHub } from "../../services/mcp/McpHub"
import { CODE_PROMPT } from "./code" import { getToolDescriptionsForMode } from "./tools"
import { ARCHITECT_PROMPT } from "./architect" import {
import { ASK_PROMPT } from "./ask" getRulesSection,
import { Mode, codeMode, architectMode, askMode } from "./modes" getSystemInfoSection,
import { CustomPrompts } from "../../shared/modes" getObjectiveSection,
import fs from 'fs/promises' getSharedToolUseSection,
import path from 'path' getMcpServersSection,
getToolUseGuidelinesSection,
getCapabilitiesSection,
} from "./sections"
import fs from "fs/promises"
import path from "path"
async function loadRuleFiles(cwd: string, mode: Mode): Promise<string> { async function loadRuleFiles(cwd: string, mode: Mode): Promise<string> {
let combinedRules = '' let combinedRules = ""
// First try mode-specific rules // First try mode-specific rules
const modeSpecificFile = `.clinerules-${mode}` const modeSpecificFile = `.clinerules-${mode}`
try { try {
const content = await fs.readFile(path.join(cwd, modeSpecificFile), 'utf-8') const content = await fs.readFile(path.join(cwd, modeSpecificFile), "utf-8")
if (content.trim()) { if (content.trim()) {
combinedRules += `\n# Rules from ${modeSpecificFile}:\n${content.trim()}\n` combinedRules += `\n# Rules from ${modeSpecificFile}:\n${content.trim()}\n`
} }
} catch (err) { } catch (err) {
// Silently skip if file doesn't exist // Silently skip if file doesn't exist
if ((err as NodeJS.ErrnoException).code !== 'ENOENT') { if ((err as NodeJS.ErrnoException).code !== "ENOENT") {
throw err throw err
} }
} }
// Then try generic rules files // Then try generic rules files
const genericRuleFiles = ['.clinerules'] const genericRuleFiles = [".clinerules"]
for (const file of genericRuleFiles) { for (const file of genericRuleFiles) {
try { try {
const content = await fs.readFile(path.join(cwd, file), 'utf-8') const content = await fs.readFile(path.join(cwd, file), "utf-8")
if (content.trim()) { if (content.trim()) {
combinedRules += `\n# Rules from ${file}:\n${content.trim()}\n` combinedRules += `\n# Rules from ${file}:\n${content.trim()}\n`
} }
} catch (err) { } catch (err) {
// Silently skip if file doesn't exist // Silently skip if file doesn't exist
if ((err as NodeJS.ErrnoException).code !== 'ENOENT') { if ((err as NodeJS.ErrnoException).code !== "ENOENT") {
throw err throw err
} }
} }
@@ -45,16 +51,12 @@ async function loadRuleFiles(cwd: string, mode: Mode): Promise<string> {
} }
interface State { interface State {
customInstructions?: string; customInstructions?: string
customPrompts?: CustomPrompts; customPrompts?: CustomPrompts
preferredLanguage?: string; preferredLanguage?: string
} }
export async function addCustomInstructions( export async function addCustomInstructions(state: State, cwd: string, mode: Mode = defaultModeSlug): Promise<string> {
state: State,
cwd: string,
mode: Mode = codeMode
): Promise<string> {
const ruleFileContent = await loadRuleFiles(cwd, mode) const ruleFileContent = await loadRuleFiles(cwd, mode)
const allInstructions = [] const allInstructions = []
@@ -66,17 +68,19 @@ export async function addCustomInstructions(
allInstructions.push(state.customInstructions.trim()) allInstructions.push(state.customInstructions.trim())
} }
if (state.customPrompts?.[mode]?.customInstructions?.trim()) { const customPrompt = state.customPrompts?.[mode]
allInstructions.push(state.customPrompts[mode].customInstructions.trim()) if (typeof customPrompt === "object" && customPrompt?.customInstructions?.trim()) {
allInstructions.push(customPrompt.customInstructions.trim())
} }
if (ruleFileContent && ruleFileContent.trim()) { if (ruleFileContent && ruleFileContent.trim()) {
allInstructions.push(ruleFileContent.trim()) allInstructions.push(ruleFileContent.trim())
} }
const joinedInstructions = allInstructions.join('\n\n') const joinedInstructions = allInstructions.join("\n\n")
return joinedInstructions ? ` return joinedInstructions
? `
==== ====
USER'S CUSTOM INSTRUCTIONS USER'S CUSTOM INSTRUCTIONS
@@ -87,23 +91,63 @@ ${joinedInstructions}`
: "" : ""
} }
async function generatePrompt(
cwd: string,
supportsComputerUse: boolean,
mode: Mode,
mcpHub?: McpHub,
diffStrategy?: DiffStrategy,
browserViewportSize?: string,
promptComponent?: PromptComponent,
): Promise<string> {
const basePrompt = `${promptComponent?.roleDefinition || getRoleDefinition(mode)}
${getSharedToolUseSection()}
${getToolDescriptionsForMode(mode, cwd, supportsComputerUse, diffStrategy, browserViewportSize, mcpHub)}
${getToolUseGuidelinesSection()}
${await getMcpServersSection(mcpHub, diffStrategy)}
${getCapabilitiesSection(cwd, supportsComputerUse, mcpHub, diffStrategy)}
${getRulesSection(cwd, supportsComputerUse, diffStrategy)}
${getSystemInfoSection(cwd)}
${getObjectiveSection()}`
return basePrompt
}
export const SYSTEM_PROMPT = async ( export const SYSTEM_PROMPT = async (
cwd: string, cwd: string,
supportsComputerUse: boolean, supportsComputerUse: boolean,
mcpHub?: McpHub, mcpHub?: McpHub,
diffStrategy?: DiffStrategy, diffStrategy?: DiffStrategy,
browserViewportSize?: string, browserViewportSize?: string,
mode: Mode = codeMode, mode: Mode = defaultModeSlug,
customPrompts?: CustomPrompts, customPrompts?: CustomPrompts,
) => { ) => {
switch (mode) { const getPromptComponent = (value: unknown) => {
case architectMode: if (typeof value === "object" && value !== null) {
return ARCHITECT_PROMPT(cwd, supportsComputerUse, mcpHub, diffStrategy, browserViewportSize, customPrompts?.architect) return value as PromptComponent
case askMode: }
return ASK_PROMPT(cwd, supportsComputerUse, mcpHub, diffStrategy, browserViewportSize, customPrompts?.ask) return undefined
default:
return CODE_PROMPT(cwd, supportsComputerUse, mcpHub, diffStrategy, browserViewportSize, customPrompts?.code)
} }
}
export { codeMode, architectMode, askMode } // Use default mode if not found
const currentMode = modes.find((m) => m.slug === mode) || modes[0]
const promptComponent = getPromptComponent(customPrompts?.[currentMode.slug])
return generatePrompt(
cwd,
supportsComputerUse,
currentMode.slug,
mcpHub,
diffStrategy,
browserViewportSize,
promptComponent,
)
}

View File

@@ -1,4 +1,9 @@
export function getAccessMcpResourceDescription(): string { import { ToolArgs } from "./types"
export function getAccessMcpResourceDescription(args: ToolArgs): string | undefined {
if (!args.mcpHub) {
return undefined
}
return `## access_mcp_resource return `## access_mcp_resource
Description: Request to access a resource provided by a connected MCP server. Resources represent data sources that can be used as context, such as files, API responses, or system information. Description: Request to access a resource provided by a connected MCP server. Resources represent data sources that can be used as context, such as files, API responses, or system information.
Parameters: Parameters:

View File

@@ -1,9 +1,14 @@
export function getBrowserActionDescription(cwd: string, browserViewportSize: string = "900x600"): string { import { ToolArgs } from "./types"
export function getBrowserActionDescription(args: ToolArgs): string | undefined {
if (!args.supportsComputerUse) {
return undefined
}
return `## browser_action return `## browser_action
Description: Request to interact with a Puppeteer-controlled browser. Every action, except \`close\`, will be responded to with a screenshot of the browser's current state, along with any new console logs. You may only perform one browser action per message, and wait for the user's response including a screenshot and logs to determine the next action. Description: Request to interact with a Puppeteer-controlled browser. Every action, except \`close\`, will be responded to with a screenshot of the browser's current state, along with any new console logs. You may only perform one browser action per message, and wait for the user's response including a screenshot and logs to determine the next action.
- The sequence of actions **must always start with** launching the browser at a URL, and **must always end with** closing the browser. If you need to visit a new URL that is not possible to navigate to from the current webpage, you must first close the browser, then launch again at the new URL. - The sequence of actions **must always start with** launching the browser at a URL, and **must always end with** closing the browser. If you need to visit a new URL that is not possible to navigate to from the current webpage, you must first close the browser, then launch again at the new URL.
- While the browser is active, only the \`browser_action\` tool can be used. No other tools should be called during this time. You may proceed to use other tools only after closing the browser. For example if you run into an error and need to fix a file, you must close the browser, then use other tools to make the necessary changes, then re-launch the browser to verify the result. - While the browser is active, only the \`browser_action\` tool can be used. No other tools should be called during this time. You may proceed to use other tools only after closing the browser. For example if you run into an error and need to fix a file, you must close the browser, then use other tools to make the necessary changes, then re-launch the browser to verify the result.
- The browser window has a resolution of **${browserViewportSize}** pixels. When performing any click actions, ensure the coordinates are within this resolution range. - The browser window has a resolution of **${args.browserViewportSize}** pixels. When performing any click actions, ensure the coordinates are within this resolution range.
- Before clicking on any elements such as icons, links, or buttons, you must consult the provided screenshot of the page to determine the coordinates of the element. The click should be targeted at the **center of the element**, not on its edges. - Before clicking on any elements such as icons, links, or buttons, you must consult the provided screenshot of the page to determine the coordinates of the element. The click should be targeted at the **center of the element**, not on its edges.
Parameters: Parameters:
- action: (required) The action to perform. The available actions are: - action: (required) The action to perform. The available actions are:
@@ -21,7 +26,7 @@ Parameters:
- Example: \`<action>close</action>\` - Example: \`<action>close</action>\`
- url: (optional) Use this for providing the URL for the \`launch\` action. - url: (optional) Use this for providing the URL for the \`launch\` action.
* Example: <url>https://example.com</url> * Example: <url>https://example.com</url>
- coordinate: (optional) The X and Y coordinates for the \`click\` action. Coordinates should be within the **${browserViewportSize}** resolution. - coordinate: (optional) The X and Y coordinates for the \`click\` action. Coordinates should be within the **${args.browserViewportSize}** resolution.
* Example: <coordinate>450,300</coordinate> * Example: <coordinate>450,300</coordinate>
- text: (optional) Use this for providing the text for the \`type\` action. - text: (optional) Use this for providing the text for the \`type\` action.
* Example: <text>Hello, world!</text> * Example: <text>Hello, world!</text>

View File

@@ -1,6 +1,8 @@
export function getExecuteCommandDescription(cwd: string): string { import { ToolArgs } from "./types"
export function getExecuteCommandDescription(args: ToolArgs): string | undefined {
return `## execute_command return `## execute_command
Description: Request to execute a CLI command on the system. Use this when you need to perform system operations or run specific commands to accomplish any step in the user's task. You must tailor your command to the user's system and provide a clear explanation of what the command does. Prefer to execute complex CLI commands over creating executable scripts, as they are more flexible and easier to run. Commands will be executed in the current working directory: ${cwd} Description: Request to execute a CLI command on the system. Use this when you need to perform system operations or run specific commands to accomplish any step in the user's task. You must tailor your command to the user's system and provide a clear explanation of what the command does. Prefer to execute complex CLI commands over creating executable scripts, as they are more flexible and easier to run. Commands will be executed in the current working directory: ${args.cwd}
Parameters: Parameters:
- command: (required) The CLI command to execute. This should be valid for the current operating system. Ensure the command is properly formatted and does not contain any harmful instructions. - command: (required) The CLI command to execute. This should be valid for the current operating system. Ensure the command is properly formatted and does not contain any harmful instructions.
Usage: Usage:

View File

@@ -1,24 +1,34 @@
import { getExecuteCommandDescription } from './execute-command' import { getExecuteCommandDescription } from "./execute-command"
import { getReadFileDescription } from './read-file' import { getReadFileDescription } from "./read-file"
import { getWriteToFileDescription } from './write-to-file' import { getWriteToFileDescription } from "./write-to-file"
import { getSearchFilesDescription } from './search-files' import { getSearchFilesDescription } from "./search-files"
import { getListFilesDescription } from './list-files' import { getListFilesDescription } from "./list-files"
import { getListCodeDefinitionNamesDescription } from './list-code-definition-names' import { getListCodeDefinitionNamesDescription } from "./list-code-definition-names"
import { getBrowserActionDescription } from './browser-action' import { getBrowserActionDescription } from "./browser-action"
import { getAskFollowupQuestionDescription } from './ask-followup-question' import { getAskFollowupQuestionDescription } from "./ask-followup-question"
import { getAttemptCompletionDescription } from './attempt-completion' import { getAttemptCompletionDescription } from "./attempt-completion"
import { getUseMcpToolDescription } from './use-mcp-tool' import { getUseMcpToolDescription } from "./use-mcp-tool"
import { getAccessMcpResourceDescription } from './access-mcp-resource' import { getAccessMcpResourceDescription } from "./access-mcp-resource"
import { DiffStrategy } from '../../diff/DiffStrategy' import { DiffStrategy } from "../../diff/DiffStrategy"
import { McpHub } from '../../../services/mcp/McpHub' import { McpHub } from "../../../services/mcp/McpHub"
import { Mode, codeMode, askMode } from '../modes' import { Mode, ToolName, getModeConfig, isToolAllowedForMode } from "../../../shared/modes"
import { CODE_ALLOWED_TOOLS, READONLY_ALLOWED_TOOLS, ToolName, ReadOnlyToolName } from '../../tool-lists' import { ToolArgs } from "./types"
type AllToolNames = ToolName | ReadOnlyToolName; // Map of tool names to their description functions
const toolDescriptionMap: Record<string, (args: ToolArgs) => string | undefined> = {
// Helper function to safely check if a tool is allowed execute_command: (args) => getExecuteCommandDescription(args),
function hasAllowedTool(tools: readonly string[], tool: AllToolNames): boolean { read_file: (args) => getReadFileDescription(args),
return tools.includes(tool); write_to_file: (args) => getWriteToFileDescription(args),
search_files: (args) => getSearchFilesDescription(args),
list_files: (args) => getListFilesDescription(args),
list_code_definition_names: (args) => getListCodeDefinitionNamesDescription(args),
browser_action: (args) => getBrowserActionDescription(args),
ask_followup_question: () => getAskFollowupQuestionDescription(),
attempt_completion: () => getAttemptCompletionDescription(),
use_mcp_tool: (args) => getUseMcpToolDescription(args),
access_mcp_resource: (args) => getAccessMcpResourceDescription(args),
apply_diff: (args) =>
args.diffStrategy ? args.diffStrategy.getToolDescription({ cwd: args.cwd, toolOptions: args.toolOptions }) : "",
} }
export function getToolDescriptionsForMode( export function getToolDescriptionsForMode(
@@ -27,65 +37,34 @@ export function getToolDescriptionsForMode(
supportsComputerUse: boolean, supportsComputerUse: boolean,
diffStrategy?: DiffStrategy, diffStrategy?: DiffStrategy,
browserViewportSize?: string, browserViewportSize?: string,
mcpHub?: McpHub mcpHub?: McpHub,
): string { ): string {
const descriptions = [] const config = getModeConfig(mode)
const args: ToolArgs = {
const allowedTools = mode === codeMode ? CODE_ALLOWED_TOOLS : READONLY_ALLOWED_TOOLS; cwd,
supportsComputerUse,
// Core tools based on mode diffStrategy,
if (hasAllowedTool(allowedTools, 'execute_command')) { browserViewportSize,
descriptions.push(getExecuteCommandDescription(cwd)); mcpHub,
}
if (hasAllowedTool(allowedTools, 'read_file')) {
descriptions.push(getReadFileDescription(cwd));
}
if (hasAllowedTool(allowedTools, 'write_to_file')) {
descriptions.push(getWriteToFileDescription(cwd));
} }
// Optional diff strategy // Map tool descriptions in the exact order specified in the mode's tools array
if (diffStrategy && hasAllowedTool(allowedTools, 'apply_diff')) { const descriptions = config.tools.map(([toolName, toolOptions]) => {
descriptions.push(diffStrategy.getToolDescription(cwd)); const descriptionFn = toolDescriptionMap[toolName]
if (!descriptionFn || !isToolAllowedForMode(toolName as ToolName, mode)) {
return undefined
} }
// File operation tools return descriptionFn({
if (hasAllowedTool(allowedTools, 'search_files')) { ...args,
descriptions.push(getSearchFilesDescription(cwd)); toolOptions,
} })
if (hasAllowedTool(allowedTools, 'list_files')) { })
descriptions.push(getListFilesDescription(cwd));
}
if (hasAllowedTool(allowedTools, 'list_code_definition_names')) {
descriptions.push(getListCodeDefinitionNamesDescription(cwd));
}
// Browser actions return `# Tools\n\n${descriptions.filter(Boolean).join("\n\n")}`
if (supportsComputerUse && hasAllowedTool(allowedTools, 'browser_action')) {
descriptions.push(getBrowserActionDescription(cwd, browserViewportSize));
}
// Common tools at the end
if (hasAllowedTool(allowedTools, 'ask_followup_question')) {
descriptions.push(getAskFollowupQuestionDescription());
}
if (hasAllowedTool(allowedTools, 'attempt_completion')) {
descriptions.push(getAttemptCompletionDescription());
}
// MCP tools if available
if (mcpHub) {
if (hasAllowedTool(allowedTools, 'use_mcp_tool')) {
descriptions.push(getUseMcpToolDescription());
}
if (hasAllowedTool(allowedTools, 'access_mcp_resource')) {
descriptions.push(getAccessMcpResourceDescription());
}
}
return `# Tools\n\n${descriptions.filter(Boolean).join('\n\n')}`
} }
// Export individual description functions for backward compatibility
export { export {
getExecuteCommandDescription, getExecuteCommandDescription,
getReadFileDescription, getReadFileDescription,
@@ -97,5 +76,5 @@ export {
getAskFollowupQuestionDescription, getAskFollowupQuestionDescription,
getAttemptCompletionDescription, getAttemptCompletionDescription,
getUseMcpToolDescription, getUseMcpToolDescription,
getAccessMcpResourceDescription getAccessMcpResourceDescription,
} }

View File

@@ -1,8 +1,10 @@
export function getListCodeDefinitionNamesDescription(cwd: string): string { import { ToolArgs } from "./types"
export function getListCodeDefinitionNamesDescription(args: ToolArgs): string {
return `## list_code_definition_names return `## list_code_definition_names
Description: Request to list definition names (classes, functions, methods, etc.) used in source code files at the top level of the specified directory. This tool provides insights into the codebase structure and important constructs, encapsulating high-level concepts and relationships that are crucial for understanding the overall architecture. Description: Request to list definition names (classes, functions, methods, etc.) used in source code files at the top level of the specified directory. This tool provides insights into the codebase structure and important constructs, encapsulating high-level concepts and relationships that are crucial for understanding the overall architecture.
Parameters: Parameters:
- path: (required) The path of the directory (relative to the current working directory ${cwd.toPosix()}) to list top level source code definitions for. - path: (required) The path of the directory (relative to the current working directory ${args.cwd}) to list top level source code definitions for.
Usage: Usage:
<list_code_definition_names> <list_code_definition_names>
<path>Directory path here</path> <path>Directory path here</path>

View File

@@ -1,8 +1,10 @@
export function getListFilesDescription(cwd: string): string { import { ToolArgs } from "./types"
export function getListFilesDescription(args: ToolArgs): string {
return `## list_files return `## list_files
Description: Request to list files and directories within the specified directory. If recursive is true, it will list all files and directories recursively. If recursive is false or not provided, it will only list the top-level contents. Do not use this tool to confirm the existence of files you may have created, as the user will let you know if the files were created successfully or not. Description: Request to list files and directories within the specified directory. If recursive is true, it will list all files and directories recursively. If recursive is false or not provided, it will only list the top-level contents. Do not use this tool to confirm the existence of files you may have created, as the user will let you know if the files were created successfully or not.
Parameters: Parameters:
- path: (required) The path of the directory to list contents for (relative to the current working directory ${cwd.toPosix()}) - path: (required) The path of the directory to list contents for (relative to the current working directory ${args.cwd})
- recursive: (optional) Whether to list files recursively. Use true for recursive listing, false or omit for top-level only. - recursive: (optional) Whether to list files recursively. Use true for recursive listing, false or omit for top-level only.
Usage: Usage:
<list_files> <list_files>

View File

@@ -1,8 +1,10 @@
export function getReadFileDescription(cwd: string): string { import { ToolArgs } from "./types"
export function getReadFileDescription(args: ToolArgs): string {
return `## read_file return `## read_file
Description: Request to read the contents of a file at the specified path. Use this when you need to examine the contents of an existing file you do not know the contents of, for example to analyze code, review text files, or extract information from configuration files. The output includes line numbers prefixed to each line (e.g. "1 | const x = 1"), making it easier to reference specific lines when creating diffs or discussing code. Automatically extracts raw text from PDF and DOCX files. May not be suitable for other types of binary files, as it returns the raw content as a string. Description: Request to read the contents of a file at the specified path. Use this when you need to examine the contents of an existing file you do not know the contents of, for example to analyze code, review text files, or extract information from configuration files. The output includes line numbers prefixed to each line (e.g. "1 | const x = 1"), making it easier to reference specific lines when creating diffs or discussing code. Automatically extracts raw text from PDF and DOCX files. May not be suitable for other types of binary files, as it returns the raw content as a string.
Parameters: Parameters:
- path: (required) The path of the file to read (relative to the current working directory ${cwd}) - path: (required) The path of the file to read (relative to the current working directory ${args.cwd})
Usage: Usage:
<read_file> <read_file>
<path>File path here</path> <path>File path here</path>

View File

@@ -1,8 +1,10 @@
export function getSearchFilesDescription(cwd: string): string { import { ToolArgs } from "./types"
export function getSearchFilesDescription(args: ToolArgs): string {
return `## search_files return `## search_files
Description: Request to perform a regex search across files in a specified directory, providing context-rich results. This tool searches for patterns or specific content across multiple files, displaying each match with encapsulating context. Description: Request to perform a regex search across files in a specified directory, providing context-rich results. This tool searches for patterns or specific content across multiple files, displaying each match with encapsulating context.
Parameters: Parameters:
- path: (required) The path of the directory to search in (relative to the current working directory ${cwd.toPosix()}). This directory will be recursively searched. - path: (required) The path of the directory to search in (relative to the current working directory ${args.cwd}). This directory will be recursively searched.
- regex: (required) The regular expression pattern to search for. Uses Rust regex syntax. - regex: (required) The regular expression pattern to search for. Uses Rust regex syntax.
- file_pattern: (optional) Glob pattern to filter files (e.g., '*.ts' for TypeScript files). If not provided, it will search all files (*). - file_pattern: (optional) Glob pattern to filter files (e.g., '*.ts' for TypeScript files). If not provided, it will search all files (*).
Usage: Usage:

View File

@@ -0,0 +1,11 @@
import { DiffStrategy } from "../../diff/DiffStrategy"
import { McpHub } from "../../../services/mcp/McpHub"
export type ToolArgs = {
cwd: string
supportsComputerUse: boolean
diffStrategy?: DiffStrategy
browserViewportSize?: string
mcpHub?: McpHub
toolOptions?: any
}

Some files were not shown because too many files have changed in this diff Show More