mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-20 04:11:10 -05:00
Prettier backfill
This commit is contained in:
@@ -2,19 +2,19 @@
|
||||
|
||||
const getReleaseLine = async (changeset) => {
|
||||
const [firstLine] = changeset.summary
|
||||
.split('\n')
|
||||
.map(l => l.trim())
|
||||
.filter(Boolean);
|
||||
return `- ${firstLine}`;
|
||||
};
|
||||
.split("\n")
|
||||
.map((l) => l.trim())
|
||||
.filter(Boolean)
|
||||
return `- ${firstLine}`
|
||||
}
|
||||
|
||||
const getDependencyReleaseLine = async () => {
|
||||
return '';
|
||||
};
|
||||
return ""
|
||||
}
|
||||
|
||||
const changelogFunctions = {
|
||||
getReleaseLine,
|
||||
getDependencyReleaseLine,
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = changelogFunctions;
|
||||
module.exports = changelogFunctions
|
||||
|
||||
9
.github/pull_request_template.md
vendored
9
.github/pull_request_template.md
vendored
@@ -1,28 +1,37 @@
|
||||
<!-- **Note:** Consider creating PRs as a DRAFT. For early feedback and self-review. -->
|
||||
|
||||
## Description
|
||||
|
||||
## Type of change
|
||||
|
||||
<!-- Please ignore options that are not relevant -->
|
||||
|
||||
- [ ] Bug fix (non-breaking change which fixes an issue)
|
||||
- [ ] New feature
|
||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||
- [ ] This change requires a documentation update
|
||||
|
||||
## How Has This Been Tested?
|
||||
|
||||
<!-- Please describe the tests that you ran to verify your changes -->
|
||||
|
||||
## Checklist:
|
||||
|
||||
<!-- Go over all the following points, and put an `x` in all the boxes that apply -->
|
||||
|
||||
- [ ] My code follows the patterns of this project
|
||||
- [ ] I have performed a self-review of my own code
|
||||
- [ ] I have commented my code, particularly in hard-to-understand areas
|
||||
- [ ] I have made corresponding changes to the documentation
|
||||
|
||||
## Additional context
|
||||
|
||||
<!-- Add any other context or screenshots about the pull request here -->
|
||||
|
||||
## Related Issues
|
||||
|
||||
<!-- List any related issues here. Use the GitHub issue linking syntax: #issue-number -->
|
||||
|
||||
## Reviewers
|
||||
|
||||
<!-- @mention specific team members or individuals who should review this PR -->
|
||||
@@ -11,6 +11,7 @@ Hot off the heels of **v3.0** introducing Code, Architect, and Ask chat modes, o
|
||||
You can now tailor the **role definition** and **custom instructions** for every chat mode to perfectly fit your workflow. Want to adjust Architect mode to focus more on system scalability? Or tweak Ask mode for deeper research queries? Done. Plus, you can define these via **mode-specific `.clinerules-[mode]` files**. You’ll find all of this in the new **Prompts** tab in the top menu.
|
||||
|
||||
The second big feature in this release is a complete revamp of **prompt enhancements**. This feature helps you craft messages to get even better results from Cline. Here’s what’s new:
|
||||
|
||||
- Works with **any provider** and API configuration, not just OpenRouter.
|
||||
- Fully customizable prompts to match your unique needs.
|
||||
- Same simple workflow: just hit the ✨ **Enhance Prompt** button in the chat input to try it out.
|
||||
@@ -33,6 +34,7 @@ You can now choose between different prompts for Roo Cline to better suit your w
|
||||
It’s super simple! There’s a dropdown in the bottom left of the chat input to switch modes. Right next to it, you’ll find a way to switch between the API configuration profiles associated with the current mode (configured on the settings screen).
|
||||
|
||||
**Why Add This?**
|
||||
|
||||
- It keeps Cline from being overly eager to jump into solving problems when you just want to think or ask questions.
|
||||
- Each mode remembers the API configuration you last used with it. For example, you can use more thoughtful models like OpenAI o1 for Architect and Ask, while sticking with Sonnet or DeepSeek for coding tasks.
|
||||
- It builds on research suggesting better results when separating "thinking" from "coding," explained well in this very thoughtful [article](https://aider.chat/2024/09/26/architect.html) from aider.
|
||||
@@ -50,11 +52,13 @@ Here's an example of Roo-Cline autonomously creating a snake game with "Always a
|
||||
https://github.com/user-attachments/assets/c2bb31dc-e9b2-4d73-885d-17f1471a4987
|
||||
|
||||
## Contributing
|
||||
|
||||
To contribute to the project, start by exploring [open issues](https://github.com/RooVetGit/Roo-Cline/issues) or checking our [feature request board](https://github.com/RooVetGit/Roo-Cline/discussions/categories/feature-requests). We'd also love to have you join the [Roo Cline Reddit](https://www.reddit.com/r/roocline/) to share ideas and connect with other contributors.
|
||||
|
||||
### Local Setup
|
||||
|
||||
1. Install dependencies:
|
||||
|
||||
```bash
|
||||
npm run install:all
|
||||
```
|
||||
@@ -89,6 +93,7 @@ We use [changesets](https://github.com/changesets/changesets) for versioning and
|
||||
4. Merge it
|
||||
|
||||
Once your merge is successful:
|
||||
|
||||
- The release workflow will automatically create a new "Changeset version bump" PR
|
||||
- This PR will:
|
||||
- Update the version based on your changeset
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
## For All Settings
|
||||
|
||||
1. Add the setting to ExtensionMessage.ts:
|
||||
|
||||
- Add the setting to the ExtensionState interface
|
||||
- Make it required if it has a default value, optional if it can be undefined
|
||||
- Example: `preferredLanguage: string`
|
||||
@@ -14,10 +14,12 @@
|
||||
## For Checkbox Settings
|
||||
|
||||
1. Add the message type to WebviewMessage.ts:
|
||||
|
||||
- Add the setting name to the WebviewMessage type's type union
|
||||
- Example: `| "multisearchDiffEnabled"`
|
||||
|
||||
2. Add the setting to ExtensionStateContext.tsx:
|
||||
|
||||
- Add the setting to the ExtensionStateContextType interface
|
||||
- Add the setter function to the interface
|
||||
- Add the setting to the initial state in useState
|
||||
@@ -25,12 +27,13 @@
|
||||
- Example:
|
||||
```typescript
|
||||
interface ExtensionStateContextType {
|
||||
multisearchDiffEnabled: boolean;
|
||||
setMultisearchDiffEnabled: (value: boolean) => void;
|
||||
multisearchDiffEnabled: boolean
|
||||
setMultisearchDiffEnabled: (value: boolean) => void
|
||||
}
|
||||
```
|
||||
|
||||
3. Add the setting to ClineProvider.ts:
|
||||
|
||||
- Add the setting name to the GlobalStateKey type union
|
||||
- Add the setting to the Promise.all array in getState
|
||||
- Add the setting to the return value in getState with a default value
|
||||
@@ -46,6 +49,7 @@
|
||||
```
|
||||
|
||||
4. Add the checkbox UI to SettingsView.tsx:
|
||||
|
||||
- Import the setting and its setter from ExtensionStateContext
|
||||
- Add the VSCodeCheckbox component with the setting's state and onChange handler
|
||||
- Add appropriate labels and description text
|
||||
@@ -69,10 +73,12 @@
|
||||
## For Select/Dropdown Settings
|
||||
|
||||
1. Add the message type to WebviewMessage.ts:
|
||||
|
||||
- Add the setting name to the WebviewMessage type's type union
|
||||
- Example: `| "preferredLanguage"`
|
||||
|
||||
2. Add the setting to ExtensionStateContext.tsx:
|
||||
|
||||
- Add the setting to the ExtensionStateContextType interface
|
||||
- Add the setter function to the interface
|
||||
- Add the setting to the initial state in useState with a default value
|
||||
@@ -80,12 +86,13 @@
|
||||
- Example:
|
||||
```typescript
|
||||
interface ExtensionStateContextType {
|
||||
preferredLanguage: string;
|
||||
setPreferredLanguage: (value: string) => void;
|
||||
preferredLanguage: string
|
||||
setPreferredLanguage: (value: string) => void
|
||||
}
|
||||
```
|
||||
|
||||
3. Add the setting to ClineProvider.ts:
|
||||
|
||||
- Add the setting name to the GlobalStateKey type union
|
||||
- Add the setting to the Promise.all array in getState
|
||||
- Add the setting to the return value in getState with a default value
|
||||
@@ -101,6 +108,7 @@
|
||||
```
|
||||
|
||||
4. Add the select UI to SettingsView.tsx:
|
||||
|
||||
- Import the setting and its setter from ExtensionStateContext
|
||||
- Add the select element with appropriate styling to match VSCode's theme
|
||||
- Add options for the dropdown
|
||||
@@ -132,6 +140,7 @@
|
||||
```
|
||||
|
||||
These steps ensure that:
|
||||
|
||||
- The setting's state is properly typed throughout the application
|
||||
- The setting persists between sessions
|
||||
- The setting's value is properly synchronized between the webview and extension
|
||||
|
||||
@@ -1,41 +1,40 @@
|
||||
/** @type {import('ts-jest').JestConfigWithTsJest} */
|
||||
module.exports = {
|
||||
preset: 'ts-jest',
|
||||
testEnvironment: 'node',
|
||||
moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'],
|
||||
preset: "ts-jest",
|
||||
testEnvironment: "node",
|
||||
moduleFileExtensions: ["ts", "tsx", "js", "jsx", "json", "node"],
|
||||
transform: {
|
||||
'^.+\\.tsx?$': ['ts-jest', {
|
||||
"^.+\\.tsx?$": [
|
||||
"ts-jest",
|
||||
{
|
||||
tsconfig: {
|
||||
"module": "CommonJS",
|
||||
"moduleResolution": "node",
|
||||
"esModuleInterop": true,
|
||||
"allowJs": true
|
||||
module: "CommonJS",
|
||||
moduleResolution: "node",
|
||||
esModuleInterop: true,
|
||||
allowJs: true,
|
||||
},
|
||||
diagnostics: false,
|
||||
isolatedModules: true
|
||||
}]
|
||||
isolatedModules: true,
|
||||
},
|
||||
testMatch: ['**/__tests__/**/*.test.ts'],
|
||||
],
|
||||
},
|
||||
testMatch: ["**/__tests__/**/*.test.ts"],
|
||||
moduleNameMapper: {
|
||||
'^vscode$': '<rootDir>/src/__mocks__/vscode.js',
|
||||
'@modelcontextprotocol/sdk$': '<rootDir>/src/__mocks__/@modelcontextprotocol/sdk/index.js',
|
||||
'@modelcontextprotocol/sdk/(.*)': '<rootDir>/src/__mocks__/@modelcontextprotocol/sdk/$1',
|
||||
'^delay$': '<rootDir>/src/__mocks__/delay.js',
|
||||
'^p-wait-for$': '<rootDir>/src/__mocks__/p-wait-for.js',
|
||||
'^globby$': '<rootDir>/src/__mocks__/globby.js',
|
||||
'^serialize-error$': '<rootDir>/src/__mocks__/serialize-error.js',
|
||||
'^strip-ansi$': '<rootDir>/src/__mocks__/strip-ansi.js',
|
||||
'^default-shell$': '<rootDir>/src/__mocks__/default-shell.js',
|
||||
'^os-name$': '<rootDir>/src/__mocks__/os-name.js'
|
||||
"^vscode$": "<rootDir>/src/__mocks__/vscode.js",
|
||||
"@modelcontextprotocol/sdk$": "<rootDir>/src/__mocks__/@modelcontextprotocol/sdk/index.js",
|
||||
"@modelcontextprotocol/sdk/(.*)": "<rootDir>/src/__mocks__/@modelcontextprotocol/sdk/$1",
|
||||
"^delay$": "<rootDir>/src/__mocks__/delay.js",
|
||||
"^p-wait-for$": "<rootDir>/src/__mocks__/p-wait-for.js",
|
||||
"^globby$": "<rootDir>/src/__mocks__/globby.js",
|
||||
"^serialize-error$": "<rootDir>/src/__mocks__/serialize-error.js",
|
||||
"^strip-ansi$": "<rootDir>/src/__mocks__/strip-ansi.js",
|
||||
"^default-shell$": "<rootDir>/src/__mocks__/default-shell.js",
|
||||
"^os-name$": "<rootDir>/src/__mocks__/os-name.js",
|
||||
},
|
||||
transformIgnorePatterns: [
|
||||
'node_modules/(?!(@modelcontextprotocol|delay|p-wait-for|globby|serialize-error|strip-ansi|default-shell|os-name)/)'
|
||||
"node_modules/(?!(@modelcontextprotocol|delay|p-wait-for|globby|serialize-error|strip-ansi|default-shell|os-name)/)",
|
||||
],
|
||||
modulePathIgnorePatterns: [
|
||||
'.vscode-test'
|
||||
],
|
||||
reporters: [
|
||||
["jest-simple-dot-reporter", {}]
|
||||
],
|
||||
setupFiles: []
|
||||
modulePathIgnorePatterns: [".vscode-test"],
|
||||
reporters: [["jest-simple-dot-reporter", {}]],
|
||||
setupFiles: [],
|
||||
}
|
||||
|
||||
@@ -13,5 +13,5 @@ class Client {
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
Client
|
||||
Client,
|
||||
}
|
||||
@@ -3,14 +3,14 @@ class StdioClientTransport {
|
||||
this.start = jest.fn().mockResolvedValue(undefined)
|
||||
this.close = jest.fn().mockResolvedValue(undefined)
|
||||
this.stderr = {
|
||||
on: jest.fn()
|
||||
on: jest.fn(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class StdioServerParameters {
|
||||
constructor() {
|
||||
this.command = ''
|
||||
this.command = ""
|
||||
this.args = []
|
||||
this.env = {}
|
||||
}
|
||||
@@ -18,5 +18,5 @@ class StdioServerParameters {
|
||||
|
||||
module.exports = {
|
||||
StdioClientTransport,
|
||||
StdioServerParameters
|
||||
StdioServerParameters,
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
const { Client } = require('./client/index.js')
|
||||
const { StdioClientTransport, StdioServerParameters } = require('./client/stdio.js')
|
||||
const { Client } = require("./client/index.js")
|
||||
const { StdioClientTransport, StdioServerParameters } = require("./client/stdio.js")
|
||||
const {
|
||||
CallToolResultSchema,
|
||||
ListToolsResultSchema,
|
||||
@@ -7,8 +7,8 @@ const {
|
||||
ListResourceTemplatesResultSchema,
|
||||
ReadResourceResultSchema,
|
||||
ErrorCode,
|
||||
McpError
|
||||
} = require('./types.js')
|
||||
McpError,
|
||||
} = require("./types.js")
|
||||
|
||||
module.exports = {
|
||||
Client,
|
||||
@@ -20,5 +20,5 @@ module.exports = {
|
||||
ListResourceTemplatesResultSchema,
|
||||
ReadResourceResultSchema,
|
||||
ErrorCode,
|
||||
McpError
|
||||
McpError,
|
||||
}
|
||||
@@ -1,36 +1,36 @@
|
||||
const CallToolResultSchema = {
|
||||
parse: jest.fn().mockReturnValue({})
|
||||
parse: jest.fn().mockReturnValue({}),
|
||||
}
|
||||
|
||||
const ListToolsResultSchema = {
|
||||
parse: jest.fn().mockReturnValue({
|
||||
tools: []
|
||||
})
|
||||
tools: [],
|
||||
}),
|
||||
}
|
||||
|
||||
const ListResourcesResultSchema = {
|
||||
parse: jest.fn().mockReturnValue({
|
||||
resources: []
|
||||
})
|
||||
resources: [],
|
||||
}),
|
||||
}
|
||||
|
||||
const ListResourceTemplatesResultSchema = {
|
||||
parse: jest.fn().mockReturnValue({
|
||||
resourceTemplates: []
|
||||
})
|
||||
resourceTemplates: [],
|
||||
}),
|
||||
}
|
||||
|
||||
const ReadResourceResultSchema = {
|
||||
parse: jest.fn().mockReturnValue({
|
||||
contents: []
|
||||
})
|
||||
contents: [],
|
||||
}),
|
||||
}
|
||||
|
||||
const ErrorCode = {
|
||||
InvalidRequest: 'InvalidRequest',
|
||||
MethodNotFound: 'MethodNotFound',
|
||||
InvalidParams: 'InvalidParams',
|
||||
InternalError: 'InternalError'
|
||||
InvalidRequest: "InvalidRequest",
|
||||
MethodNotFound: "MethodNotFound",
|
||||
InvalidParams: "InvalidParams",
|
||||
InternalError: "InternalError",
|
||||
}
|
||||
|
||||
class McpError extends Error {
|
||||
@@ -47,5 +47,5 @@ module.exports = {
|
||||
ListResourceTemplatesResultSchema,
|
||||
ReadResourceResultSchema,
|
||||
ErrorCode,
|
||||
McpError
|
||||
McpError,
|
||||
}
|
||||
@@ -12,6 +12,6 @@ export class McpHub {
|
||||
}
|
||||
|
||||
async callTool(serverName: string, toolName: string, toolArguments?: Record<string, unknown>): Promise<any> {
|
||||
return Promise.resolve({ result: 'success' })
|
||||
return Promise.resolve({ result: "success" })
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,12 @@
|
||||
// Mock default shell based on platform
|
||||
const os = require('os');
|
||||
const os = require("os")
|
||||
|
||||
let defaultShell;
|
||||
if (os.platform() === 'win32') {
|
||||
defaultShell = 'cmd.exe';
|
||||
let defaultShell
|
||||
if (os.platform() === "win32") {
|
||||
defaultShell = "cmd.exe"
|
||||
} else {
|
||||
defaultShell = '/bin/bash';
|
||||
defaultShell = "/bin/bash"
|
||||
}
|
||||
|
||||
module.exports = defaultShell;
|
||||
module.exports.default = defaultShell;
|
||||
module.exports = defaultShell
|
||||
module.exports.default = defaultShell
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
function delay(ms) {
|
||||
return new Promise(resolve => setTimeout(resolve, ms));
|
||||
return new Promise((resolve) => setTimeout(resolve, ms))
|
||||
}
|
||||
|
||||
module.exports = delay;
|
||||
module.exports.default = delay;
|
||||
module.exports = delay
|
||||
module.exports.default = delay
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
function globby(patterns, options) {
|
||||
return Promise.resolve([]);
|
||||
return Promise.resolve([])
|
||||
}
|
||||
|
||||
globby.sync = function (patterns, options) {
|
||||
return [];
|
||||
};
|
||||
return []
|
||||
}
|
||||
|
||||
module.exports = globby;
|
||||
module.exports.default = globby;
|
||||
module.exports = globby
|
||||
module.exports.default = globby
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
function osName() {
|
||||
return 'macOS';
|
||||
return "macOS"
|
||||
}
|
||||
|
||||
module.exports = osName;
|
||||
module.exports.default = osName;
|
||||
module.exports = osName
|
||||
module.exports.default = osName
|
||||
|
||||
@@ -2,19 +2,19 @@ function pWaitFor(condition, options = {}) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const interval = setInterval(() => {
|
||||
if (condition()) {
|
||||
clearInterval(interval);
|
||||
resolve();
|
||||
clearInterval(interval)
|
||||
resolve()
|
||||
}
|
||||
}, options.interval || 20);
|
||||
}, options.interval || 20)
|
||||
|
||||
if (options.timeout) {
|
||||
setTimeout(() => {
|
||||
clearInterval(interval);
|
||||
reject(new Error('Timed out'));
|
||||
}, options.timeout);
|
||||
clearInterval(interval)
|
||||
reject(new Error("Timed out"))
|
||||
}, options.timeout)
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
module.exports = pWaitFor;
|
||||
module.exports.default = pWaitFor;
|
||||
module.exports = pWaitFor
|
||||
module.exports.default = pWaitFor
|
||||
|
||||
@@ -3,23 +3,23 @@ function serializeError(error) {
|
||||
return {
|
||||
name: error.name,
|
||||
message: error.message,
|
||||
stack: error.stack
|
||||
};
|
||||
stack: error.stack,
|
||||
}
|
||||
return error;
|
||||
}
|
||||
return error
|
||||
}
|
||||
|
||||
function deserializeError(errorData) {
|
||||
if (errorData && typeof errorData === 'object') {
|
||||
const error = new Error(errorData.message);
|
||||
error.name = errorData.name;
|
||||
error.stack = errorData.stack;
|
||||
return error;
|
||||
if (errorData && typeof errorData === "object") {
|
||||
const error = new Error(errorData.message)
|
||||
error.name = errorData.name
|
||||
error.stack = errorData.stack
|
||||
return error
|
||||
}
|
||||
return errorData;
|
||||
return errorData
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
serializeError,
|
||||
deserializeError
|
||||
};
|
||||
deserializeError,
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
function stripAnsi(string) {
|
||||
// Simple mock that just returns the input string
|
||||
return string;
|
||||
return string
|
||||
}
|
||||
|
||||
module.exports = stripAnsi;
|
||||
module.exports.default = stripAnsi;
|
||||
module.exports = stripAnsi
|
||||
module.exports.default = stripAnsi
|
||||
|
||||
@@ -3,11 +3,11 @@ const vscode = {
|
||||
showInformationMessage: jest.fn(),
|
||||
showErrorMessage: jest.fn(),
|
||||
createTextEditorDecorationType: jest.fn().mockReturnValue({
|
||||
dispose: jest.fn()
|
||||
})
|
||||
dispose: jest.fn(),
|
||||
}),
|
||||
},
|
||||
workspace: {
|
||||
onDidSaveTextDocument: jest.fn()
|
||||
onDidSaveTextDocument: jest.fn(),
|
||||
},
|
||||
Disposable: class {
|
||||
dispose() {}
|
||||
@@ -15,43 +15,43 @@ const vscode = {
|
||||
Uri: {
|
||||
file: (path) => ({
|
||||
fsPath: path,
|
||||
scheme: 'file',
|
||||
authority: '',
|
||||
scheme: "file",
|
||||
authority: "",
|
||||
path: path,
|
||||
query: '',
|
||||
fragment: '',
|
||||
query: "",
|
||||
fragment: "",
|
||||
with: jest.fn(),
|
||||
toJSON: jest.fn()
|
||||
})
|
||||
toJSON: jest.fn(),
|
||||
}),
|
||||
},
|
||||
EventEmitter: class {
|
||||
constructor() {
|
||||
this.event = jest.fn();
|
||||
this.fire = jest.fn();
|
||||
this.event = jest.fn()
|
||||
this.fire = jest.fn()
|
||||
}
|
||||
},
|
||||
ConfigurationTarget: {
|
||||
Global: 1,
|
||||
Workspace: 2,
|
||||
WorkspaceFolder: 3
|
||||
WorkspaceFolder: 3,
|
||||
},
|
||||
Position: class {
|
||||
constructor(line, character) {
|
||||
this.line = line;
|
||||
this.character = character;
|
||||
this.line = line
|
||||
this.character = character
|
||||
}
|
||||
},
|
||||
Range: class {
|
||||
constructor(startLine, startCharacter, endLine, endCharacter) {
|
||||
this.start = new vscode.Position(startLine, startCharacter);
|
||||
this.end = new vscode.Position(endLine, endCharacter);
|
||||
this.start = new vscode.Position(startLine, startCharacter)
|
||||
this.end = new vscode.Position(endLine, endCharacter)
|
||||
}
|
||||
},
|
||||
ThemeColor: class {
|
||||
constructor(id) {
|
||||
this.id = id;
|
||||
this.id = id
|
||||
}
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = vscode;
|
||||
module.exports = vscode
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import { AnthropicHandler } from '../anthropic';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import { ApiStream } from '../../transform/stream';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { AnthropicHandler } from "../anthropic"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import { ApiStream } from "../../transform/stream"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock Anthropic client
|
||||
const mockBetaCreate = jest.fn();
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('@anthropic-ai/sdk', () => {
|
||||
const mockBetaCreate = jest.fn()
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("@anthropic-ai/sdk", () => {
|
||||
return {
|
||||
Anthropic: jest.fn().mockImplementation(() => ({
|
||||
beta: {
|
||||
@@ -15,225 +15,224 @@ jest.mock('@anthropic-ai/sdk', () => {
|
||||
create: mockBetaCreate.mockImplementation(async () => ({
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: 'message_start',
|
||||
type: "message_start",
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 100,
|
||||
output_tokens: 50,
|
||||
cache_creation_input_tokens: 20,
|
||||
cache_read_input_tokens: 10
|
||||
cache_read_input_tokens: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
};
|
||||
yield {
|
||||
type: 'content_block_start',
|
||||
type: "content_block_start",
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'Hello'
|
||||
type: "text",
|
||||
text: "Hello",
|
||||
},
|
||||
}
|
||||
};
|
||||
yield {
|
||||
type: 'content_block_delta',
|
||||
type: "content_block_delta",
|
||||
delta: {
|
||||
type: 'text_delta',
|
||||
text: ' world'
|
||||
}
|
||||
};
|
||||
}
|
||||
}))
|
||||
}
|
||||
type: "text_delta",
|
||||
text: " world",
|
||||
},
|
||||
}
|
||||
},
|
||||
})),
|
||||
},
|
||||
},
|
||||
},
|
||||
messages: {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
content: [
|
||||
{ type: 'text', text: 'Test response' }
|
||||
],
|
||||
role: 'assistant',
|
||||
id: "test-completion",
|
||||
content: [{ type: "text", text: "Test response" }],
|
||||
role: "assistant",
|
||||
model: options.model,
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
output_tokens: 5,
|
||||
},
|
||||
}
|
||||
}
|
||||
return {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: 'message_start',
|
||||
type: "message_start",
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
}
|
||||
output_tokens: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
yield {
|
||||
type: 'content_block_start',
|
||||
type: "content_block_start",
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'Test response'
|
||||
}
|
||||
type: "text",
|
||||
text: "Test response",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
})),
|
||||
}
|
||||
})
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
|
||||
describe('AnthropicHandler', () => {
|
||||
let handler: AnthropicHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
describe("AnthropicHandler", () => {
|
||||
let handler: AnthropicHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiKey: 'test-api-key',
|
||||
apiModelId: 'claude-3-5-sonnet-20241022'
|
||||
};
|
||||
handler = new AnthropicHandler(mockOptions);
|
||||
mockBetaCreate.mockClear();
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
apiKey: "test-api-key",
|
||||
apiModelId: "claude-3-5-sonnet-20241022",
|
||||
}
|
||||
handler = new AnthropicHandler(mockOptions)
|
||||
mockBetaCreate.mockClear()
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(AnthropicHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(AnthropicHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
|
||||
})
|
||||
|
||||
it('should initialize with undefined API key', () => {
|
||||
it("should initialize with undefined API key", () => {
|
||||
// The SDK will handle API key validation, so we just verify it initializes
|
||||
const handlerWithoutKey = new AnthropicHandler({
|
||||
...mockOptions,
|
||||
apiKey: undefined
|
||||
});
|
||||
expect(handlerWithoutKey).toBeInstanceOf(AnthropicHandler);
|
||||
});
|
||||
apiKey: undefined,
|
||||
})
|
||||
expect(handlerWithoutKey).toBeInstanceOf(AnthropicHandler)
|
||||
})
|
||||
|
||||
it('should use custom base URL if provided', () => {
|
||||
const customBaseUrl = 'https://custom.anthropic.com';
|
||||
it("should use custom base URL if provided", () => {
|
||||
const customBaseUrl = "https://custom.anthropic.com"
|
||||
const handlerWithCustomUrl = new AnthropicHandler({
|
||||
...mockOptions,
|
||||
anthropicBaseUrl: customBaseUrl
|
||||
});
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler);
|
||||
});
|
||||
});
|
||||
anthropicBaseUrl: customBaseUrl,
|
||||
})
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler)
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: [{
|
||||
type: 'text' as const,
|
||||
text: 'Hello!'
|
||||
}]
|
||||
}
|
||||
];
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: "Hello!",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
it('should handle prompt caching for supported models', async () => {
|
||||
it("should handle prompt caching for supported models", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, [
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text' as const, text: 'First message' }]
|
||||
role: "user",
|
||||
content: [{ type: "text" as const, text: "First message" }],
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text' as const, text: 'Response' }]
|
||||
role: "assistant",
|
||||
content: [{ type: "text" as const, text: "Response" }],
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text' as const, text: 'Second message' }]
|
||||
}
|
||||
]);
|
||||
role: "user",
|
||||
content: [{ type: "text" as const, text: "Second message" }],
|
||||
},
|
||||
])
|
||||
|
||||
const chunks: any[] = [];
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
// Verify usage information
|
||||
const usageChunk = chunks.find(chunk => chunk.type === 'usage');
|
||||
expect(usageChunk).toBeDefined();
|
||||
expect(usageChunk?.inputTokens).toBe(100);
|
||||
expect(usageChunk?.outputTokens).toBe(50);
|
||||
expect(usageChunk?.cacheWriteTokens).toBe(20);
|
||||
expect(usageChunk?.cacheReadTokens).toBe(10);
|
||||
const usageChunk = chunks.find((chunk) => chunk.type === "usage")
|
||||
expect(usageChunk).toBeDefined()
|
||||
expect(usageChunk?.inputTokens).toBe(100)
|
||||
expect(usageChunk?.outputTokens).toBe(50)
|
||||
expect(usageChunk?.cacheWriteTokens).toBe(20)
|
||||
expect(usageChunk?.cacheReadTokens).toBe(10)
|
||||
|
||||
// Verify text content
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(2);
|
||||
expect(textChunks[0].text).toBe('Hello');
|
||||
expect(textChunks[1].text).toBe(' world');
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(2)
|
||||
expect(textChunks[0].text).toBe("Hello")
|
||||
expect(textChunks[1].text).toBe(" world")
|
||||
|
||||
// Verify beta API was used
|
||||
expect(mockBetaCreate).toHaveBeenCalled();
|
||||
expect(mockCreate).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
expect(mockBetaCreate).toHaveBeenCalled()
|
||||
expect(mockCreate).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.apiModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
max_tokens: 8192,
|
||||
temperature: 0,
|
||||
stream: false
|
||||
});
|
||||
});
|
||||
stream: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Anthropic completion error: API Error');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Anthropic completion error: API Error")
|
||||
})
|
||||
|
||||
it('should handle non-text content', async () => {
|
||||
it("should handle non-text content", async () => {
|
||||
mockCreate.mockImplementationOnce(async () => ({
|
||||
content: [{ type: 'image' }]
|
||||
}));
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
content: [{ type: "image" }],
|
||||
}))
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockImplementationOnce(async () => ({
|
||||
content: [{ type: 'text', text: '' }]
|
||||
}));
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
content: [{ type: "text", text: "" }],
|
||||
}))
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return default model if no model ID is provided', () => {
|
||||
describe("getModel", () => {
|
||||
it("should return default model if no model ID is provided", () => {
|
||||
const handlerWithoutModel = new AnthropicHandler({
|
||||
...mockOptions,
|
||||
apiModelId: undefined
|
||||
});
|
||||
const model = handlerWithoutModel.getModel();
|
||||
expect(model.id).toBeDefined();
|
||||
expect(model.info).toBeDefined();
|
||||
});
|
||||
apiModelId: undefined,
|
||||
})
|
||||
const model = handlerWithoutModel.getModel()
|
||||
expect(model.id).toBeDefined()
|
||||
expect(model.info).toBeDefined()
|
||||
})
|
||||
|
||||
it('should return specified model if valid model ID is provided', () => {
|
||||
const model = handler.getModel();
|
||||
expect(model.id).toBe(mockOptions.apiModelId);
|
||||
expect(model.info).toBeDefined();
|
||||
expect(model.info.maxTokens).toBe(8192);
|
||||
expect(model.info.contextWindow).toBe(200_000);
|
||||
expect(model.info.supportsImages).toBe(true);
|
||||
expect(model.info.supportsPromptCache).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
it("should return specified model if valid model ID is provided", () => {
|
||||
const model = handler.getModel()
|
||||
expect(model.id).toBe(mockOptions.apiModelId)
|
||||
expect(model.info).toBeDefined()
|
||||
expect(model.info.maxTokens).toBe(8192)
|
||||
expect(model.info.contextWindow).toBe(200_000)
|
||||
expect(model.info.supportsImages).toBe(true)
|
||||
expect(model.info.supportsPromptCache).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,62 +1,64 @@
|
||||
import { AwsBedrockHandler } from '../bedrock';
|
||||
import { MessageContent } from '../../../shared/api';
|
||||
import { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { AwsBedrockHandler } from "../bedrock"
|
||||
import { MessageContent } from "../../../shared/api"
|
||||
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
describe('AwsBedrockHandler', () => {
|
||||
let handler: AwsBedrockHandler;
|
||||
describe("AwsBedrockHandler", () => {
|
||||
let handler: AwsBedrockHandler
|
||||
|
||||
beforeEach(() => {
|
||||
handler = new AwsBedrockHandler({
|
||||
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||
awsAccessKey: 'test-access-key',
|
||||
awsSecretKey: 'test-secret-key',
|
||||
awsRegion: 'us-east-1'
|
||||
});
|
||||
});
|
||||
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
awsAccessKey: "test-access-key",
|
||||
awsSecretKey: "test-secret-key",
|
||||
awsRegion: "us-east-1",
|
||||
})
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided config', () => {
|
||||
expect(handler['options'].awsAccessKey).toBe('test-access-key');
|
||||
expect(handler['options'].awsSecretKey).toBe('test-secret-key');
|
||||
expect(handler['options'].awsRegion).toBe('us-east-1');
|
||||
expect(handler['options'].apiModelId).toBe('anthropic.claude-3-5-sonnet-20241022-v2:0');
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided config", () => {
|
||||
expect(handler["options"].awsAccessKey).toBe("test-access-key")
|
||||
expect(handler["options"].awsSecretKey).toBe("test-secret-key")
|
||||
expect(handler["options"].awsRegion).toBe("us-east-1")
|
||||
expect(handler["options"].apiModelId).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
})
|
||||
|
||||
it('should initialize with missing AWS credentials', () => {
|
||||
it("should initialize with missing AWS credentials", () => {
|
||||
const handlerWithoutCreds = new AwsBedrockHandler({
|
||||
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||
awsRegion: 'us-east-1'
|
||||
});
|
||||
expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler);
|
||||
});
|
||||
});
|
||||
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
awsRegion: "us-east-1",
|
||||
})
|
||||
expect(handlerWithoutCreds).toBeInstanceOf(AwsBedrockHandler)
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
describe("createMessage", () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello'
|
||||
role: "user",
|
||||
content: "Hello",
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi there!'
|
||||
}
|
||||
];
|
||||
role: "assistant",
|
||||
content: "Hi there!",
|
||||
},
|
||||
]
|
||||
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
|
||||
it('should handle text messages correctly', async () => {
|
||||
it("should handle text messages correctly", async () => {
|
||||
const mockResponse = {
|
||||
messages: [{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Hello! How can I help you?' }]
|
||||
}],
|
||||
messages: [
|
||||
{
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "Hello! How can I help you?" }],
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
output_tokens: 5,
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
// Mock AWS SDK invoke
|
||||
const mockStream = {
|
||||
@@ -65,182 +67,193 @@ describe('AwsBedrockHandler', () => {
|
||||
metadata: {
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
outputTokens: 5
|
||||
outputTokens: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
const mockInvoke = jest.fn().mockResolvedValue({
|
||||
stream: mockStream
|
||||
});
|
||||
stream: mockStream,
|
||||
})
|
||||
|
||||
handler['client'] = {
|
||||
send: mockInvoke
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
handler["client"] = {
|
||||
send: mockInvoke,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const chunks = [];
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
const chunks = []
|
||||
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'usage',
|
||||
type: "usage",
|
||||
inputTokens: 10,
|
||||
outputTokens: 5
|
||||
});
|
||||
|
||||
expect(mockInvoke).toHaveBeenCalledWith(expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0'
|
||||
outputTokens: 5,
|
||||
})
|
||||
}));
|
||||
});
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
expect(mockInvoke).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
}),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
it("should handle API errors", async () => {
|
||||
// Mock AWS SDK invoke with error
|
||||
const mockInvoke = jest.fn().mockRejectedValue(new Error('AWS Bedrock error'));
|
||||
const mockInvoke = jest.fn().mockRejectedValue(new Error("AWS Bedrock error"))
|
||||
|
||||
handler['client'] = {
|
||||
send: mockInvoke
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
handler["client"] = {
|
||||
send: mockInvoke,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
}
|
||||
}).rejects.toThrow('AWS Bedrock error');
|
||||
});
|
||||
});
|
||||
}).rejects.toThrow("AWS Bedrock error")
|
||||
})
|
||||
})
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(JSON.stringify({
|
||||
content: 'Test response'
|
||||
}))
|
||||
};
|
||||
output: new TextEncoder().encode(
|
||||
JSON.stringify({
|
||||
content: "Test response",
|
||||
}),
|
||||
),
|
||||
}
|
||||
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse)
|
||||
handler["client"] = {
|
||||
send: mockSend,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockSend).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||
modelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
messages: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
role: 'user',
|
||||
content: [{ text: 'Test prompt' }]
|
||||
})
|
||||
role: "user",
|
||||
content: [{ text: "Test prompt" }],
|
||||
}),
|
||||
]),
|
||||
inferenceConfig: expect.objectContaining({
|
||||
maxTokens: 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1
|
||||
topP: 0.1,
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
it("should handle API errors", async () => {
|
||||
const mockError = new Error("AWS Bedrock error")
|
||||
const mockSend = jest.fn().mockRejectedValue(mockError)
|
||||
handler["client"] = {
|
||||
send: mockSend,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"Bedrock completion error: AWS Bedrock error",
|
||||
)
|
||||
})
|
||||
}));
|
||||
});
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('AWS Bedrock error');
|
||||
const mockSend = jest.fn().mockRejectedValue(mockError);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Bedrock completion error: AWS Bedrock error');
|
||||
});
|
||||
|
||||
it('should handle invalid response format', async () => {
|
||||
it("should handle invalid response format", async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode('invalid json')
|
||||
};
|
||||
output: new TextEncoder().encode("invalid json"),
|
||||
}
|
||||
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse)
|
||||
handler["client"] = {
|
||||
send: mockSend,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
it("should handle empty response", async () => {
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(JSON.stringify({}))
|
||||
};
|
||||
output: new TextEncoder().encode(JSON.stringify({})),
|
||||
}
|
||||
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse)
|
||||
handler["client"] = {
|
||||
send: mockSend,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
|
||||
it('should handle cross-region inference', async () => {
|
||||
it("should handle cross-region inference", async () => {
|
||||
handler = new AwsBedrockHandler({
|
||||
apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||
awsAccessKey: 'test-access-key',
|
||||
awsSecretKey: 'test-secret-key',
|
||||
awsRegion: 'us-east-1',
|
||||
awsUseCrossRegionInference: true
|
||||
});
|
||||
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
awsAccessKey: "test-access-key",
|
||||
awsSecretKey: "test-secret-key",
|
||||
awsRegion: "us-east-1",
|
||||
awsUseCrossRegionInference: true,
|
||||
})
|
||||
|
||||
const mockResponse = {
|
||||
output: new TextEncoder().encode(JSON.stringify({
|
||||
content: 'Test response'
|
||||
}))
|
||||
};
|
||||
output: new TextEncoder().encode(
|
||||
JSON.stringify({
|
||||
content: "Test response",
|
||||
}),
|
||||
),
|
||||
}
|
||||
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse);
|
||||
handler['client'] = {
|
||||
send: mockSend
|
||||
} as unknown as BedrockRuntimeClient;
|
||||
const mockSend = jest.fn().mockResolvedValue(mockResponse)
|
||||
handler["client"] = {
|
||||
send: mockSend,
|
||||
} as unknown as BedrockRuntimeClient
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockSend).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
modelId: 'us.anthropic.claude-3-5-sonnet-20241022-v2:0'
|
||||
modelId: "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
}),
|
||||
}),
|
||||
)
|
||||
})
|
||||
})
|
||||
}));
|
||||
});
|
||||
});
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return correct model info in test environment', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe('anthropic.claude-3-5-sonnet-20241022-v2:0');
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(5000); // Test environment value
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000); // Test environment value
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return correct model info in test environment", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(5000) // Test environment value
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000) // Test environment value
|
||||
})
|
||||
|
||||
it('should return test model info for invalid model in test environment', () => {
|
||||
it("should return test model info for invalid model in test environment", () => {
|
||||
const invalidHandler = new AwsBedrockHandler({
|
||||
apiModelId: 'invalid-model',
|
||||
awsAccessKey: 'test-access-key',
|
||||
awsSecretKey: 'test-secret-key',
|
||||
awsRegion: 'us-east-1'
|
||||
});
|
||||
const modelInfo = invalidHandler.getModel();
|
||||
expect(modelInfo.id).toBe('invalid-model'); // In test env, returns whatever is passed
|
||||
expect(modelInfo.info.maxTokens).toBe(5000);
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||
});
|
||||
});
|
||||
});
|
||||
apiModelId: "invalid-model",
|
||||
awsAccessKey: "test-access-key",
|
||||
awsSecretKey: "test-secret-key",
|
||||
awsRegion: "us-east-1",
|
||||
})
|
||||
const modelInfo = invalidHandler.getModel()
|
||||
expect(modelInfo.id).toBe("invalid-model") // In test env, returns whatever is passed
|
||||
expect(modelInfo.info.maxTokens).toBe(5000)
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { DeepSeekHandler } from '../deepseek';
|
||||
import { ApiHandlerOptions, deepSeekDefaultModelId } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { DeepSeekHandler } from "../deepseek"
|
||||
import { ApiHandlerOptions, deepSeekDefaultModelId } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
@@ -14,190 +14,204 @@ jest.mock('openai', () => {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response', refusal: null },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response", refusal: null },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Return async iterator for streaming
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [{
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
};
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
|
||||
describe('DeepSeekHandler', () => {
|
||||
let handler: DeepSeekHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
describe("DeepSeekHandler", () => {
|
||||
let handler: DeepSeekHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
deepSeekApiKey: 'test-api-key',
|
||||
deepSeekModelId: 'deepseek-chat',
|
||||
deepSeekBaseUrl: 'https://api.deepseek.com/v1'
|
||||
};
|
||||
handler = new DeepSeekHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
deepSeekApiKey: "test-api-key",
|
||||
deepSeekModelId: "deepseek-chat",
|
||||
deepSeekBaseUrl: "https://api.deepseek.com/v1",
|
||||
}
|
||||
handler = new DeepSeekHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(DeepSeekHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.deepSeekModelId);
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(DeepSeekHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.deepSeekModelId)
|
||||
})
|
||||
|
||||
it('should throw error if API key is missing', () => {
|
||||
it("should throw error if API key is missing", () => {
|
||||
expect(() => {
|
||||
new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekApiKey: undefined
|
||||
});
|
||||
}).toThrow('DeepSeek API key is required');
|
||||
});
|
||||
deepSeekApiKey: undefined,
|
||||
})
|
||||
}).toThrow("DeepSeek API key is required")
|
||||
})
|
||||
|
||||
it('should use default model ID if not provided', () => {
|
||||
it("should use default model ID if not provided", () => {
|
||||
const handlerWithoutModel = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekModelId: undefined
|
||||
});
|
||||
expect(handlerWithoutModel.getModel().id).toBe(deepSeekDefaultModelId);
|
||||
});
|
||||
deepSeekModelId: undefined,
|
||||
})
|
||||
expect(handlerWithoutModel.getModel().id).toBe(deepSeekDefaultModelId)
|
||||
})
|
||||
|
||||
it('should use default base URL if not provided', () => {
|
||||
it("should use default base URL if not provided", () => {
|
||||
const handlerWithoutBaseUrl = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekBaseUrl: undefined
|
||||
});
|
||||
expect(handlerWithoutBaseUrl).toBeInstanceOf(DeepSeekHandler);
|
||||
deepSeekBaseUrl: undefined,
|
||||
})
|
||||
expect(handlerWithoutBaseUrl).toBeInstanceOf(DeepSeekHandler)
|
||||
// The base URL is passed to OpenAI client internally
|
||||
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({
|
||||
baseURL: 'https://api.deepseek.com/v1'
|
||||
}));
|
||||
});
|
||||
expect(OpenAI).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
baseURL: "https://api.deepseek.com/v1",
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
it('should use custom base URL if provided', () => {
|
||||
const customBaseUrl = 'https://custom.deepseek.com/v1';
|
||||
it("should use custom base URL if provided", () => {
|
||||
const customBaseUrl = "https://custom.deepseek.com/v1"
|
||||
const handlerWithCustomUrl = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekBaseUrl: customBaseUrl
|
||||
});
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(DeepSeekHandler);
|
||||
deepSeekBaseUrl: customBaseUrl,
|
||||
})
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(DeepSeekHandler)
|
||||
// The custom base URL is passed to OpenAI client
|
||||
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({
|
||||
baseURL: customBaseUrl
|
||||
}));
|
||||
});
|
||||
expect(OpenAI).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
baseURL: customBaseUrl,
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
it('should set includeMaxTokens to true', () => {
|
||||
it("should set includeMaxTokens to true", () => {
|
||||
// Create a new handler and verify OpenAI client was called with includeMaxTokens
|
||||
new DeepSeekHandler(mockOptions);
|
||||
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({
|
||||
apiKey: mockOptions.deepSeekApiKey
|
||||
}));
|
||||
});
|
||||
});
|
||||
new DeepSeekHandler(mockOptions)
|
||||
expect(OpenAI).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
apiKey: mockOptions.deepSeekApiKey,
|
||||
}),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info for valid model ID', () => {
|
||||
const model = handler.getModel();
|
||||
expect(model.id).toBe(mockOptions.deepSeekModelId);
|
||||
expect(model.info).toBeDefined();
|
||||
expect(model.info.maxTokens).toBe(8192);
|
||||
expect(model.info.contextWindow).toBe(64_000);
|
||||
expect(model.info.supportsImages).toBe(false);
|
||||
expect(model.info.supportsPromptCache).toBe(false);
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return model info for valid model ID", () => {
|
||||
const model = handler.getModel()
|
||||
expect(model.id).toBe(mockOptions.deepSeekModelId)
|
||||
expect(model.info).toBeDefined()
|
||||
expect(model.info.maxTokens).toBe(8192)
|
||||
expect(model.info.contextWindow).toBe(64_000)
|
||||
expect(model.info.supportsImages).toBe(false)
|
||||
expect(model.info.supportsPromptCache).toBe(false)
|
||||
})
|
||||
|
||||
it('should return provided model ID with default model info if model does not exist', () => {
|
||||
it("should return provided model ID with default model info if model does not exist", () => {
|
||||
const handlerWithInvalidModel = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekModelId: 'invalid-model'
|
||||
});
|
||||
const model = handlerWithInvalidModel.getModel();
|
||||
expect(model.id).toBe('invalid-model'); // Returns provided ID
|
||||
expect(model.info).toBeDefined();
|
||||
expect(model.info).toBe(handler.getModel().info); // But uses default model info
|
||||
});
|
||||
deepSeekModelId: "invalid-model",
|
||||
})
|
||||
const model = handlerWithInvalidModel.getModel()
|
||||
expect(model.id).toBe("invalid-model") // Returns provided ID
|
||||
expect(model.info).toBeDefined()
|
||||
expect(model.info).toBe(handler.getModel().info) // But uses default model info
|
||||
})
|
||||
|
||||
it('should return default model if no model ID is provided', () => {
|
||||
it("should return default model if no model ID is provided", () => {
|
||||
const handlerWithoutModel = new DeepSeekHandler({
|
||||
...mockOptions,
|
||||
deepSeekModelId: undefined
|
||||
});
|
||||
const model = handlerWithoutModel.getModel();
|
||||
expect(model.id).toBe(deepSeekDefaultModelId);
|
||||
expect(model.info).toBeDefined();
|
||||
});
|
||||
});
|
||||
deepSeekModelId: undefined,
|
||||
})
|
||||
const model = handlerWithoutModel.getModel()
|
||||
expect(model.id).toBe(deepSeekDefaultModelId)
|
||||
expect(model.info).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: [{
|
||||
type: 'text' as const,
|
||||
text: 'Hello!'
|
||||
}]
|
||||
}
|
||||
];
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: "Hello!",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
it("should handle streaming responses", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(1)
|
||||
expect(textChunks[0].text).toBe("Test response")
|
||||
})
|
||||
|
||||
it('should include usage information', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
it("should include usage information", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
const usageChunks = chunks.filter(chunk => chunk.type === 'usage');
|
||||
expect(usageChunks.length).toBeGreaterThan(0);
|
||||
expect(usageChunks[0].inputTokens).toBe(10);
|
||||
expect(usageChunks[0].outputTokens).toBe(5);
|
||||
});
|
||||
});
|
||||
});
|
||||
const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
|
||||
expect(usageChunks.length).toBeGreaterThan(0)
|
||||
expect(usageChunks[0].inputTokens).toBe(10)
|
||||
expect(usageChunks[0].outputTokens).toBe(5)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,212 +1,210 @@
|
||||
import { GeminiHandler } from '../gemini';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { GoogleGenerativeAI } from '@google/generative-ai';
|
||||
import { GeminiHandler } from "../gemini"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { GoogleGenerativeAI } from "@google/generative-ai"
|
||||
|
||||
// Mock the Google Generative AI SDK
|
||||
jest.mock('@google/generative-ai', () => ({
|
||||
jest.mock("@google/generative-ai", () => ({
|
||||
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
|
||||
getGenerativeModel: jest.fn().mockReturnValue({
|
||||
generateContentStream: jest.fn(),
|
||||
generateContent: jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => 'Test response'
|
||||
}
|
||||
})
|
||||
})
|
||||
text: () => "Test response",
|
||||
},
|
||||
}),
|
||||
}),
|
||||
})),
|
||||
}))
|
||||
}));
|
||||
|
||||
describe('GeminiHandler', () => {
|
||||
let handler: GeminiHandler;
|
||||
describe("GeminiHandler", () => {
|
||||
let handler: GeminiHandler
|
||||
|
||||
beforeEach(() => {
|
||||
handler = new GeminiHandler({
|
||||
apiKey: 'test-key',
|
||||
apiModelId: 'gemini-2.0-flash-thinking-exp-1219',
|
||||
geminiApiKey: 'test-key'
|
||||
});
|
||||
});
|
||||
apiKey: "test-key",
|
||||
apiModelId: "gemini-2.0-flash-thinking-exp-1219",
|
||||
geminiApiKey: "test-key",
|
||||
})
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided config', () => {
|
||||
expect(handler['options'].geminiApiKey).toBe('test-key');
|
||||
expect(handler['options'].apiModelId).toBe('gemini-2.0-flash-thinking-exp-1219');
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided config", () => {
|
||||
expect(handler["options"].geminiApiKey).toBe("test-key")
|
||||
expect(handler["options"].apiModelId).toBe("gemini-2.0-flash-thinking-exp-1219")
|
||||
})
|
||||
|
||||
it('should throw if API key is missing', () => {
|
||||
it("should throw if API key is missing", () => {
|
||||
expect(() => {
|
||||
new GeminiHandler({
|
||||
apiModelId: 'gemini-2.0-flash-thinking-exp-1219',
|
||||
geminiApiKey: ''
|
||||
});
|
||||
}).toThrow('API key is required for Google Gemini');
|
||||
});
|
||||
});
|
||||
apiModelId: "gemini-2.0-flash-thinking-exp-1219",
|
||||
geminiApiKey: "",
|
||||
})
|
||||
}).toThrow("API key is required for Google Gemini")
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
describe("createMessage", () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello'
|
||||
role: "user",
|
||||
content: "Hello",
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi there!'
|
||||
}
|
||||
];
|
||||
role: "assistant",
|
||||
content: "Hi there!",
|
||||
},
|
||||
]
|
||||
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
|
||||
it('should handle text messages correctly', async () => {
|
||||
it("should handle text messages correctly", async () => {
|
||||
// Mock the stream response
|
||||
const mockStream = {
|
||||
stream: [
|
||||
{ text: () => 'Hello' },
|
||||
{ text: () => ' world!' }
|
||||
],
|
||||
stream: [{ text: () => "Hello" }, { text: () => " world!" }],
|
||||
response: {
|
||||
usageMetadata: {
|
||||
promptTokenCount: 10,
|
||||
candidatesTokenCount: 5
|
||||
candidatesTokenCount: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Setup the mock implementation
|
||||
const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream);
|
||||
const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream)
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContentStream: mockGenerateContentStream
|
||||
});
|
||||
generateContentStream: mockGenerateContentStream,
|
||||
})
|
||||
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const chunks = [];
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
const chunks = []
|
||||
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
// Should have 3 chunks: 'Hello', ' world!', and usage info
|
||||
expect(chunks.length).toBe(3);
|
||||
expect(chunks.length).toBe(3)
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Hello'
|
||||
});
|
||||
type: "text",
|
||||
text: "Hello",
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'text',
|
||||
text: ' world!'
|
||||
});
|
||||
type: "text",
|
||||
text: " world!",
|
||||
})
|
||||
expect(chunks[2]).toEqual({
|
||||
type: 'usage',
|
||||
type: "usage",
|
||||
inputTokens: 10,
|
||||
outputTokens: 5
|
||||
});
|
||||
outputTokens: 5,
|
||||
})
|
||||
|
||||
// Verify the model configuration
|
||||
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
|
||||
model: 'gemini-2.0-flash-thinking-exp-1219',
|
||||
systemInstruction: systemPrompt
|
||||
});
|
||||
model: "gemini-2.0-flash-thinking-exp-1219",
|
||||
systemInstruction: systemPrompt,
|
||||
})
|
||||
|
||||
// Verify generation config
|
||||
expect(mockGenerateContentStream).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
generationConfig: {
|
||||
temperature: 0
|
||||
}
|
||||
temperature: 0,
|
||||
},
|
||||
}),
|
||||
)
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('Gemini API error');
|
||||
const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError);
|
||||
it("should handle API errors", async () => {
|
||||
const mockError = new Error("Gemini API error")
|
||||
const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError)
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContentStream: mockGenerateContentStream
|
||||
});
|
||||
generateContentStream: mockGenerateContentStream,
|
||||
})
|
||||
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
}
|
||||
}).rejects.toThrow('Gemini API error');
|
||||
});
|
||||
});
|
||||
}).rejects.toThrow("Gemini API error")
|
||||
})
|
||||
})
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const mockGenerateContent = jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => 'Test response'
|
||||
}
|
||||
});
|
||||
text: () => "Test response",
|
||||
},
|
||||
})
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent
|
||||
});
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
generateContent: mockGenerateContent,
|
||||
})
|
||||
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
|
||||
model: 'gemini-2.0-flash-thinking-exp-1219'
|
||||
});
|
||||
model: "gemini-2.0-flash-thinking-exp-1219",
|
||||
})
|
||||
expect(mockGenerateContent).toHaveBeenCalledWith({
|
||||
contents: [{ role: 'user', parts: [{ text: 'Test prompt' }] }],
|
||||
contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
|
||||
generationConfig: {
|
||||
temperature: 0
|
||||
}
|
||||
});
|
||||
});
|
||||
temperature: 0,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('Gemini API error');
|
||||
const mockGenerateContent = jest.fn().mockRejectedValue(mockError);
|
||||
it("should handle API errors", async () => {
|
||||
const mockError = new Error("Gemini API error")
|
||||
const mockGenerateContent = jest.fn().mockRejectedValue(mockError)
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent
|
||||
});
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
generateContent: mockGenerateContent,
|
||||
})
|
||||
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Gemini completion error: Gemini API error');
|
||||
});
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"Gemini completion error: Gemini API error",
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
it("should handle empty response", async () => {
|
||||
const mockGenerateContent = jest.fn().mockResolvedValue({
|
||||
response: {
|
||||
text: () => ''
|
||||
}
|
||||
});
|
||||
text: () => "",
|
||||
},
|
||||
})
|
||||
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent
|
||||
});
|
||||
(handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
|
||||
generateContent: mockGenerateContent,
|
||||
})
|
||||
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return correct model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe('gemini-2.0-flash-thinking-exp-1219');
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(8192);
|
||||
expect(modelInfo.info.contextWindow).toBe(32_767);
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return correct model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219")
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(8192)
|
||||
expect(modelInfo.info.contextWindow).toBe(32_767)
|
||||
})
|
||||
|
||||
it('should return default model if invalid model specified', () => {
|
||||
it("should return default model if invalid model specified", () => {
|
||||
const invalidHandler = new GeminiHandler({
|
||||
apiModelId: 'invalid-model',
|
||||
geminiApiKey: 'test-key'
|
||||
});
|
||||
const modelInfo = invalidHandler.getModel();
|
||||
expect(modelInfo.id).toBe('gemini-2.0-flash-thinking-exp-1219'); // Default model
|
||||
});
|
||||
});
|
||||
});
|
||||
apiModelId: "invalid-model",
|
||||
geminiApiKey: "test-key",
|
||||
})
|
||||
const modelInfo = invalidHandler.getModel()
|
||||
expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219") // Default model
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { GlamaHandler } from '../glama';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import axios from 'axios';
|
||||
import { GlamaHandler } from "../glama"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import axios from "axios"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
const mockWithResponse = jest.fn();
|
||||
const mockCreate = jest.fn()
|
||||
const mockWithResponse = jest.fn()
|
||||
|
||||
jest.mock('openai', () => {
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
@@ -18,209 +18,221 @@ jest.mock('openai', () => {
|
||||
const stream = {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [{
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
};
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
const result = mockCreate(...args);
|
||||
const result = mockCreate(...args)
|
||||
if (args[0].stream) {
|
||||
mockWithResponse.mockReturnValue(Promise.resolve({
|
||||
mockWithResponse.mockReturnValue(
|
||||
Promise.resolve({
|
||||
data: stream,
|
||||
response: {
|
||||
headers: {
|
||||
get: (name: string) => name === 'x-completion-request-id' ? 'test-request-id' : null
|
||||
get: (name: string) =>
|
||||
name === "x-completion-request-id" ? "test-request-id" : null,
|
||||
},
|
||||
},
|
||||
}),
|
||||
)
|
||||
result.withResponse = mockWithResponse
|
||||
}
|
||||
return result
|
||||
},
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
}));
|
||||
result.withResponse = mockWithResponse;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
})
|
||||
|
||||
describe('GlamaHandler', () => {
|
||||
let handler: GlamaHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
describe("GlamaHandler", () => {
|
||||
let handler: GlamaHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: 'anthropic/claude-3-5-sonnet',
|
||||
glamaModelId: 'anthropic/claude-3-5-sonnet',
|
||||
glamaApiKey: 'test-api-key'
|
||||
};
|
||||
handler = new GlamaHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
mockWithResponse.mockClear();
|
||||
apiModelId: "anthropic/claude-3-5-sonnet",
|
||||
glamaModelId: "anthropic/claude-3-5-sonnet",
|
||||
glamaApiKey: "test-api-key",
|
||||
}
|
||||
handler = new GlamaHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
mockWithResponse.mockClear()
|
||||
|
||||
// Default mock implementation for non-streaming responses
|
||||
mockCreate.mockResolvedValue({
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response" },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
});
|
||||
});
|
||||
total_tokens: 15,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(GlamaHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
|
||||
});
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(GlamaHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello!'
|
||||
}
|
||||
];
|
||||
role: "user",
|
||||
content: "Hello!",
|
||||
},
|
||||
]
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
it("should handle streaming responses", async () => {
|
||||
// Mock axios for token usage request
|
||||
const mockAxios = jest.spyOn(axios, 'get').mockResolvedValueOnce({
|
||||
const mockAxios = jest.spyOn(axios, "get").mockResolvedValueOnce({
|
||||
data: {
|
||||
tokenUsage: {
|
||||
promptTokens: 10,
|
||||
completionTokens: 5,
|
||||
cacheCreationInputTokens: 0,
|
||||
cacheReadInputTokens: 0
|
||||
cacheReadInputTokens: 0,
|
||||
},
|
||||
totalCostUsd: "0.00"
|
||||
}
|
||||
});
|
||||
totalCostUsd: "0.00",
|
||||
},
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks.length).toBe(2); // Text chunk and usage chunk
|
||||
expect(chunks.length).toBe(2) // Text chunk and usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Test response'
|
||||
});
|
||||
type: "text",
|
||||
text: "Test response",
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'usage',
|
||||
type: "usage",
|
||||
inputTokens: 10,
|
||||
outputTokens: 5,
|
||||
cacheWriteTokens: 0,
|
||||
cacheReadTokens: 0,
|
||||
totalCost: 0
|
||||
});
|
||||
totalCost: 0,
|
||||
})
|
||||
|
||||
mockAxios.mockRestore();
|
||||
});
|
||||
mockAxios.mockRestore()
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockImplementationOnce(() => {
|
||||
throw new Error('API Error');
|
||||
});
|
||||
throw new Error("API Error")
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks = [];
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks = []
|
||||
|
||||
try {
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
fail('Expected error to be thrown');
|
||||
fail("Expected error to be thrown")
|
||||
} catch (error) {
|
||||
expect(error).toBeInstanceOf(Error);
|
||||
expect(error.message).toBe('API Error');
|
||||
expect(error).toBeInstanceOf(Error)
|
||||
expect(error.message).toBe("API Error")
|
||||
}
|
||||
});
|
||||
});
|
||||
})
|
||||
})
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: mockOptions.apiModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
max_tokens: 8192
|
||||
}));
|
||||
});
|
||||
max_tokens: 8192,
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Glama completion error: API Error');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Glama completion error: API Error")
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
choices: [{ message: { content: "" } }],
|
||||
})
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
|
||||
it('should not set max_tokens for non-Anthropic models', async () => {
|
||||
it("should not set max_tokens for non-Anthropic models", async () => {
|
||||
// Reset mock to clear any previous calls
|
||||
mockCreate.mockClear();
|
||||
mockCreate.mockClear()
|
||||
|
||||
const nonAnthropicOptions = {
|
||||
apiModelId: 'openai/gpt-4',
|
||||
glamaModelId: 'openai/gpt-4',
|
||||
glamaApiKey: 'test-key',
|
||||
apiModelId: "openai/gpt-4",
|
||||
glamaModelId: "openai/gpt-4",
|
||||
glamaApiKey: "test-key",
|
||||
glamaModelInfo: {
|
||||
maxTokens: 4096,
|
||||
contextWindow: 8192,
|
||||
supportsImages: true,
|
||||
supportsPromptCache: false
|
||||
supportsPromptCache: false,
|
||||
},
|
||||
}
|
||||
};
|
||||
const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions);
|
||||
const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions)
|
||||
|
||||
await nonAnthropicHandler.completePrompt('Test prompt');
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
model: 'openai/gpt-4',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0
|
||||
}));
|
||||
expect(mockCreate.mock.calls[0][0]).not.toHaveProperty('max_tokens');
|
||||
});
|
||||
});
|
||||
await nonAnthropicHandler.completePrompt("Test prompt")
|
||||
expect(mockCreate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: "openai/gpt-4",
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
}),
|
||||
)
|
||||
expect(mockCreate.mock.calls[0][0]).not.toHaveProperty("max_tokens")
|
||||
})
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe(mockOptions.apiModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(8192);
|
||||
expect(modelInfo.info.contextWindow).toBe(200_000);
|
||||
});
|
||||
});
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe(mockOptions.apiModelId)
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(8192)
|
||||
expect(modelInfo.info.contextWindow).toBe(200_000)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { LmStudioHandler } from '../lmstudio';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { LmStudioHandler } from "../lmstudio"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
@@ -14,147 +14,154 @@ jest.mock('openai', () => {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response" },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [{
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
};
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
|
||||
describe('LmStudioHandler', () => {
|
||||
let handler: LmStudioHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
describe("LmStudioHandler", () => {
|
||||
let handler: LmStudioHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: 'local-model',
|
||||
lmStudioModelId: 'local-model',
|
||||
lmStudioBaseUrl: 'http://localhost:1234/v1'
|
||||
};
|
||||
handler = new LmStudioHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
apiModelId: "local-model",
|
||||
lmStudioModelId: "local-model",
|
||||
lmStudioBaseUrl: "http://localhost:1234/v1",
|
||||
}
|
||||
handler = new LmStudioHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(LmStudioHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId);
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(LmStudioHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId)
|
||||
})
|
||||
|
||||
it('should use default base URL if not provided', () => {
|
||||
it("should use default base URL if not provided", () => {
|
||||
const handlerWithoutUrl = new LmStudioHandler({
|
||||
apiModelId: 'local-model',
|
||||
lmStudioModelId: 'local-model'
|
||||
});
|
||||
expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler);
|
||||
});
|
||||
});
|
||||
apiModelId: "local-model",
|
||||
lmStudioModelId: "local-model",
|
||||
})
|
||||
expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler)
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello!'
|
||||
}
|
||||
];
|
||||
role: "user",
|
||||
content: "Hello!",
|
||||
},
|
||||
]
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
it("should handle streaming responses", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(1)
|
||||
expect(textChunks[0].text).toBe("Test response")
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
|
||||
});
|
||||
});
|
||||
}).rejects.toThrow("Please check the LM Studio developer logs to debug what went wrong")
|
||||
})
|
||||
})
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.lmStudioModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
});
|
||||
});
|
||||
stream: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"Please check the LM Studio developer logs to debug what went wrong",
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
choices: [{ message: { content: "" } }],
|
||||
})
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe(mockOptions.lmStudioModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(-1);
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||
});
|
||||
});
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe(mockOptions.lmStudioModelId)
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(-1)
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { OllamaHandler } from '../ollama';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { OllamaHandler } from "../ollama"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
@@ -14,147 +14,152 @@ jest.mock('openai', () => {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response" },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [{
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
};
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
|
||||
describe('OllamaHandler', () => {
|
||||
let handler: OllamaHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
describe("OllamaHandler", () => {
|
||||
let handler: OllamaHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: 'llama2',
|
||||
ollamaModelId: 'llama2',
|
||||
ollamaBaseUrl: 'http://localhost:11434/v1'
|
||||
};
|
||||
handler = new OllamaHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
apiModelId: "llama2",
|
||||
ollamaModelId: "llama2",
|
||||
ollamaBaseUrl: "http://localhost:11434/v1",
|
||||
}
|
||||
handler = new OllamaHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(OllamaHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.ollamaModelId);
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(OllamaHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.ollamaModelId)
|
||||
})
|
||||
|
||||
it('should use default base URL if not provided', () => {
|
||||
it("should use default base URL if not provided", () => {
|
||||
const handlerWithoutUrl = new OllamaHandler({
|
||||
apiModelId: 'llama2',
|
||||
ollamaModelId: 'llama2'
|
||||
});
|
||||
expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler);
|
||||
});
|
||||
});
|
||||
apiModelId: "llama2",
|
||||
ollamaModelId: "llama2",
|
||||
})
|
||||
expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler)
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello!'
|
||||
}
|
||||
];
|
||||
role: "user",
|
||||
content: "Hello!",
|
||||
},
|
||||
]
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
it("should handle streaming responses", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(1)
|
||||
expect(textChunks[0].text).toBe("Test response")
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow('API Error');
|
||||
});
|
||||
});
|
||||
}).rejects.toThrow("API Error")
|
||||
})
|
||||
})
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.ollamaModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
});
|
||||
});
|
||||
stream: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Ollama completion error: API Error');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Ollama completion error: API Error")
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
choices: [{ message: { content: "" } }],
|
||||
})
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe(mockOptions.ollamaModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(-1);
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||
});
|
||||
});
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe(mockOptions.ollamaModelId)
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(-1)
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { OpenAiNativeHandler } from '../openai-native';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { OpenAiNativeHandler } from "../openai-native"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
@@ -14,306 +14,313 @@ jest.mock('openai', () => {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response" },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [{
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
};
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
|
||||
describe('OpenAiNativeHandler', () => {
|
||||
let handler: OpenAiNativeHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
describe("OpenAiNativeHandler", () => {
|
||||
let handler: OpenAiNativeHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello!'
|
||||
}
|
||||
];
|
||||
role: "user",
|
||||
content: "Hello!",
|
||||
},
|
||||
]
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: 'gpt-4o',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
};
|
||||
handler = new OpenAiNativeHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
apiModelId: "gpt-4o",
|
||||
openAiNativeApiKey: "test-api-key",
|
||||
}
|
||||
handler = new OpenAiNativeHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(OpenAiNativeHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(OpenAiNativeHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
|
||||
})
|
||||
|
||||
it('should initialize with empty API key', () => {
|
||||
it("should initialize with empty API key", () => {
|
||||
const handlerWithoutKey = new OpenAiNativeHandler({
|
||||
apiModelId: 'gpt-4o',
|
||||
openAiNativeApiKey: ''
|
||||
});
|
||||
expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler);
|
||||
});
|
||||
});
|
||||
apiModelId: "gpt-4o",
|
||||
openAiNativeApiKey: "",
|
||||
})
|
||||
expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler)
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
describe("createMessage", () => {
|
||||
it("should handle streaming responses", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(1)
|
||||
expect(textChunks[0].text).toBe("Test response")
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow('API Error');
|
||||
});
|
||||
}).rejects.toThrow("API Error")
|
||||
})
|
||||
|
||||
it('should handle missing content in response for o1 model', async () => {
|
||||
it("should handle missing content in response for o1 model", async () => {
|
||||
// Use o1 model which supports developer role
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: 'o1'
|
||||
});
|
||||
apiModelId: "o1",
|
||||
})
|
||||
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: null } }],
|
||||
usage: {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0
|
||||
}
|
||||
});
|
||||
total_tokens: 0,
|
||||
},
|
||||
})
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages);
|
||||
const results = [];
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const results = []
|
||||
for await (const result of generator) {
|
||||
results.push(result);
|
||||
results.push(result)
|
||||
}
|
||||
|
||||
expect(results).toEqual([
|
||||
{ type: 'text', text: '' },
|
||||
{ type: 'usage', inputTokens: 0, outputTokens: 0 }
|
||||
]);
|
||||
{ type: "text", text: "" },
|
||||
{ type: "usage", inputTokens: 0, outputTokens: 0 },
|
||||
])
|
||||
|
||||
// Verify developer role is used for system prompt with o1 model
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1',
|
||||
model: "o1",
|
||||
messages: [
|
||||
{ role: 'developer', content: systemPrompt },
|
||||
{ role: 'user', content: 'Hello!' }
|
||||
]
|
||||
});
|
||||
});
|
||||
});
|
||||
{ role: "developer", content: systemPrompt },
|
||||
{ role: "user", content: "Hello!" },
|
||||
],
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('streaming models', () => {
|
||||
describe("streaming models", () => {
|
||||
beforeEach(() => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: 'gpt-4o',
|
||||
});
|
||||
});
|
||||
apiModelId: "gpt-4o",
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle streaming response', async () => {
|
||||
it("should handle streaming response", async () => {
|
||||
const mockStream = [
|
||||
{ choices: [{ delta: { content: 'Hello' } }], usage: null },
|
||||
{ choices: [{ delta: { content: ' there' } }], usage: null },
|
||||
{ choices: [{ delta: { content: '!' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||
];
|
||||
{ choices: [{ delta: { content: "Hello" } }], usage: null },
|
||||
{ choices: [{ delta: { content: " there" } }], usage: null },
|
||||
{ choices: [{ delta: { content: "!" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||
]
|
||||
|
||||
mockCreate.mockResolvedValueOnce(
|
||||
(async function* () {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk;
|
||||
yield chunk
|
||||
}
|
||||
})()
|
||||
);
|
||||
})(),
|
||||
)
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages);
|
||||
const results = [];
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const results = []
|
||||
for await (const result of generator) {
|
||||
results.push(result);
|
||||
results.push(result)
|
||||
}
|
||||
|
||||
expect(results).toEqual([
|
||||
{ type: 'text', text: 'Hello' },
|
||||
{ type: 'text', text: ' there' },
|
||||
{ type: 'text', text: '!' },
|
||||
{ type: 'usage', inputTokens: 10, outputTokens: 5 },
|
||||
]);
|
||||
{ type: "text", text: "Hello" },
|
||||
{ type: "text", text: " there" },
|
||||
{ type: "text", text: "!" },
|
||||
{ type: "usage", inputTokens: 10, outputTokens: 5 },
|
||||
])
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'gpt-4o',
|
||||
model: "gpt-4o",
|
||||
temperature: 0,
|
||||
messages: [
|
||||
{ role: 'system', content: systemPrompt },
|
||||
{ role: 'user', content: 'Hello!' },
|
||||
{ role: "system", content: systemPrompt },
|
||||
{ role: "user", content: "Hello!" },
|
||||
],
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
});
|
||||
});
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle empty delta content', async () => {
|
||||
it("should handle empty delta content", async () => {
|
||||
const mockStream = [
|
||||
{ choices: [{ delta: {} }], usage: null },
|
||||
{ choices: [{ delta: { content: null } }], usage: null },
|
||||
{ choices: [{ delta: { content: 'Hello' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||
];
|
||||
{ choices: [{ delta: { content: "Hello" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||
]
|
||||
|
||||
mockCreate.mockResolvedValueOnce(
|
||||
(async function* () {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk;
|
||||
yield chunk
|
||||
}
|
||||
})()
|
||||
);
|
||||
})(),
|
||||
)
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages);
|
||||
const results = [];
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const results = []
|
||||
for await (const result of generator) {
|
||||
results.push(result);
|
||||
results.push(result)
|
||||
}
|
||||
|
||||
expect(results).toEqual([
|
||||
{ type: 'text', text: 'Hello' },
|
||||
{ type: 'usage', inputTokens: 10, outputTokens: 5 },
|
||||
]);
|
||||
});
|
||||
});
|
||||
{ type: "text", text: "Hello" },
|
||||
{ type: "usage", inputTokens: 10, outputTokens: 5 },
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully with gpt-4o model', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully with gpt-4o model", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'gpt-4o',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0
|
||||
});
|
||||
});
|
||||
model: "gpt-4o",
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
})
|
||||
})
|
||||
|
||||
it('should complete prompt successfully with o1 model', async () => {
|
||||
it("should complete prompt successfully with o1 model", async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: 'o1',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
apiModelId: "o1",
|
||||
openAiNativeApiKey: "test-api-key",
|
||||
})
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||
});
|
||||
});
|
||||
model: "o1",
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
})
|
||||
})
|
||||
|
||||
it('should complete prompt successfully with o1-preview model', async () => {
|
||||
it("should complete prompt successfully with o1-preview model", async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: 'o1-preview',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
apiModelId: "o1-preview",
|
||||
openAiNativeApiKey: "test-api-key",
|
||||
})
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1-preview',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||
});
|
||||
});
|
||||
model: "o1-preview",
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
})
|
||||
})
|
||||
|
||||
it('should complete prompt successfully with o1-mini model', async () => {
|
||||
it("should complete prompt successfully with o1-mini model", async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: 'o1-mini',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
apiModelId: "o1-mini",
|
||||
openAiNativeApiKey: "test-api-key",
|
||||
})
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1-mini',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||
});
|
||||
});
|
||||
model: "o1-mini",
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('OpenAI Native completion error: API Error');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"OpenAI Native completion error: API Error",
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
choices: [{ message: { content: "" } }],
|
||||
})
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe(mockOptions.apiModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(4096);
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe(mockOptions.apiModelId)
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(4096)
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000)
|
||||
})
|
||||
|
||||
it('should handle undefined model ID', () => {
|
||||
it("should handle undefined model ID", () => {
|
||||
const handlerWithoutModel = new OpenAiNativeHandler({
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
const modelInfo = handlerWithoutModel.getModel();
|
||||
expect(modelInfo.id).toBe('gpt-4o'); // Default model
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
openAiNativeApiKey: "test-api-key",
|
||||
})
|
||||
const modelInfo = handlerWithoutModel.getModel()
|
||||
expect(modelInfo.id).toBe("gpt-4o") // Default model
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import { OpenAiHandler } from '../openai';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import { ApiStream } from '../../transform/stream';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { OpenAiHandler } from "../openai"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import { ApiStream } from "../../transform/stream"
|
||||
import OpenAI from "openai"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
const mockCreate = jest.fn()
|
||||
jest.mock("openai", () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
@@ -15,210 +15,219 @@ jest.mock('openai', () => {
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response', refusal: null },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
id: "test-completion",
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content: "Test response", refusal: null },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "Test response" },
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: null,
|
||||
}
|
||||
yield {
|
||||
choices: [{
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
total_tokens: 15,
|
||||
},
|
||||
}
|
||||
};
|
||||
},
|
||||
}
|
||||
}),
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
|
||||
describe('OpenAiHandler', () => {
|
||||
let handler: OpenAiHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
describe("OpenAiHandler", () => {
|
||||
let handler: OpenAiHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
openAiApiKey: 'test-api-key',
|
||||
openAiModelId: 'gpt-4',
|
||||
openAiBaseUrl: 'https://api.openai.com/v1'
|
||||
};
|
||||
handler = new OpenAiHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
openAiApiKey: "test-api-key",
|
||||
openAiModelId: "gpt-4",
|
||||
openAiBaseUrl: "https://api.openai.com/v1",
|
||||
}
|
||||
handler = new OpenAiHandler(mockOptions)
|
||||
mockCreate.mockClear()
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(OpenAiHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.openAiModelId);
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeInstanceOf(OpenAiHandler)
|
||||
expect(handler.getModel().id).toBe(mockOptions.openAiModelId)
|
||||
})
|
||||
|
||||
it('should use custom base URL if provided', () => {
|
||||
const customBaseUrl = 'https://custom.openai.com/v1';
|
||||
it("should use custom base URL if provided", () => {
|
||||
const customBaseUrl = "https://custom.openai.com/v1"
|
||||
const handlerWithCustomUrl = new OpenAiHandler({
|
||||
...mockOptions,
|
||||
openAiBaseUrl: customBaseUrl
|
||||
});
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(OpenAiHandler);
|
||||
});
|
||||
});
|
||||
openAiBaseUrl: customBaseUrl,
|
||||
})
|
||||
expect(handlerWithCustomUrl).toBeInstanceOf(OpenAiHandler)
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant."
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: [{
|
||||
type: 'text' as const,
|
||||
text: 'Hello!'
|
||||
}]
|
||||
}
|
||||
];
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: "Hello!",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
it('should handle non-streaming mode', async () => {
|
||||
it("should handle non-streaming mode", async () => {
|
||||
const handler = new OpenAiHandler({
|
||||
...mockOptions,
|
||||
openAiStreamingEnabled: false
|
||||
});
|
||||
openAiStreamingEnabled: false,
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunk = chunks.find(chunk => chunk.type === 'text');
|
||||
const usageChunk = chunks.find(chunk => chunk.type === 'usage');
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunk = chunks.find((chunk) => chunk.type === "text")
|
||||
const usageChunk = chunks.find((chunk) => chunk.type === "usage")
|
||||
|
||||
expect(textChunk).toBeDefined();
|
||||
expect(textChunk?.text).toBe('Test response');
|
||||
expect(usageChunk).toBeDefined();
|
||||
expect(usageChunk?.inputTokens).toBe(10);
|
||||
expect(usageChunk?.outputTokens).toBe(5);
|
||||
});
|
||||
expect(textChunk).toBeDefined()
|
||||
expect(textChunk?.text).toBe("Test response")
|
||||
expect(usageChunk).toBeDefined()
|
||||
expect(usageChunk?.inputTokens).toBe(10)
|
||||
expect(usageChunk?.outputTokens).toBe(5)
|
||||
})
|
||||
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
it("should handle streaming responses", async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks: any[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
});
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
||||
expect(textChunks).toHaveLength(1)
|
||||
expect(textChunks[0].text).toBe("Test response")
|
||||
})
|
||||
})
|
||||
|
||||
describe('error handling', () => {
|
||||
describe("error handling", () => {
|
||||
const testMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: [{
|
||||
type: 'text' as const,
|
||||
text: 'Hello'
|
||||
}]
|
||||
}
|
||||
];
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: "Hello",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
|
||||
const stream = handler.createMessage('system prompt', testMessages);
|
||||
const stream = handler.createMessage("system prompt", testMessages)
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow('API Error');
|
||||
});
|
||||
}).rejects.toThrow("API Error")
|
||||
})
|
||||
|
||||
it('should handle rate limiting', async () => {
|
||||
const rateLimitError = new Error('Rate limit exceeded');
|
||||
rateLimitError.name = 'Error';
|
||||
(rateLimitError as any).status = 429;
|
||||
mockCreate.mockRejectedValueOnce(rateLimitError);
|
||||
it("should handle rate limiting", async () => {
|
||||
const rateLimitError = new Error("Rate limit exceeded")
|
||||
rateLimitError.name = "Error"
|
||||
;(rateLimitError as any).status = 429
|
||||
mockCreate.mockRejectedValueOnce(rateLimitError)
|
||||
|
||||
const stream = handler.createMessage('system prompt', testMessages);
|
||||
const stream = handler.createMessage("system prompt", testMessages)
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow('Rate limit exceeded');
|
||||
});
|
||||
});
|
||||
}).rejects.toThrow("Rate limit exceeded")
|
||||
})
|
||||
})
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.openAiModelId,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0
|
||||
});
|
||||
});
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
temperature: 0,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('OpenAI completion error: API Error');
|
||||
});
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("OpenAI completion error: API Error")
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
it("should handle empty response", async () => {
|
||||
mockCreate.mockImplementationOnce(() => ({
|
||||
choices: [{ message: { content: '' } }]
|
||||
}));
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
choices: [{ message: { content: "" } }],
|
||||
}))
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info with sane defaults', () => {
|
||||
const model = handler.getModel();
|
||||
expect(model.id).toBe(mockOptions.openAiModelId);
|
||||
expect(model.info).toBeDefined();
|
||||
expect(model.info.contextWindow).toBe(128_000);
|
||||
expect(model.info.supportsImages).toBe(true);
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return model info with sane defaults", () => {
|
||||
const model = handler.getModel()
|
||||
expect(model.id).toBe(mockOptions.openAiModelId)
|
||||
expect(model.info).toBeDefined()
|
||||
expect(model.info.contextWindow).toBe(128_000)
|
||||
expect(model.info.supportsImages).toBe(true)
|
||||
})
|
||||
|
||||
it('should handle undefined model ID', () => {
|
||||
it("should handle undefined model ID", () => {
|
||||
const handlerWithoutModel = new OpenAiHandler({
|
||||
...mockOptions,
|
||||
openAiModelId: undefined
|
||||
});
|
||||
const model = handlerWithoutModel.getModel();
|
||||
expect(model.id).toBe('');
|
||||
expect(model.info).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
openAiModelId: undefined,
|
||||
})
|
||||
const model = handlerWithoutModel.getModel()
|
||||
expect(model.id).toBe("")
|
||||
expect(model.info).toBeDefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,83 +1,85 @@
|
||||
import { OpenRouterHandler } from '../openrouter'
|
||||
import { ApiHandlerOptions, ModelInfo } from '../../../shared/api'
|
||||
import OpenAI from 'openai'
|
||||
import axios from 'axios'
|
||||
import { Anthropic } from '@anthropic-ai/sdk'
|
||||
import { OpenRouterHandler } from "../openrouter"
|
||||
import { ApiHandlerOptions, ModelInfo } from "../../../shared/api"
|
||||
import OpenAI from "openai"
|
||||
import axios from "axios"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock dependencies
|
||||
jest.mock('openai')
|
||||
jest.mock('axios')
|
||||
jest.mock('delay', () => jest.fn(() => Promise.resolve()))
|
||||
jest.mock("openai")
|
||||
jest.mock("axios")
|
||||
jest.mock("delay", () => jest.fn(() => Promise.resolve()))
|
||||
|
||||
describe('OpenRouterHandler', () => {
|
||||
describe("OpenRouterHandler", () => {
|
||||
const mockOptions: ApiHandlerOptions = {
|
||||
openRouterApiKey: 'test-key',
|
||||
openRouterModelId: 'test-model',
|
||||
openRouterApiKey: "test-key",
|
||||
openRouterModelId: "test-model",
|
||||
openRouterModelInfo: {
|
||||
name: 'Test Model',
|
||||
description: 'Test Description',
|
||||
name: "Test Model",
|
||||
description: "Test Description",
|
||||
maxTokens: 1000,
|
||||
contextWindow: 2000,
|
||||
supportsPromptCache: true,
|
||||
inputPrice: 0.01,
|
||||
outputPrice: 0.02
|
||||
} as ModelInfo
|
||||
outputPrice: 0.02,
|
||||
} as ModelInfo,
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
test('constructor initializes with correct options', () => {
|
||||
test("constructor initializes with correct options", () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
expect(handler).toBeInstanceOf(OpenRouterHandler)
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
baseURL: 'https://openrouter.ai/api/v1',
|
||||
baseURL: "https://openrouter.ai/api/v1",
|
||||
apiKey: mockOptions.openRouterApiKey,
|
||||
defaultHeaders: {
|
||||
'HTTP-Referer': 'https://github.com/RooVetGit/Roo-Cline',
|
||||
'X-Title': 'Roo-Cline',
|
||||
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
|
||||
"X-Title": "Roo-Cline",
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
test('getModel returns correct model info when options are provided', () => {
|
||||
test("getModel returns correct model info when options are provided", () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const result = handler.getModel()
|
||||
|
||||
expect(result).toEqual({
|
||||
id: mockOptions.openRouterModelId,
|
||||
info: mockOptions.openRouterModelInfo
|
||||
info: mockOptions.openRouterModelInfo,
|
||||
})
|
||||
})
|
||||
|
||||
test('getModel returns default model info when options are not provided', () => {
|
||||
test("getModel returns default model info when options are not provided", () => {
|
||||
const handler = new OpenRouterHandler({})
|
||||
const result = handler.getModel()
|
||||
|
||||
expect(result.id).toBe('anthropic/claude-3.5-sonnet:beta')
|
||||
expect(result.id).toBe("anthropic/claude-3.5-sonnet:beta")
|
||||
expect(result.info.supportsPromptCache).toBe(true)
|
||||
})
|
||||
|
||||
test('createMessage generates correct stream chunks', async () => {
|
||||
test("createMessage generates correct stream chunks", async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: 'test-id',
|
||||
choices: [{
|
||||
id: "test-id",
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: 'test response'
|
||||
}
|
||||
}]
|
||||
}
|
||||
content: "test response",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// Mock OpenAI chat.completions.create
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
|
||||
// Mock axios.get for generation details
|
||||
@@ -86,13 +88,13 @@ describe('OpenRouterHandler', () => {
|
||||
data: {
|
||||
native_tokens_prompt: 10,
|
||||
native_tokens_completion: 20,
|
||||
total_cost: 0.001
|
||||
}
|
||||
}
|
||||
total_cost: 0.001,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const systemPrompt = 'test system prompt'
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{ role: 'user' as const, content: 'test message' }]
|
||||
const systemPrompt = "test system prompt"
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user" as const, content: "test message" }]
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const chunks = []
|
||||
@@ -104,180 +106,192 @@ describe('OpenRouterHandler', () => {
|
||||
// Verify stream chunks
|
||||
expect(chunks).toHaveLength(2) // One text chunk and one usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'test response'
|
||||
type: "text",
|
||||
text: "test response",
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'usage',
|
||||
type: "usage",
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalCost: 0.001,
|
||||
fullResponseText: 'test response'
|
||||
fullResponseText: "test response",
|
||||
})
|
||||
|
||||
// Verify OpenAI client was called with correct parameters
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
expect(mockCreate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: mockOptions.openRouterModelId,
|
||||
temperature: 0,
|
||||
messages: expect.arrayContaining([
|
||||
{ role: 'system', content: systemPrompt },
|
||||
{ role: 'user', content: 'test message' }
|
||||
{ role: "system", content: systemPrompt },
|
||||
{ role: "user", content: "test message" },
|
||||
]),
|
||||
stream: true
|
||||
}))
|
||||
stream: true,
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
test('createMessage with middle-out transform enabled', async () => {
|
||||
test("createMessage with middle-out transform enabled", async () => {
|
||||
const handler = new OpenRouterHandler({
|
||||
...mockOptions,
|
||||
openRouterUseMiddleOutTransform: true
|
||||
openRouterUseMiddleOutTransform: true,
|
||||
})
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: 'test-id',
|
||||
choices: [{
|
||||
id: "test-id",
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: 'test response'
|
||||
}
|
||||
}]
|
||||
}
|
||||
content: "test response",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
||||
|
||||
await handler.createMessage('test', []).next()
|
||||
await handler.createMessage("test", []).next()
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
transforms: ['middle-out']
|
||||
}))
|
||||
expect(mockCreate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
transforms: ["middle-out"],
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
test('createMessage with Claude model adds cache control', async () => {
|
||||
test("createMessage with Claude model adds cache control", async () => {
|
||||
const handler = new OpenRouterHandler({
|
||||
...mockOptions,
|
||||
openRouterModelId: 'anthropic/claude-3.5-sonnet'
|
||||
openRouterModelId: "anthropic/claude-3.5-sonnet",
|
||||
})
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: 'test-id',
|
||||
choices: [{
|
||||
id: "test-id",
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: 'test response'
|
||||
}
|
||||
}]
|
||||
}
|
||||
content: "test response",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
||||
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{ role: 'user', content: 'message 1' },
|
||||
{ role: 'assistant', content: 'response 1' },
|
||||
{ role: 'user', content: 'message 2' }
|
||||
{ role: "user", content: "message 1" },
|
||||
{ role: "assistant", content: "response 1" },
|
||||
{ role: "user", content: "message 2" },
|
||||
]
|
||||
|
||||
await handler.createMessage('test system', messages).next()
|
||||
await handler.createMessage("test system", messages).next()
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
expect(mockCreate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
messages: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
role: 'system',
|
||||
role: "system",
|
||||
content: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
cache_control: { type: 'ephemeral' }
|
||||
})
|
||||
])
|
||||
})
|
||||
])
|
||||
}))
|
||||
cache_control: { type: "ephemeral" },
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
test('createMessage handles API errors', async () => {
|
||||
test("createMessage handles API errors", async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
error: {
|
||||
message: 'API Error',
|
||||
code: 500
|
||||
}
|
||||
}
|
||||
message: "API Error",
|
||||
code: 500,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
|
||||
const generator = handler.createMessage('test', [])
|
||||
await expect(generator.next()).rejects.toThrow('OpenRouter API Error 500: API Error')
|
||||
const generator = handler.createMessage("test", [])
|
||||
await expect(generator.next()).rejects.toThrow("OpenRouter API Error 500: API Error")
|
||||
})
|
||||
|
||||
test('completePrompt returns correct response', async () => {
|
||||
test("completePrompt returns correct response", async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockResponse = {
|
||||
choices: [{
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
content: 'test completion'
|
||||
}
|
||||
}]
|
||||
content: "test completion",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockResponse)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
|
||||
const result = await handler.completePrompt('test prompt')
|
||||
const result = await handler.completePrompt("test prompt")
|
||||
|
||||
expect(result).toBe('test completion')
|
||||
expect(result).toBe("test completion")
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.openRouterModelId,
|
||||
messages: [{ role: 'user', content: 'test prompt' }],
|
||||
messages: [{ role: "user", content: "test prompt" }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
stream: false,
|
||||
})
|
||||
})
|
||||
|
||||
test('completePrompt handles API errors', async () => {
|
||||
test("completePrompt handles API errors", async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockError = {
|
||||
error: {
|
||||
message: 'API Error',
|
||||
code: 500
|
||||
}
|
||||
message: "API Error",
|
||||
code: 500,
|
||||
},
|
||||
}
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockError)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
|
||||
await expect(handler.completePrompt('test prompt'))
|
||||
.rejects.toThrow('OpenRouter API Error 500: API Error')
|
||||
await expect(handler.completePrompt("test prompt")).rejects.toThrow("OpenRouter API Error 500: API Error")
|
||||
})
|
||||
|
||||
test('completePrompt handles unexpected errors', async () => {
|
||||
test("completePrompt handles unexpected errors", async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockCreate = jest.fn().mockRejectedValue(new Error('Unexpected error'))
|
||||
const mockCreate = jest.fn().mockRejectedValue(new Error("Unexpected error"))
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
completions: { create: mockCreate },
|
||||
} as any
|
||||
|
||||
await expect(handler.completePrompt('test prompt'))
|
||||
.rejects.toThrow('OpenRouter completion error: Unexpected error')
|
||||
await expect(handler.completePrompt("test prompt")).rejects.toThrow(
|
||||
"OpenRouter completion error: Unexpected error",
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,296 +1,295 @@
|
||||
import { VertexHandler } from '../vertex';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import { AnthropicVertex } from '@anthropic-ai/vertex-sdk';
|
||||
import { VertexHandler } from "../vertex"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
|
||||
|
||||
// Mock Vertex SDK
|
||||
jest.mock('@anthropic-ai/vertex-sdk', () => ({
|
||||
jest.mock("@anthropic-ai/vertex-sdk", () => ({
|
||||
AnthropicVertex: jest.fn().mockImplementation(() => ({
|
||||
messages: {
|
||||
create: jest.fn().mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
content: [
|
||||
{ type: 'text', text: 'Test response' }
|
||||
],
|
||||
role: 'assistant',
|
||||
id: "test-completion",
|
||||
content: [{ type: "text", text: "Test response" }],
|
||||
role: "assistant",
|
||||
model: options.model,
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
output_tokens: 5,
|
||||
},
|
||||
}
|
||||
}
|
||||
return {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: 'message_start',
|
||||
type: "message_start",
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
}
|
||||
}
|
||||
output_tokens: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
yield {
|
||||
type: 'content_block_start',
|
||||
type: "content_block_start",
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'Test response'
|
||||
type: "text",
|
||||
text: "Test response",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}),
|
||||
},
|
||||
})),
|
||||
}))
|
||||
}));
|
||||
|
||||
describe('VertexHandler', () => {
|
||||
let handler: VertexHandler;
|
||||
describe("VertexHandler", () => {
|
||||
let handler: VertexHandler
|
||||
|
||||
beforeEach(() => {
|
||||
handler = new VertexHandler({
|
||||
apiModelId: 'claude-3-5-sonnet-v2@20241022',
|
||||
vertexProjectId: 'test-project',
|
||||
vertexRegion: 'us-central1'
|
||||
});
|
||||
});
|
||||
apiModelId: "claude-3-5-sonnet-v2@20241022",
|
||||
vertexProjectId: "test-project",
|
||||
vertexRegion: "us-central1",
|
||||
})
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided config', () => {
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided config", () => {
|
||||
expect(AnthropicVertex).toHaveBeenCalledWith({
|
||||
projectId: 'test-project',
|
||||
region: 'us-central1'
|
||||
});
|
||||
});
|
||||
});
|
||||
projectId: "test-project",
|
||||
region: "us-central1",
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
describe("createMessage", () => {
|
||||
const mockMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello'
|
||||
role: "user",
|
||||
content: "Hello",
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi there!'
|
||||
}
|
||||
];
|
||||
role: "assistant",
|
||||
content: "Hi there!",
|
||||
},
|
||||
]
|
||||
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
|
||||
it('should handle streaming responses correctly', async () => {
|
||||
it("should handle streaming responses correctly", async () => {
|
||||
const mockStream = [
|
||||
{
|
||||
type: 'message_start',
|
||||
type: "message_start",
|
||||
message: {
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 0
|
||||
}
|
||||
}
|
||||
output_tokens: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'content_block_start',
|
||||
type: "content_block_start",
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'Hello'
|
||||
}
|
||||
type: "text",
|
||||
text: "Hello",
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'content_block_delta',
|
||||
type: "content_block_delta",
|
||||
delta: {
|
||||
type: 'text_delta',
|
||||
text: ' world!'
|
||||
}
|
||||
type: "text_delta",
|
||||
text: " world!",
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'message_delta',
|
||||
type: "message_delta",
|
||||
usage: {
|
||||
output_tokens: 5
|
||||
}
|
||||
}
|
||||
];
|
||||
output_tokens: 5,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
// Setup async iterator for mock stream
|
||||
const asyncIterator = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk;
|
||||
yield chunk
|
||||
}
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const chunks = [];
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
const chunks = []
|
||||
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks.length).toBe(4);
|
||||
expect(chunks.length).toBe(4)
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'usage',
|
||||
type: "usage",
|
||||
inputTokens: 10,
|
||||
outputTokens: 0
|
||||
});
|
||||
outputTokens: 0,
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Hello'
|
||||
});
|
||||
type: "text",
|
||||
text: "Hello",
|
||||
})
|
||||
expect(chunks[2]).toEqual({
|
||||
type: 'text',
|
||||
text: ' world!'
|
||||
});
|
||||
type: "text",
|
||||
text: " world!",
|
||||
})
|
||||
expect(chunks[3]).toEqual({
|
||||
type: 'usage',
|
||||
type: "usage",
|
||||
inputTokens: 0,
|
||||
outputTokens: 5
|
||||
});
|
||||
outputTokens: 5,
|
||||
})
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'claude-3-5-sonnet-v2@20241022',
|
||||
model: "claude-3-5-sonnet-v2@20241022",
|
||||
max_tokens: 8192,
|
||||
temperature: 0,
|
||||
system: systemPrompt,
|
||||
messages: mockMessages,
|
||||
stream: true
|
||||
});
|
||||
});
|
||||
stream: true,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle multiple content blocks with line breaks', async () => {
|
||||
it("should handle multiple content blocks with line breaks", async () => {
|
||||
const mockStream = [
|
||||
{
|
||||
type: 'content_block_start',
|
||||
type: "content_block_start",
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'First line'
|
||||
}
|
||||
type: "text",
|
||||
text: "First line",
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'content_block_start',
|
||||
type: "content_block_start",
|
||||
index: 1,
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: 'Second line'
|
||||
}
|
||||
}
|
||||
];
|
||||
type: "text",
|
||||
text: "Second line",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
const asyncIterator = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk;
|
||||
yield chunk
|
||||
}
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const chunks = [];
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
const chunks = []
|
||||
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks.length).toBe(3);
|
||||
expect(chunks.length).toBe(3)
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'First line'
|
||||
});
|
||||
type: "text",
|
||||
text: "First line",
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'text',
|
||||
text: '\n'
|
||||
});
|
||||
type: "text",
|
||||
text: "\n",
|
||||
})
|
||||
expect(chunks[2]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Second line'
|
||||
});
|
||||
});
|
||||
type: "text",
|
||||
text: "Second line",
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('Vertex API error');
|
||||
const mockCreate = jest.fn().mockRejectedValue(mockError);
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
it("should handle API errors", async () => {
|
||||
const mockError = new Error("Vertex API error")
|
||||
const mockCreate = jest.fn().mockRejectedValue(mockError)
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages);
|
||||
const stream = handler.createMessage(systemPrompt, mockMessages)
|
||||
|
||||
await expect(async () => {
|
||||
for await (const chunk of stream) {
|
||||
// Should throw before yielding any chunks
|
||||
}
|
||||
}).rejects.toThrow('Vertex API error');
|
||||
});
|
||||
});
|
||||
}).rejects.toThrow("Vertex API error")
|
||||
})
|
||||
})
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(handler['client'].messages.create).toHaveBeenCalledWith({
|
||||
model: 'claude-3-5-sonnet-v2@20241022',
|
||||
describe("completePrompt", () => {
|
||||
it("should complete prompt successfully", async () => {
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("Test response")
|
||||
expect(handler["client"].messages.create).toHaveBeenCalledWith({
|
||||
model: "claude-3-5-sonnet-v2@20241022",
|
||||
max_tokens: 8192,
|
||||
temperature: 0,
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
stream: false
|
||||
});
|
||||
});
|
||||
messages: [{ role: "user", content: "Test prompt" }],
|
||||
stream: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
const mockError = new Error('Vertex API error');
|
||||
const mockCreate = jest.fn().mockRejectedValue(mockError);
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
it("should handle API errors", async () => {
|
||||
const mockError = new Error("Vertex API error")
|
||||
const mockCreate = jest.fn().mockRejectedValue(mockError)
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('Vertex completion error: Vertex API error');
|
||||
});
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"Vertex completion error: Vertex API error",
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle non-text content', async () => {
|
||||
it("should handle non-text content", async () => {
|
||||
const mockCreate = jest.fn().mockResolvedValue({
|
||||
content: [{ type: 'image' }]
|
||||
});
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
content: [{ type: "image" }],
|
||||
})
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
it("should handle empty response", async () => {
|
||||
const mockCreate = jest.fn().mockResolvedValue({
|
||||
content: [{ type: 'text', text: '' }]
|
||||
});
|
||||
(handler['client'].messages as any).create = mockCreate;
|
||||
content: [{ type: "text", text: "" }],
|
||||
})
|
||||
;(handler["client"].messages as any).create = mockCreate
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return correct model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe('claude-3-5-sonnet-v2@20241022');
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(8192);
|
||||
expect(modelInfo.info.contextWindow).toBe(200_000);
|
||||
});
|
||||
describe("getModel", () => {
|
||||
it("should return correct model info", () => {
|
||||
const modelInfo = handler.getModel()
|
||||
expect(modelInfo.id).toBe("claude-3-5-sonnet-v2@20241022")
|
||||
expect(modelInfo.info).toBeDefined()
|
||||
expect(modelInfo.info.maxTokens).toBe(8192)
|
||||
expect(modelInfo.info.contextWindow).toBe(200_000)
|
||||
})
|
||||
|
||||
it('should return default model if invalid model specified', () => {
|
||||
it("should return default model if invalid model specified", () => {
|
||||
const invalidHandler = new VertexHandler({
|
||||
apiModelId: 'invalid-model',
|
||||
vertexProjectId: 'test-project',
|
||||
vertexRegion: 'us-central1'
|
||||
});
|
||||
const modelInfo = invalidHandler.getModel();
|
||||
expect(modelInfo.id).toBe('claude-3-5-sonnet-v2@20241022'); // Default model
|
||||
});
|
||||
});
|
||||
});
|
||||
apiModelId: "invalid-model",
|
||||
vertexProjectId: "test-project",
|
||||
vertexRegion: "us-central1",
|
||||
})
|
||||
const modelInfo = invalidHandler.getModel()
|
||||
expect(modelInfo.id).toBe("claude-3-5-sonnet-v2@20241022") // Default model
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,289 +1,295 @@
|
||||
import * as vscode from 'vscode';
|
||||
import { VsCodeLmHandler } from '../vscode-lm';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import * as vscode from "vscode"
|
||||
import { VsCodeLmHandler } from "../vscode-lm"
|
||||
import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
|
||||
// Mock vscode namespace
|
||||
jest.mock('vscode', () => {
|
||||
jest.mock("vscode", () => {
|
||||
class MockLanguageModelTextPart {
|
||||
type = 'text';
|
||||
type = "text"
|
||||
constructor(public value: string) {}
|
||||
}
|
||||
|
||||
class MockLanguageModelToolCallPart {
|
||||
type = 'tool_call';
|
||||
type = "tool_call"
|
||||
constructor(
|
||||
public callId: string,
|
||||
public name: string,
|
||||
public input: any
|
||||
public input: any,
|
||||
) {}
|
||||
}
|
||||
|
||||
return {
|
||||
workspace: {
|
||||
onDidChangeConfiguration: jest.fn((callback) => ({
|
||||
dispose: jest.fn()
|
||||
}))
|
||||
dispose: jest.fn(),
|
||||
})),
|
||||
},
|
||||
CancellationTokenSource: jest.fn(() => ({
|
||||
token: {
|
||||
isCancellationRequested: false,
|
||||
onCancellationRequested: jest.fn()
|
||||
onCancellationRequested: jest.fn(),
|
||||
},
|
||||
cancel: jest.fn(),
|
||||
dispose: jest.fn()
|
||||
dispose: jest.fn(),
|
||||
})),
|
||||
CancellationError: class CancellationError extends Error {
|
||||
constructor() {
|
||||
super('Operation cancelled');
|
||||
this.name = 'CancellationError';
|
||||
super("Operation cancelled")
|
||||
this.name = "CancellationError"
|
||||
}
|
||||
},
|
||||
LanguageModelChatMessage: {
|
||||
Assistant: jest.fn((content) => ({
|
||||
role: 'assistant',
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
|
||||
role: "assistant",
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
|
||||
})),
|
||||
User: jest.fn((content) => ({
|
||||
role: 'user',
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
|
||||
}))
|
||||
role: "user",
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
|
||||
})),
|
||||
},
|
||||
LanguageModelTextPart: MockLanguageModelTextPart,
|
||||
LanguageModelToolCallPart: MockLanguageModelToolCallPart,
|
||||
lm: {
|
||||
selectChatModels: jest.fn()
|
||||
selectChatModels: jest.fn(),
|
||||
},
|
||||
}
|
||||
};
|
||||
});
|
||||
})
|
||||
|
||||
const mockLanguageModelChat = {
|
||||
id: 'test-model',
|
||||
name: 'Test Model',
|
||||
vendor: 'test-vendor',
|
||||
family: 'test-family',
|
||||
version: '1.0',
|
||||
id: "test-model",
|
||||
name: "Test Model",
|
||||
vendor: "test-vendor",
|
||||
family: "test-family",
|
||||
version: "1.0",
|
||||
maxInputTokens: 4096,
|
||||
sendRequest: jest.fn(),
|
||||
countTokens: jest.fn()
|
||||
};
|
||||
countTokens: jest.fn(),
|
||||
}
|
||||
|
||||
describe('VsCodeLmHandler', () => {
|
||||
let handler: VsCodeLmHandler;
|
||||
describe("VsCodeLmHandler", () => {
|
||||
let handler: VsCodeLmHandler
|
||||
const defaultOptions: ApiHandlerOptions = {
|
||||
vsCodeLmModelSelector: {
|
||||
vendor: 'test-vendor',
|
||||
family: 'test-family'
|
||||
vendor: "test-vendor",
|
||||
family: "test-family",
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
handler = new VsCodeLmHandler(defaultOptions);
|
||||
});
|
||||
jest.clearAllMocks()
|
||||
handler = new VsCodeLmHandler(defaultOptions)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
handler.dispose();
|
||||
});
|
||||
handler.dispose()
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeDefined();
|
||||
expect(vscode.workspace.onDidChangeConfiguration).toHaveBeenCalled();
|
||||
});
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(handler).toBeDefined()
|
||||
expect(vscode.workspace.onDidChangeConfiguration).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle configuration changes', () => {
|
||||
const callback = (vscode.workspace.onDidChangeConfiguration as jest.Mock).mock.calls[0][0];
|
||||
callback({ affectsConfiguration: () => true });
|
||||
it("should handle configuration changes", () => {
|
||||
const callback = (vscode.workspace.onDidChangeConfiguration as jest.Mock).mock.calls[0][0]
|
||||
callback({ affectsConfiguration: () => true })
|
||||
// Should reset client when config changes
|
||||
expect(handler['client']).toBeNull();
|
||||
});
|
||||
});
|
||||
expect(handler["client"]).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('createClient', () => {
|
||||
it('should create client with selector', async () => {
|
||||
const mockModel = { ...mockLanguageModelChat };
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
||||
describe("createClient", () => {
|
||||
it("should create client with selector", async () => {
|
||||
const mockModel = { ...mockLanguageModelChat }
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||
|
||||
const client = await handler['createClient']({
|
||||
vendor: 'test-vendor',
|
||||
family: 'test-family'
|
||||
});
|
||||
const client = await handler["createClient"]({
|
||||
vendor: "test-vendor",
|
||||
family: "test-family",
|
||||
})
|
||||
|
||||
expect(client).toBeDefined();
|
||||
expect(client.id).toBe('test-model');
|
||||
expect(client).toBeDefined()
|
||||
expect(client.id).toBe("test-model")
|
||||
expect(vscode.lm.selectChatModels).toHaveBeenCalledWith({
|
||||
vendor: 'test-vendor',
|
||||
family: 'test-family'
|
||||
});
|
||||
});
|
||||
vendor: "test-vendor",
|
||||
family: "test-family",
|
||||
})
|
||||
})
|
||||
|
||||
it('should return default client when no models available', async () => {
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([]);
|
||||
it("should return default client when no models available", async () => {
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([])
|
||||
|
||||
const client = await handler['createClient']({});
|
||||
const client = await handler["createClient"]({})
|
||||
|
||||
expect(client).toBeDefined();
|
||||
expect(client.id).toBe('default-lm');
|
||||
expect(client.vendor).toBe('vscode');
|
||||
});
|
||||
});
|
||||
expect(client).toBeDefined()
|
||||
expect(client.id).toBe("default-lm")
|
||||
expect(client.vendor).toBe("vscode")
|
||||
})
|
||||
})
|
||||
|
||||
describe('createMessage', () => {
|
||||
describe("createMessage", () => {
|
||||
beforeEach(() => {
|
||||
const mockModel = { ...mockLanguageModelChat };
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
||||
mockLanguageModelChat.countTokens.mockResolvedValue(10);
|
||||
});
|
||||
const mockModel = { ...mockLanguageModelChat }
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||
mockLanguageModelChat.countTokens.mockResolvedValue(10)
|
||||
})
|
||||
|
||||
it('should stream text responses', async () => {
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
||||
role: 'user' as const,
|
||||
content: 'Hello'
|
||||
}];
|
||||
it("should stream text responses", async () => {
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: "Hello",
|
||||
},
|
||||
]
|
||||
|
||||
const responseText = 'Hello! How can I help you?';
|
||||
const responseText = "Hello! How can I help you?"
|
||||
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
|
||||
stream: (async function* () {
|
||||
yield new vscode.LanguageModelTextPart(responseText);
|
||||
return;
|
||||
yield new vscode.LanguageModelTextPart(responseText)
|
||||
return
|
||||
})(),
|
||||
text: (async function* () {
|
||||
yield responseText;
|
||||
return;
|
||||
})()
|
||||
});
|
||||
yield responseText
|
||||
return
|
||||
})(),
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks = [];
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks).toHaveLength(2); // Text chunk + usage chunk
|
||||
expect(chunks).toHaveLength(2) // Text chunk + usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: responseText
|
||||
});
|
||||
type: "text",
|
||||
text: responseText,
|
||||
})
|
||||
expect(chunks[1]).toMatchObject({
|
||||
type: 'usage',
|
||||
type: "usage",
|
||||
inputTokens: expect.any(Number),
|
||||
outputTokens: expect.any(Number)
|
||||
});
|
||||
});
|
||||
outputTokens: expect.any(Number),
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle tool calls', async () => {
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
||||
role: 'user' as const,
|
||||
content: 'Calculate 2+2'
|
||||
}];
|
||||
it("should handle tool calls", async () => {
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: "Calculate 2+2",
|
||||
},
|
||||
]
|
||||
|
||||
const toolCallData = {
|
||||
name: 'calculator',
|
||||
arguments: { operation: 'add', numbers: [2, 2] },
|
||||
callId: 'call-1'
|
||||
};
|
||||
name: "calculator",
|
||||
arguments: { operation: "add", numbers: [2, 2] },
|
||||
callId: "call-1",
|
||||
}
|
||||
|
||||
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
|
||||
stream: (async function* () {
|
||||
yield new vscode.LanguageModelToolCallPart(
|
||||
toolCallData.callId,
|
||||
toolCallData.name,
|
||||
toolCallData.arguments
|
||||
);
|
||||
return;
|
||||
toolCallData.arguments,
|
||||
)
|
||||
return
|
||||
})(),
|
||||
text: (async function* () {
|
||||
yield JSON.stringify({ type: 'tool_call', ...toolCallData });
|
||||
return;
|
||||
})()
|
||||
});
|
||||
yield JSON.stringify({ type: "tool_call", ...toolCallData })
|
||||
return
|
||||
})(),
|
||||
})
|
||||
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks = [];
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
const chunks = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks).toHaveLength(2); // Tool call chunk + usage chunk
|
||||
expect(chunks).toHaveLength(2) // Tool call chunk + usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: JSON.stringify({ type: 'tool_call', ...toolCallData })
|
||||
});
|
||||
});
|
||||
type: "text",
|
||||
text: JSON.stringify({ type: "tool_call", ...toolCallData }),
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle errors', async () => {
|
||||
const systemPrompt = 'You are a helpful assistant';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
||||
role: 'user' as const,
|
||||
content: 'Hello'
|
||||
}];
|
||||
it("should handle errors", async () => {
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: "Hello",
|
||||
},
|
||||
]
|
||||
|
||||
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error('API Error'));
|
||||
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("API Error"))
|
||||
|
||||
await expect(async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const stream = handler.createMessage(systemPrompt, messages)
|
||||
for await (const _ of stream) {
|
||||
// consume stream
|
||||
}
|
||||
}).rejects.toThrow('API Error');
|
||||
});
|
||||
});
|
||||
}).rejects.toThrow("API Error")
|
||||
})
|
||||
})
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info when client exists', async () => {
|
||||
const mockModel = { ...mockLanguageModelChat };
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
||||
describe("getModel", () => {
|
||||
it("should return model info when client exists", async () => {
|
||||
const mockModel = { ...mockLanguageModelChat }
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||
|
||||
// Initialize client
|
||||
await handler['getClient']();
|
||||
await handler["getClient"]()
|
||||
|
||||
const model = handler.getModel();
|
||||
expect(model.id).toBe('test-model');
|
||||
expect(model.info).toBeDefined();
|
||||
expect(model.info.contextWindow).toBe(4096);
|
||||
});
|
||||
const model = handler.getModel()
|
||||
expect(model.id).toBe("test-model")
|
||||
expect(model.info).toBeDefined()
|
||||
expect(model.info.contextWindow).toBe(4096)
|
||||
})
|
||||
|
||||
it('should return fallback model info when no client exists', () => {
|
||||
const model = handler.getModel();
|
||||
expect(model.id).toBe('test-vendor/test-family');
|
||||
expect(model.info).toBeDefined();
|
||||
});
|
||||
});
|
||||
it("should return fallback model info when no client exists", () => {
|
||||
const model = handler.getModel()
|
||||
expect(model.id).toBe("test-vendor/test-family")
|
||||
expect(model.info).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete single prompt', async () => {
|
||||
const mockModel = { ...mockLanguageModelChat };
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
||||
describe("completePrompt", () => {
|
||||
it("should complete single prompt", async () => {
|
||||
const mockModel = { ...mockLanguageModelChat }
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||
|
||||
const responseText = 'Completed text';
|
||||
const responseText = "Completed text"
|
||||
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
|
||||
stream: (async function* () {
|
||||
yield new vscode.LanguageModelTextPart(responseText);
|
||||
return;
|
||||
yield new vscode.LanguageModelTextPart(responseText)
|
||||
return
|
||||
})(),
|
||||
text: (async function* () {
|
||||
yield responseText;
|
||||
return;
|
||||
})()
|
||||
});
|
||||
yield responseText
|
||||
return
|
||||
})(),
|
||||
})
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe(responseText);
|
||||
expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled();
|
||||
});
|
||||
const result = await handler.completePrompt("Test prompt")
|
||||
expect(result).toBe(responseText)
|
||||
expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle errors during completion', async () => {
|
||||
const mockModel = { ...mockLanguageModelChat };
|
||||
(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]);
|
||||
it("should handle errors during completion", async () => {
|
||||
const mockModel = { ...mockLanguageModelChat }
|
||||
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
|
||||
|
||||
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error('Completion failed'));
|
||||
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("Completion failed"))
|
||||
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects
|
||||
.toThrow('VSCode LM completion error: Completion failed');
|
||||
});
|
||||
});
|
||||
});
|
||||
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
||||
"VSCode LM completion error: Completion failed",
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -181,14 +181,14 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
|
||||
max_tokens: this.getModel().info.maxTokens || 8192,
|
||||
temperature: 0,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
stream: false
|
||||
stream: false,
|
||||
})
|
||||
|
||||
const content = response.content[0]
|
||||
if (content.type === 'text') {
|
||||
if (content.type === "text") {
|
||||
return content.text
|
||||
}
|
||||
return ''
|
||||
return ""
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`Anthropic completion error: ${error.message}`)
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime"
|
||||
import {
|
||||
BedrockRuntimeClient,
|
||||
ConverseStreamCommand,
|
||||
ConverseCommand,
|
||||
BedrockRuntimeClientConfig,
|
||||
} from "@aws-sdk/client-bedrock-runtime"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
|
||||
@@ -8,34 +13,34 @@ import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../
|
||||
// Define types for stream events based on AWS SDK
|
||||
export interface StreamEvent {
|
||||
messageStart?: {
|
||||
role?: string;
|
||||
};
|
||||
role?: string
|
||||
}
|
||||
messageStop?: {
|
||||
stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence";
|
||||
additionalModelResponseFields?: Record<string, unknown>;
|
||||
};
|
||||
stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence"
|
||||
additionalModelResponseFields?: Record<string, unknown>
|
||||
}
|
||||
contentBlockStart?: {
|
||||
start?: {
|
||||
text?: string;
|
||||
};
|
||||
contentBlockIndex?: number;
|
||||
};
|
||||
text?: string
|
||||
}
|
||||
contentBlockIndex?: number
|
||||
}
|
||||
contentBlockDelta?: {
|
||||
delta?: {
|
||||
text?: string;
|
||||
};
|
||||
contentBlockIndex?: number;
|
||||
};
|
||||
text?: string
|
||||
}
|
||||
contentBlockIndex?: number
|
||||
}
|
||||
metadata?: {
|
||||
usage?: {
|
||||
inputTokens: number;
|
||||
outputTokens: number;
|
||||
totalTokens?: number; // Made optional since we don't use it
|
||||
};
|
||||
inputTokens: number
|
||||
outputTokens: number
|
||||
totalTokens?: number // Made optional since we don't use it
|
||||
}
|
||||
metrics?: {
|
||||
latencyMs: number;
|
||||
};
|
||||
};
|
||||
latencyMs: number
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||
@@ -47,7 +52,7 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||
|
||||
// Only include credentials if they actually exist
|
||||
const clientConfig: BedrockRuntimeClientConfig = {
|
||||
region: this.options.awsRegion || "us-east-1"
|
||||
region: this.options.awsRegion || "us-east-1",
|
||||
}
|
||||
|
||||
if (this.options.awsAccessKey && this.options.awsSecretKey) {
|
||||
@@ -55,7 +60,7 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||
clientConfig.credentials = {
|
||||
accessKeyId: this.options.awsAccessKey,
|
||||
secretAccessKey: this.options.awsSecretKey,
|
||||
...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {})
|
||||
...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,12 +101,14 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||
maxTokens: modelConfig.info.maxTokens || 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1,
|
||||
...(this.options.awsUsePromptCache ? {
|
||||
...(this.options.awsUsePromptCache
|
||||
? {
|
||||
promptCache: {
|
||||
promptCacheId: this.options.awspromptCacheId || ""
|
||||
}
|
||||
} : {})
|
||||
promptCacheId: this.options.awspromptCacheId || "",
|
||||
},
|
||||
}
|
||||
: {}),
|
||||
},
|
||||
}
|
||||
|
||||
try {
|
||||
@@ -109,18 +116,16 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||
const response = await this.client.send(command)
|
||||
|
||||
if (!response.stream) {
|
||||
throw new Error('No stream available in the response')
|
||||
throw new Error("No stream available in the response")
|
||||
}
|
||||
|
||||
for await (const chunk of response.stream) {
|
||||
// Parse the chunk as JSON if it's a string (for tests)
|
||||
let streamEvent: StreamEvent
|
||||
try {
|
||||
streamEvent = typeof chunk === 'string' ?
|
||||
JSON.parse(chunk) :
|
||||
chunk as unknown as StreamEvent
|
||||
streamEvent = typeof chunk === "string" ? JSON.parse(chunk) : (chunk as unknown as StreamEvent)
|
||||
} catch (e) {
|
||||
console.error('Failed to parse stream event:', e)
|
||||
console.error("Failed to parse stream event:", e)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -129,7 +134,7 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||
yield {
|
||||
type: "usage",
|
||||
inputTokens: streamEvent.metadata.usage.inputTokens || 0,
|
||||
outputTokens: streamEvent.metadata.usage.outputTokens || 0
|
||||
outputTokens: streamEvent.metadata.usage.outputTokens || 0,
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -143,7 +148,7 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||
if (streamEvent.contentBlockStart?.start?.text) {
|
||||
yield {
|
||||
type: "text",
|
||||
text: streamEvent.contentBlockStart.start.text
|
||||
text: streamEvent.contentBlockStart.start.text,
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -152,7 +157,7 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||
if (streamEvent.contentBlockDelta?.delta?.text) {
|
||||
yield {
|
||||
type: "text",
|
||||
text: streamEvent.contentBlockDelta.delta.text
|
||||
text: streamEvent.contentBlockDelta.delta.text,
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -162,32 +167,31 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
} catch (error: unknown) {
|
||||
console.error('Bedrock Runtime API Error:', error)
|
||||
console.error("Bedrock Runtime API Error:", error)
|
||||
// Only access stack if error is an Error object
|
||||
if (error instanceof Error) {
|
||||
console.error('Error stack:', error.stack)
|
||||
console.error("Error stack:", error.stack)
|
||||
yield {
|
||||
type: "text",
|
||||
text: `Error: ${error.message}`
|
||||
text: `Error: ${error.message}`,
|
||||
}
|
||||
yield {
|
||||
type: "usage",
|
||||
inputTokens: 0,
|
||||
outputTokens: 0
|
||||
outputTokens: 0,
|
||||
}
|
||||
throw error
|
||||
} else {
|
||||
const unknownError = new Error("An unknown error occurred")
|
||||
yield {
|
||||
type: "text",
|
||||
text: unknownError.message
|
||||
text: unknownError.message,
|
||||
}
|
||||
yield {
|
||||
type: "usage",
|
||||
inputTokens: 0,
|
||||
outputTokens: 0
|
||||
outputTokens: 0,
|
||||
}
|
||||
throw unknownError
|
||||
}
|
||||
@@ -198,14 +202,14 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||
const modelId = this.options.apiModelId
|
||||
if (modelId) {
|
||||
// For tests, allow any model ID
|
||||
if (process.env.NODE_ENV === 'test') {
|
||||
if (process.env.NODE_ENV === "test") {
|
||||
return {
|
||||
id: modelId,
|
||||
info: {
|
||||
maxTokens: 5000,
|
||||
contextWindow: 128_000,
|
||||
supportsPromptCache: false
|
||||
}
|
||||
supportsPromptCache: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
// For production, validate against known models
|
||||
@@ -216,7 +220,7 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||
}
|
||||
return {
|
||||
id: bedrockDefaultModelId,
|
||||
info: bedrockModels[bedrockDefaultModelId]
|
||||
info: bedrockModels[bedrockDefaultModelId],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -245,15 +249,17 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||
|
||||
const payload = {
|
||||
modelId,
|
||||
messages: convertToBedrockConverseMessages([{
|
||||
messages: convertToBedrockConverseMessages([
|
||||
{
|
||||
role: "user",
|
||||
content: prompt
|
||||
}]),
|
||||
content: prompt,
|
||||
},
|
||||
]),
|
||||
inferenceConfig: {
|
||||
maxTokens: modelConfig.info.maxTokens || 5000,
|
||||
temperature: 0.3,
|
||||
topP: 0.1
|
||||
}
|
||||
topP: 0.1,
|
||||
},
|
||||
}
|
||||
|
||||
const command = new ConverseCommand(payload)
|
||||
@@ -267,10 +273,10 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
|
||||
return output.content
|
||||
}
|
||||
} catch (parseError) {
|
||||
console.error('Failed to parse Bedrock response:', parseError)
|
||||
console.error("Failed to parse Bedrock response:", parseError)
|
||||
}
|
||||
}
|
||||
return ''
|
||||
return ""
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`Bedrock completion error: ${error.message}`)
|
||||
|
||||
@@ -12,7 +12,7 @@ export class DeepSeekHandler extends OpenAiHandler {
|
||||
openAiApiKey: options.deepSeekApiKey,
|
||||
openAiModelId: options.deepSeekModelId ?? deepSeekDefaultModelId,
|
||||
openAiBaseUrl: options.deepSeekBaseUrl ?? "https://api.deepseek.com/v1",
|
||||
includeMaxTokens: true
|
||||
includeMaxTokens: true,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ export class DeepSeekHandler extends OpenAiHandler {
|
||||
const modelId = this.options.deepSeekModelId ?? deepSeekDefaultModelId
|
||||
return {
|
||||
id: modelId,
|
||||
info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId]
|
||||
info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,17 +72,17 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
|
||||
maxTokens = 8_192
|
||||
}
|
||||
|
||||
const { data: completion, response } = await this.client.chat.completions.create({
|
||||
const { data: completion, response } = await this.client.chat.completions
|
||||
.create({
|
||||
model: this.getModel().id,
|
||||
max_tokens: maxTokens,
|
||||
temperature: 0,
|
||||
messages: openAiMessages,
|
||||
stream: true,
|
||||
}).withResponse();
|
||||
})
|
||||
.withResponse()
|
||||
|
||||
const completionRequestId = response.headers.get(
|
||||
'x-completion-request-id',
|
||||
);
|
||||
const completionRequestId = response.headers.get("x-completion-request-id")
|
||||
|
||||
for await (const chunk of completion) {
|
||||
const delta = chunk.choices[0]?.delta
|
||||
@@ -96,13 +96,16 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await axios.get(`https://glama.ai/api/gateway/v1/completion-requests/${completionRequestId}`, {
|
||||
const response = await axios.get(
|
||||
`https://glama.ai/api/gateway/v1/completion-requests/${completionRequestId}`,
|
||||
{
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.options.glamaApiKey}`,
|
||||
},
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
const completionRequest = response.data;
|
||||
const completionRequest = response.data
|
||||
|
||||
if (completionRequest.tokenUsage) {
|
||||
yield {
|
||||
|
||||
@@ -60,7 +60,7 @@ export class LmStudioHandler implements ApiHandler, SingleCompletionHandler {
|
||||
model: this.getModel().id,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
stream: false,
|
||||
})
|
||||
return response.choices[0]?.message.content || ""
|
||||
} catch (error) {
|
||||
|
||||
@@ -53,7 +53,7 @@ export class OllamaHandler implements ApiHandler, SingleCompletionHandler {
|
||||
model: this.getModel().id,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
stream: false,
|
||||
})
|
||||
return response.choices[0]?.message.content || ""
|
||||
} catch (error) {
|
||||
|
||||
@@ -32,7 +32,10 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
|
||||
// o1 doesnt support streaming or non-1 temp but does support a developer prompt
|
||||
const response = await this.client.chat.completions.create({
|
||||
model: modelId,
|
||||
messages: [{ role: modelId === "o1" ? "developer" : "user", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
|
||||
messages: [
|
||||
{ role: modelId === "o1" ? "developer" : "user", content: systemPrompt },
|
||||
...convertToOpenAiMessages(messages),
|
||||
],
|
||||
})
|
||||
yield {
|
||||
type: "text",
|
||||
@@ -98,14 +101,14 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
|
||||
// o1 doesn't support non-1 temp
|
||||
requestOptions = {
|
||||
model: modelId,
|
||||
messages: [{ role: "user", content: prompt }]
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
}
|
||||
break
|
||||
default:
|
||||
requestOptions = {
|
||||
model: modelId,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
temperature: 0
|
||||
temperature: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
this.options = options
|
||||
// Azure API shape slightly differs from the core API shape: https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
|
||||
const urlHost = new URL(this.options.openAiBaseUrl ?? "").host;
|
||||
const urlHost = new URL(this.options.openAiBaseUrl ?? "").host
|
||||
if (urlHost === "azure.com" || urlHost.endsWith(".azure.com")) {
|
||||
this.client = new AzureOpenAI({
|
||||
baseURL: this.options.openAiBaseUrl,
|
||||
@@ -39,7 +39,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
|
||||
if (this.options.openAiStreamingEnabled ?? true) {
|
||||
const systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = {
|
||||
role: "system",
|
||||
content: systemPrompt
|
||||
content: systemPrompt,
|
||||
}
|
||||
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
|
||||
model: modelId,
|
||||
@@ -74,7 +74,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
|
||||
// o1 for instance doesnt support streaming, non-1 temp, or system prompt
|
||||
const systemMessage: OpenAI.Chat.ChatCompletionUserMessageParam = {
|
||||
role: "user",
|
||||
content: systemPrompt
|
||||
content: systemPrompt,
|
||||
}
|
||||
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
|
||||
model: modelId,
|
||||
|
||||
@@ -9,12 +9,12 @@ import delay from "delay"
|
||||
|
||||
// Add custom interface for OpenRouter params
|
||||
type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & {
|
||||
transforms?: string[];
|
||||
transforms?: string[]
|
||||
}
|
||||
|
||||
// Add custom interface for OpenRouter usage chunk
|
||||
interface OpenRouterApiStreamUsageChunk extends ApiStreamUsageChunk {
|
||||
fullResponseText: string;
|
||||
fullResponseText: string
|
||||
}
|
||||
|
||||
import { SingleCompletionHandler } from ".."
|
||||
@@ -35,7 +35,10 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
||||
})
|
||||
}
|
||||
|
||||
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): AsyncGenerator<ApiStreamChunk> {
|
||||
async *createMessage(
|
||||
systemPrompt: string,
|
||||
messages: Anthropic.Messages.MessageParam[],
|
||||
): AsyncGenerator<ApiStreamChunk> {
|
||||
// Convert Anthropic messages to OpenAI format
|
||||
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
|
||||
{ role: "system", content: systemPrompt },
|
||||
@@ -108,7 +111,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
||||
break
|
||||
}
|
||||
// https://openrouter.ai/docs/transforms
|
||||
let fullResponseText = "";
|
||||
let fullResponseText = ""
|
||||
const stream = await this.client.chat.completions.create({
|
||||
model: this.getModel().id,
|
||||
max_tokens: maxTokens,
|
||||
@@ -116,8 +119,8 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
||||
messages: openAiMessages,
|
||||
stream: true,
|
||||
// This way, the transforms field will only be included in the parameters when openRouterUseMiddleOutTransform is true.
|
||||
...(this.options.openRouterUseMiddleOutTransform && { transforms: ["middle-out"] })
|
||||
} as OpenRouterChatCompletionParams);
|
||||
...(this.options.openRouterUseMiddleOutTransform && { transforms: ["middle-out"] }),
|
||||
} as OpenRouterChatCompletionParams)
|
||||
|
||||
let genId: string | undefined
|
||||
|
||||
@@ -135,11 +138,11 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
||||
|
||||
const delta = chunk.choices[0]?.delta
|
||||
if (delta?.content) {
|
||||
fullResponseText += delta.content;
|
||||
fullResponseText += delta.content
|
||||
yield {
|
||||
type: "text",
|
||||
text: delta.content,
|
||||
} as ApiStreamChunk;
|
||||
} as ApiStreamChunk
|
||||
}
|
||||
// if (chunk.usage) {
|
||||
// yield {
|
||||
@@ -170,13 +173,12 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
||||
inputTokens: generation?.native_tokens_prompt || 0,
|
||||
outputTokens: generation?.native_tokens_completion || 0,
|
||||
totalCost: generation?.total_cost || 0,
|
||||
fullResponseText
|
||||
} as OpenRouterApiStreamUsageChunk;
|
||||
fullResponseText,
|
||||
} as OpenRouterApiStreamUsageChunk
|
||||
} catch (error) {
|
||||
// ignore if fails
|
||||
console.error("Error fetching OpenRouter generation details:", error)
|
||||
}
|
||||
|
||||
}
|
||||
getModel(): { id: string; info: ModelInfo } {
|
||||
const modelId = this.options.openRouterModelId
|
||||
@@ -193,7 +195,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
||||
model: this.getModel().id,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
stream: false,
|
||||
})
|
||||
|
||||
if ("error" in response) {
|
||||
|
||||
@@ -91,14 +91,14 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
|
||||
max_tokens: this.getModel().info.maxTokens || 8192,
|
||||
temperature: 0,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
stream: false
|
||||
stream: false,
|
||||
})
|
||||
|
||||
const content = response.content[0]
|
||||
if (content.type === 'text') {
|
||||
if (content.type === "text") {
|
||||
return content.text
|
||||
}
|
||||
return ''
|
||||
return ""
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`Vertex completion error: ${error.message}`)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk";
|
||||
import * as vscode from 'vscode';
|
||||
import { ApiHandler, SingleCompletionHandler } from "../";
|
||||
import { calculateApiCost } from "../../utils/cost";
|
||||
import { ApiStream } from "../transform/stream";
|
||||
import { convertToVsCodeLmMessages } from "../transform/vscode-lm-format";
|
||||
import { SELECTOR_SEPARATOR, stringifyVsCodeLmModelSelector } from "../../shared/vsCodeSelectorUtils";
|
||||
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api";
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import * as vscode from "vscode"
|
||||
import { ApiHandler, SingleCompletionHandler } from "../"
|
||||
import { calculateApiCost } from "../../utils/cost"
|
||||
import { ApiStream } from "../transform/stream"
|
||||
import { convertToVsCodeLmMessages } from "../transform/vscode-lm-format"
|
||||
import { SELECTOR_SEPARATOR, stringifyVsCodeLmModelSelector } from "../../shared/vsCodeSelectorUtils"
|
||||
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
|
||||
|
||||
/**
|
||||
* Handles interaction with VS Code's Language Model API for chat-based operations.
|
||||
@@ -35,39 +35,36 @@ import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../..
|
||||
* ```
|
||||
*/
|
||||
export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
||||
|
||||
private options: ApiHandlerOptions;
|
||||
private client: vscode.LanguageModelChat | null;
|
||||
private disposable: vscode.Disposable | null;
|
||||
private currentRequestCancellation: vscode.CancellationTokenSource | null;
|
||||
private options: ApiHandlerOptions
|
||||
private client: vscode.LanguageModelChat | null
|
||||
private disposable: vscode.Disposable | null
|
||||
private currentRequestCancellation: vscode.CancellationTokenSource | null
|
||||
|
||||
constructor(options: ApiHandlerOptions) {
|
||||
this.options = options;
|
||||
this.client = null;
|
||||
this.disposable = null;
|
||||
this.currentRequestCancellation = null;
|
||||
this.options = options
|
||||
this.client = null
|
||||
this.disposable = null
|
||||
this.currentRequestCancellation = null
|
||||
|
||||
try {
|
||||
// Listen for model changes and reset client
|
||||
this.disposable = vscode.workspace.onDidChangeConfiguration(event => {
|
||||
if (event.affectsConfiguration('lm')) {
|
||||
this.disposable = vscode.workspace.onDidChangeConfiguration((event) => {
|
||||
if (event.affectsConfiguration("lm")) {
|
||||
try {
|
||||
this.client = null;
|
||||
this.ensureCleanState();
|
||||
}
|
||||
catch (error) {
|
||||
console.error('Error during configuration change cleanup:', error);
|
||||
this.client = null
|
||||
this.ensureCleanState()
|
||||
} catch (error) {
|
||||
console.error("Error during configuration change cleanup:", error)
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
catch (error) {
|
||||
})
|
||||
} catch (error) {
|
||||
// Ensure cleanup if constructor fails
|
||||
this.dispose();
|
||||
this.dispose()
|
||||
|
||||
throw new Error(
|
||||
`Cline <Language Model API>: Failed to initialize handler: ${error instanceof Error ? error.message : 'Unknown error'}`
|
||||
);
|
||||
`Cline <Language Model API>: Failed to initialize handler: ${error instanceof Error ? error.message : "Unknown error"}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,39 +81,39 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
||||
*/
|
||||
async createClient(selector: vscode.LanguageModelChatSelector): Promise<vscode.LanguageModelChat> {
|
||||
try {
|
||||
const models = await vscode.lm.selectChatModels(selector);
|
||||
const models = await vscode.lm.selectChatModels(selector)
|
||||
|
||||
// Use first available model or create a minimal model object
|
||||
if (models && Array.isArray(models) && models.length > 0) {
|
||||
return models[0];
|
||||
return models[0]
|
||||
}
|
||||
|
||||
// Create a minimal model if no models are available
|
||||
return {
|
||||
id: 'default-lm',
|
||||
name: 'Default Language Model',
|
||||
vendor: 'vscode',
|
||||
family: 'lm',
|
||||
version: '1.0',
|
||||
id: "default-lm",
|
||||
name: "Default Language Model",
|
||||
vendor: "vscode",
|
||||
family: "lm",
|
||||
version: "1.0",
|
||||
maxInputTokens: 8192,
|
||||
sendRequest: async (messages, options, token) => {
|
||||
// Provide a minimal implementation
|
||||
return {
|
||||
stream: (async function* () {
|
||||
yield new vscode.LanguageModelTextPart(
|
||||
"Language model functionality is limited. Please check VS Code configuration."
|
||||
);
|
||||
"Language model functionality is limited. Please check VS Code configuration.",
|
||||
)
|
||||
})(),
|
||||
text: (async function* () {
|
||||
yield "Language model functionality is limited. Please check VS Code configuration.";
|
||||
})()
|
||||
};
|
||||
yield "Language model functionality is limited. Please check VS Code configuration."
|
||||
})(),
|
||||
}
|
||||
},
|
||||
countTokens: async () => 0
|
||||
};
|
||||
countTokens: async () => 0,
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
throw new Error(`Cline <Language Model API>: Failed to select model: ${errorMessage}`);
|
||||
const errorMessage = error instanceof Error ? error.message : "Unknown error"
|
||||
throw new Error(`Cline <Language Model API>: Failed to select model: ${errorMessage}`)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,230 +134,222 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
||||
* Tool calls handling is currently a work in progress.
|
||||
*/
|
||||
dispose(): void {
|
||||
|
||||
if (this.disposable) {
|
||||
|
||||
this.disposable.dispose();
|
||||
this.disposable.dispose()
|
||||
}
|
||||
|
||||
if (this.currentRequestCancellation) {
|
||||
|
||||
this.currentRequestCancellation.cancel();
|
||||
this.currentRequestCancellation.dispose();
|
||||
this.currentRequestCancellation.cancel()
|
||||
this.currentRequestCancellation.dispose()
|
||||
}
|
||||
}
|
||||
|
||||
private async countTokens(text: string | vscode.LanguageModelChatMessage): Promise<number> {
|
||||
// Check for required dependencies
|
||||
if (!this.client) {
|
||||
console.warn('Cline <Language Model API>: No client available for token counting');
|
||||
return 0;
|
||||
console.warn("Cline <Language Model API>: No client available for token counting")
|
||||
return 0
|
||||
}
|
||||
|
||||
if (!this.currentRequestCancellation) {
|
||||
console.warn('Cline <Language Model API>: No cancellation token available for token counting');
|
||||
return 0;
|
||||
console.warn("Cline <Language Model API>: No cancellation token available for token counting")
|
||||
return 0
|
||||
}
|
||||
|
||||
// Validate input
|
||||
if (!text) {
|
||||
console.debug('Cline <Language Model API>: Empty text provided for token counting');
|
||||
return 0;
|
||||
console.debug("Cline <Language Model API>: Empty text provided for token counting")
|
||||
return 0
|
||||
}
|
||||
|
||||
try {
|
||||
// Handle different input types
|
||||
let tokenCount: number;
|
||||
let tokenCount: number
|
||||
|
||||
if (typeof text === 'string') {
|
||||
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token);
|
||||
if (typeof text === "string") {
|
||||
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
|
||||
} else if (text instanceof vscode.LanguageModelChatMessage) {
|
||||
// For chat messages, ensure we have content
|
||||
if (!text.content || (Array.isArray(text.content) && text.content.length === 0)) {
|
||||
console.debug('Cline <Language Model API>: Empty chat message content');
|
||||
return 0;
|
||||
console.debug("Cline <Language Model API>: Empty chat message content")
|
||||
return 0
|
||||
}
|
||||
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token);
|
||||
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
|
||||
} else {
|
||||
console.warn('Cline <Language Model API>: Invalid input type for token counting');
|
||||
return 0;
|
||||
console.warn("Cline <Language Model API>: Invalid input type for token counting")
|
||||
return 0
|
||||
}
|
||||
|
||||
// Validate the result
|
||||
if (typeof tokenCount !== 'number') {
|
||||
console.warn('Cline <Language Model API>: Non-numeric token count received:', tokenCount);
|
||||
return 0;
|
||||
if (typeof tokenCount !== "number") {
|
||||
console.warn("Cline <Language Model API>: Non-numeric token count received:", tokenCount)
|
||||
return 0
|
||||
}
|
||||
|
||||
if (tokenCount < 0) {
|
||||
console.warn('Cline <Language Model API>: Negative token count received:', tokenCount);
|
||||
return 0;
|
||||
console.warn("Cline <Language Model API>: Negative token count received:", tokenCount)
|
||||
return 0
|
||||
}
|
||||
|
||||
return tokenCount;
|
||||
}
|
||||
catch (error) {
|
||||
return tokenCount
|
||||
} catch (error) {
|
||||
// Handle specific error types
|
||||
if (error instanceof vscode.CancellationError) {
|
||||
console.debug('Cline <Language Model API>: Token counting cancelled by user');
|
||||
return 0;
|
||||
console.debug("Cline <Language Model API>: Token counting cancelled by user")
|
||||
return 0
|
||||
}
|
||||
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
console.warn('Cline <Language Model API>: Token counting failed:', errorMessage);
|
||||
const errorMessage = error instanceof Error ? error.message : "Unknown error"
|
||||
console.warn("Cline <Language Model API>: Token counting failed:", errorMessage)
|
||||
|
||||
// Log additional error details if available
|
||||
if (error instanceof Error && error.stack) {
|
||||
console.debug('Token counting error stack:', error.stack);
|
||||
console.debug("Token counting error stack:", error.stack)
|
||||
}
|
||||
|
||||
return 0; // Fallback to prevent stream interruption
|
||||
return 0 // Fallback to prevent stream interruption
|
||||
}
|
||||
}
|
||||
|
||||
private async calculateTotalInputTokens(systemPrompt: string, vsCodeLmMessages: vscode.LanguageModelChatMessage[]): Promise<number> {
|
||||
private async calculateTotalInputTokens(
|
||||
systemPrompt: string,
|
||||
vsCodeLmMessages: vscode.LanguageModelChatMessage[],
|
||||
): Promise<number> {
|
||||
const systemTokens: number = await this.countTokens(systemPrompt)
|
||||
|
||||
const systemTokens: number = await this.countTokens(systemPrompt);
|
||||
const messageTokens: number[] = await Promise.all(vsCodeLmMessages.map((msg) => this.countTokens(msg)))
|
||||
|
||||
const messageTokens: number[] = await Promise.all(
|
||||
vsCodeLmMessages.map(msg => this.countTokens(msg))
|
||||
);
|
||||
|
||||
return systemTokens + messageTokens.reduce(
|
||||
(sum: number, tokens: number): number => sum + tokens, 0
|
||||
);
|
||||
return systemTokens + messageTokens.reduce((sum: number, tokens: number): number => sum + tokens, 0)
|
||||
}
|
||||
|
||||
private ensureCleanState(): void {
|
||||
|
||||
if (this.currentRequestCancellation) {
|
||||
|
||||
this.currentRequestCancellation.cancel();
|
||||
this.currentRequestCancellation.dispose();
|
||||
this.currentRequestCancellation = null;
|
||||
this.currentRequestCancellation.cancel()
|
||||
this.currentRequestCancellation.dispose()
|
||||
this.currentRequestCancellation = null
|
||||
}
|
||||
}
|
||||
|
||||
private async getClient(): Promise<vscode.LanguageModelChat> {
|
||||
if (!this.client) {
|
||||
console.debug('Cline <Language Model API>: Getting client with options:', {
|
||||
console.debug("Cline <Language Model API>: Getting client with options:", {
|
||||
vsCodeLmModelSelector: this.options.vsCodeLmModelSelector,
|
||||
hasOptions: !!this.options,
|
||||
selectorKeys: this.options.vsCodeLmModelSelector ? Object.keys(this.options.vsCodeLmModelSelector) : []
|
||||
});
|
||||
selectorKeys: this.options.vsCodeLmModelSelector ? Object.keys(this.options.vsCodeLmModelSelector) : [],
|
||||
})
|
||||
|
||||
try {
|
||||
// Use default empty selector if none provided to get all available models
|
||||
const selector = this.options?.vsCodeLmModelSelector || {};
|
||||
console.debug('Cline <Language Model API>: Creating client with selector:', selector);
|
||||
this.client = await this.createClient(selector);
|
||||
const selector = this.options?.vsCodeLmModelSelector || {}
|
||||
console.debug("Cline <Language Model API>: Creating client with selector:", selector)
|
||||
this.client = await this.createClient(selector)
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : 'Unknown error';
|
||||
console.error('Cline <Language Model API>: Client creation failed:', message);
|
||||
throw new Error(`Cline <Language Model API>: Failed to create client: ${message}`);
|
||||
const message = error instanceof Error ? error.message : "Unknown error"
|
||||
console.error("Cline <Language Model API>: Client creation failed:", message)
|
||||
throw new Error(`Cline <Language Model API>: Failed to create client: ${message}`)
|
||||
}
|
||||
}
|
||||
|
||||
return this.client;
|
||||
return this.client
|
||||
}
|
||||
|
||||
private cleanTerminalOutput(text: string): string {
|
||||
if (!text) {
|
||||
return '';
|
||||
return ""
|
||||
}
|
||||
|
||||
return text
|
||||
return (
|
||||
text
|
||||
// Нормализуем переносы строк
|
||||
.replace(/\r\n/g, '\n')
|
||||
.replace(/\r/g, '\n')
|
||||
.replace(/\r\n/g, "\n")
|
||||
.replace(/\r/g, "\n")
|
||||
|
||||
// Удаляем ANSI escape sequences
|
||||
.replace(/\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])/g, '') // Полный набор ANSI sequences
|
||||
.replace(/\x9B[0-?]*[ -/]*[@-~]/g, '') // CSI sequences
|
||||
.replace(/\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])/g, "") // Полный набор ANSI sequences
|
||||
.replace(/\x9B[0-?]*[ -/]*[@-~]/g, "") // CSI sequences
|
||||
|
||||
// Удаляем последовательности установки заголовка терминала и прочие OSC sequences
|
||||
.replace(/\x1B\][0-9;]*(?:\x07|\x1B\\)/g, '')
|
||||
.replace(/\x1B\][0-9;]*(?:\x07|\x1B\\)/g, "")
|
||||
|
||||
// Удаляем управляющие символы
|
||||
.replace(/[\x00-\x09\x0B-\x0C\x0E-\x1F\x7F]/g, '')
|
||||
.replace(/[\x00-\x09\x0B-\x0C\x0E-\x1F\x7F]/g, "")
|
||||
|
||||
// Удаляем escape-последовательности VS Code
|
||||
.replace(/\x1B[PD].*?\x1B\\/g, '') // DCS sequences
|
||||
.replace(/\x1B_.*?\x1B\\/g, '') // APC sequences
|
||||
.replace(/\x1B\^.*?\x1B\\/g, '') // PM sequences
|
||||
.replace(/\x1B\[[\d;]*[HfABCDEFGJKST]/g, '') // Cursor movement and clear screen
|
||||
.replace(/\x1B[PD].*?\x1B\\/g, "") // DCS sequences
|
||||
.replace(/\x1B_.*?\x1B\\/g, "") // APC sequences
|
||||
.replace(/\x1B\^.*?\x1B\\/g, "") // PM sequences
|
||||
.replace(/\x1B\[[\d;]*[HfABCDEFGJKST]/g, "") // Cursor movement and clear screen
|
||||
|
||||
// Удаляем пути Windows и служебную информацию
|
||||
.replace(/^(?:PS )?[A-Z]:\\[^\n]*$/mg, '')
|
||||
.replace(/^;?Cwd=.*$/mg, '')
|
||||
.replace(/^(?:PS )?[A-Z]:\\[^\n]*$/gm, "")
|
||||
.replace(/^;?Cwd=.*$/gm, "")
|
||||
|
||||
// Очищаем экранированные последовательности
|
||||
.replace(/\\x[0-9a-fA-F]{2}/g, '')
|
||||
.replace(/\\u[0-9a-fA-F]{4}/g, '')
|
||||
.replace(/\\x[0-9a-fA-F]{2}/g, "")
|
||||
.replace(/\\u[0-9a-fA-F]{4}/g, "")
|
||||
|
||||
// Финальная очистка
|
||||
.replace(/\n{3,}/g, '\n\n') // Убираем множественные пустые строки
|
||||
.trim();
|
||||
.replace(/\n{3,}/g, "\n\n") // Убираем множественные пустые строки
|
||||
.trim()
|
||||
)
|
||||
}
|
||||
|
||||
private cleanMessageContent(content: any): any {
|
||||
if (!content) {
|
||||
return content;
|
||||
return content
|
||||
}
|
||||
|
||||
if (typeof content === 'string') {
|
||||
return this.cleanTerminalOutput(content);
|
||||
if (typeof content === "string") {
|
||||
return this.cleanTerminalOutput(content)
|
||||
}
|
||||
|
||||
if (Array.isArray(content)) {
|
||||
return content.map(item => this.cleanMessageContent(item));
|
||||
return content.map((item) => this.cleanMessageContent(item))
|
||||
}
|
||||
|
||||
if (typeof content === 'object') {
|
||||
const cleaned: any = {};
|
||||
if (typeof content === "object") {
|
||||
const cleaned: any = {}
|
||||
for (const [key, value] of Object.entries(content)) {
|
||||
cleaned[key] = this.cleanMessageContent(value);
|
||||
cleaned[key] = this.cleanMessageContent(value)
|
||||
}
|
||||
return cleaned;
|
||||
return cleaned
|
||||
}
|
||||
|
||||
return content;
|
||||
return content
|
||||
}
|
||||
|
||||
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
||||
|
||||
// Ensure clean state before starting a new request
|
||||
this.ensureCleanState();
|
||||
const client: vscode.LanguageModelChat = await this.getClient();
|
||||
this.ensureCleanState()
|
||||
const client: vscode.LanguageModelChat = await this.getClient()
|
||||
|
||||
// Clean system prompt and messages
|
||||
const cleanedSystemPrompt = this.cleanTerminalOutput(systemPrompt);
|
||||
const cleanedMessages = messages.map(msg => ({
|
||||
const cleanedSystemPrompt = this.cleanTerminalOutput(systemPrompt)
|
||||
const cleanedMessages = messages.map((msg) => ({
|
||||
...msg,
|
||||
content: this.cleanMessageContent(msg.content)
|
||||
}));
|
||||
content: this.cleanMessageContent(msg.content),
|
||||
}))
|
||||
|
||||
// Convert Anthropic messages to VS Code LM messages
|
||||
const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = [
|
||||
vscode.LanguageModelChatMessage.Assistant(cleanedSystemPrompt),
|
||||
...convertToVsCodeLmMessages(cleanedMessages),
|
||||
];
|
||||
]
|
||||
|
||||
// Initialize cancellation token for the request
|
||||
this.currentRequestCancellation = new vscode.CancellationTokenSource();
|
||||
this.currentRequestCancellation = new vscode.CancellationTokenSource()
|
||||
|
||||
// Calculate input tokens before starting the stream
|
||||
const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages);
|
||||
const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages)
|
||||
|
||||
// Accumulate the text and count at the end of the stream to reduce token counting overhead.
|
||||
let accumulatedText: string = '';
|
||||
let accumulatedText: string = ""
|
||||
|
||||
try {
|
||||
|
||||
// Create the response stream with minimal required options
|
||||
const requestOptions: vscode.LanguageModelChatRequestOptions = {
|
||||
justification: `Cline would like to use '${client.name}' from '${client.vendor}', Click 'Allow' to proceed.`
|
||||
};
|
||||
justification: `Cline would like to use '${client.name}' from '${client.vendor}', Click 'Allow' to proceed.`,
|
||||
}
|
||||
|
||||
// Note: Tool support is currently provided by the VSCode Language Model API directly
|
||||
// Extensions can register tools using vscode.lm.registerTool()
|
||||
@@ -368,40 +357,40 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
||||
const response: vscode.LanguageModelChatResponse = await client.sendRequest(
|
||||
vsCodeLmMessages,
|
||||
requestOptions,
|
||||
this.currentRequestCancellation.token
|
||||
);
|
||||
this.currentRequestCancellation.token,
|
||||
)
|
||||
|
||||
// Consume the stream and handle both text and tool call chunks
|
||||
for await (const chunk of response.stream) {
|
||||
if (chunk instanceof vscode.LanguageModelTextPart) {
|
||||
// Validate text part value
|
||||
if (typeof chunk.value !== 'string') {
|
||||
console.warn('Cline <Language Model API>: Invalid text part value received:', chunk.value);
|
||||
continue;
|
||||
if (typeof chunk.value !== "string") {
|
||||
console.warn("Cline <Language Model API>: Invalid text part value received:", chunk.value)
|
||||
continue
|
||||
}
|
||||
|
||||
accumulatedText += chunk.value;
|
||||
accumulatedText += chunk.value
|
||||
yield {
|
||||
type: "text",
|
||||
text: chunk.value,
|
||||
};
|
||||
}
|
||||
} else if (chunk instanceof vscode.LanguageModelToolCallPart) {
|
||||
try {
|
||||
// Validate tool call parameters
|
||||
if (!chunk.name || typeof chunk.name !== 'string') {
|
||||
console.warn('Cline <Language Model API>: Invalid tool name received:', chunk.name);
|
||||
continue;
|
||||
if (!chunk.name || typeof chunk.name !== "string") {
|
||||
console.warn("Cline <Language Model API>: Invalid tool name received:", chunk.name)
|
||||
continue
|
||||
}
|
||||
|
||||
if (!chunk.callId || typeof chunk.callId !== 'string') {
|
||||
console.warn('Cline <Language Model API>: Invalid tool callId received:', chunk.callId);
|
||||
continue;
|
||||
if (!chunk.callId || typeof chunk.callId !== "string") {
|
||||
console.warn("Cline <Language Model API>: Invalid tool callId received:", chunk.callId)
|
||||
continue
|
||||
}
|
||||
|
||||
// Ensure input is a valid object
|
||||
if (!chunk.input || typeof chunk.input !== 'object') {
|
||||
console.warn('Cline <Language Model API>: Invalid tool input received:', chunk.input);
|
||||
continue;
|
||||
if (!chunk.input || typeof chunk.input !== "object") {
|
||||
console.warn("Cline <Language Model API>: Invalid tool input received:", chunk.input)
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert tool calls to text format with proper error handling
|
||||
@@ -409,82 +398,75 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
||||
type: "tool_call",
|
||||
name: chunk.name,
|
||||
arguments: chunk.input,
|
||||
callId: chunk.callId
|
||||
};
|
||||
callId: chunk.callId,
|
||||
}
|
||||
|
||||
const toolCallText = JSON.stringify(toolCall);
|
||||
accumulatedText += toolCallText;
|
||||
const toolCallText = JSON.stringify(toolCall)
|
||||
accumulatedText += toolCallText
|
||||
|
||||
// Log tool call for debugging
|
||||
console.debug('Cline <Language Model API>: Processing tool call:', {
|
||||
console.debug("Cline <Language Model API>: Processing tool call:", {
|
||||
name: chunk.name,
|
||||
callId: chunk.callId,
|
||||
inputSize: JSON.stringify(chunk.input).length
|
||||
});
|
||||
inputSize: JSON.stringify(chunk.input).length,
|
||||
})
|
||||
|
||||
yield {
|
||||
type: "text",
|
||||
text: toolCallText,
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Cline <Language Model API>: Failed to process tool call:', error);
|
||||
console.error("Cline <Language Model API>: Failed to process tool call:", error)
|
||||
// Continue processing other chunks even if one fails
|
||||
continue;
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
console.warn('Cline <Language Model API>: Unknown chunk type received:', chunk);
|
||||
console.warn("Cline <Language Model API>: Unknown chunk type received:", chunk)
|
||||
}
|
||||
}
|
||||
|
||||
// Count tokens in the accumulated text after stream completion
|
||||
const totalOutputTokens: number = await this.countTokens(accumulatedText);
|
||||
const totalOutputTokens: number = await this.countTokens(accumulatedText)
|
||||
|
||||
// Report final usage after stream completion
|
||||
yield {
|
||||
type: "usage",
|
||||
inputTokens: totalInputTokens,
|
||||
outputTokens: totalOutputTokens,
|
||||
totalCost: calculateApiCost(
|
||||
this.getModel().info,
|
||||
totalInputTokens,
|
||||
totalOutputTokens
|
||||
)
|
||||
};
|
||||
totalCost: calculateApiCost(this.getModel().info, totalInputTokens, totalOutputTokens),
|
||||
}
|
||||
catch (error: unknown) {
|
||||
|
||||
this.ensureCleanState();
|
||||
} catch (error: unknown) {
|
||||
this.ensureCleanState()
|
||||
|
||||
if (error instanceof vscode.CancellationError) {
|
||||
|
||||
throw new Error("Cline <Language Model API>: Request cancelled by user");
|
||||
throw new Error("Cline <Language Model API>: Request cancelled by user")
|
||||
}
|
||||
|
||||
if (error instanceof Error) {
|
||||
console.error('Cline <Language Model API>: Stream error details:', {
|
||||
console.error("Cline <Language Model API>: Stream error details:", {
|
||||
message: error.message,
|
||||
stack: error.stack,
|
||||
name: error.name
|
||||
});
|
||||
name: error.name,
|
||||
})
|
||||
|
||||
// Return original error if it's already an Error instance
|
||||
throw error;
|
||||
} else if (typeof error === 'object' && error !== null) {
|
||||
throw error
|
||||
} else if (typeof error === "object" && error !== null) {
|
||||
// Handle error-like objects
|
||||
const errorDetails = JSON.stringify(error, null, 2);
|
||||
console.error('Cline <Language Model API>: Stream error object:', errorDetails);
|
||||
throw new Error(`Cline <Language Model API>: Response stream error: ${errorDetails}`);
|
||||
const errorDetails = JSON.stringify(error, null, 2)
|
||||
console.error("Cline <Language Model API>: Stream error object:", errorDetails)
|
||||
throw new Error(`Cline <Language Model API>: Response stream error: ${errorDetails}`)
|
||||
} else {
|
||||
// Fallback for unknown error types
|
||||
const errorMessage = String(error);
|
||||
console.error('Cline <Language Model API>: Unknown stream error:', errorMessage);
|
||||
throw new Error(`Cline <Language Model API>: Response stream error: ${errorMessage}`);
|
||||
const errorMessage = String(error)
|
||||
console.error("Cline <Language Model API>: Unknown stream error:", errorMessage)
|
||||
throw new Error(`Cline <Language Model API>: Response stream error: ${errorMessage}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return model information based on the current client state
|
||||
getModel(): { id: string; info: ModelInfo; } {
|
||||
getModel(): { id: string; info: ModelInfo } {
|
||||
if (this.client) {
|
||||
// Validate client properties
|
||||
const requiredProps = {
|
||||
@@ -492,68 +474,69 @@ export class VsCodeLmHandler implements ApiHandler, SingleCompletionHandler {
|
||||
vendor: this.client.vendor,
|
||||
family: this.client.family,
|
||||
version: this.client.version,
|
||||
maxInputTokens: this.client.maxInputTokens
|
||||
};
|
||||
maxInputTokens: this.client.maxInputTokens,
|
||||
}
|
||||
|
||||
// Log any missing properties for debugging
|
||||
for (const [prop, value] of Object.entries(requiredProps)) {
|
||||
if (!value && value !== 0) {
|
||||
console.warn(`Cline <Language Model API>: Client missing ${prop} property`);
|
||||
console.warn(`Cline <Language Model API>: Client missing ${prop} property`)
|
||||
}
|
||||
}
|
||||
|
||||
// Construct model ID using available information
|
||||
const modelParts = [
|
||||
this.client.vendor,
|
||||
this.client.family,
|
||||
this.client.version
|
||||
].filter(Boolean);
|
||||
const modelParts = [this.client.vendor, this.client.family, this.client.version].filter(Boolean)
|
||||
|
||||
const modelId = this.client.id || modelParts.join(SELECTOR_SEPARATOR);
|
||||
const modelId = this.client.id || modelParts.join(SELECTOR_SEPARATOR)
|
||||
|
||||
// Build model info with conservative defaults for missing values
|
||||
const modelInfo: ModelInfo = {
|
||||
maxTokens: -1, // Unlimited tokens by default
|
||||
contextWindow: typeof this.client.maxInputTokens === 'number'
|
||||
contextWindow:
|
||||
typeof this.client.maxInputTokens === "number"
|
||||
? Math.max(0, this.client.maxInputTokens)
|
||||
: openAiModelInfoSaneDefaults.contextWindow,
|
||||
supportsImages: false, // VSCode Language Model API currently doesn't support image inputs
|
||||
supportsPromptCache: true,
|
||||
inputPrice: 0,
|
||||
outputPrice: 0,
|
||||
description: `VSCode Language Model: ${modelId}`
|
||||
};
|
||||
description: `VSCode Language Model: ${modelId}`,
|
||||
}
|
||||
|
||||
return { id: modelId, info: modelInfo };
|
||||
return { id: modelId, info: modelInfo }
|
||||
}
|
||||
|
||||
// Fallback when no client is available
|
||||
const fallbackId = this.options.vsCodeLmModelSelector
|
||||
? stringifyVsCodeLmModelSelector(this.options.vsCodeLmModelSelector)
|
||||
: "vscode-lm";
|
||||
: "vscode-lm"
|
||||
|
||||
console.debug('Cline <Language Model API>: No client available, using fallback model info');
|
||||
console.debug("Cline <Language Model API>: No client available, using fallback model info")
|
||||
|
||||
return {
|
||||
id: fallbackId,
|
||||
info: {
|
||||
...openAiModelInfoSaneDefaults,
|
||||
description: `VSCode Language Model (Fallback): ${fallbackId}`
|
||||
description: `VSCode Language Model (Fallback): ${fallbackId}`,
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
async completePrompt(prompt: string): Promise<string> {
|
||||
try {
|
||||
const client = await this.getClient();
|
||||
const response = await client.sendRequest([vscode.LanguageModelChatMessage.User(prompt)], {}, new vscode.CancellationTokenSource().token);
|
||||
let result = "";
|
||||
const client = await this.getClient()
|
||||
const response = await client.sendRequest(
|
||||
[vscode.LanguageModelChatMessage.User(prompt)],
|
||||
{},
|
||||
new vscode.CancellationTokenSource().token,
|
||||
)
|
||||
let result = ""
|
||||
for await (const chunk of response.stream) {
|
||||
if (chunk instanceof vscode.LanguageModelTextPart) {
|
||||
result += chunk.value;
|
||||
result += chunk.value
|
||||
}
|
||||
}
|
||||
return result;
|
||||
return result
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`VSCode LM completion error: ${error.message}`)
|
||||
|
||||
@@ -1,251 +1,249 @@
|
||||
import { convertToBedrockConverseMessages, convertToAnthropicMessage } from '../bedrock-converse-format'
|
||||
import { Anthropic } from '@anthropic-ai/sdk'
|
||||
import { ContentBlock, ToolResultContentBlock } from '@aws-sdk/client-bedrock-runtime'
|
||||
import { StreamEvent } from '../../providers/bedrock'
|
||||
import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../bedrock-converse-format"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { ContentBlock, ToolResultContentBlock } from "@aws-sdk/client-bedrock-runtime"
|
||||
import { StreamEvent } from "../../providers/bedrock"
|
||||
|
||||
describe('bedrock-converse-format', () => {
|
||||
describe('convertToBedrockConverseMessages', () => {
|
||||
test('converts simple text messages correctly', () => {
|
||||
describe("bedrock-converse-format", () => {
|
||||
describe("convertToBedrockConverseMessages", () => {
|
||||
test("converts simple text messages correctly", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{ role: 'user', content: 'Hello' },
|
||||
{ role: 'assistant', content: 'Hi there' }
|
||||
{ role: "user", content: "Hello" },
|
||||
{ role: "assistant", content: "Hi there" },
|
||||
]
|
||||
|
||||
const result = convertToBedrockConverseMessages(messages)
|
||||
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ text: 'Hello' }]
|
||||
role: "user",
|
||||
content: [{ text: "Hello" }],
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [{ text: 'Hi there' }]
|
||||
}
|
||||
role: "assistant",
|
||||
content: [{ text: "Hi there" }],
|
||||
},
|
||||
])
|
||||
})
|
||||
|
||||
test('converts messages with images correctly', () => {
|
||||
test("converts messages with images correctly", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Look at this image:'
|
||||
type: "text",
|
||||
text: "Look at this image:",
|
||||
},
|
||||
{
|
||||
type: 'image',
|
||||
type: "image",
|
||||
source: {
|
||||
type: 'base64',
|
||||
data: 'SGVsbG8=', // "Hello" in base64
|
||||
media_type: 'image/jpeg' as const
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
type: "base64",
|
||||
data: "SGVsbG8=", // "Hello" in base64
|
||||
media_type: "image/jpeg" as const,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
const result = convertToBedrockConverseMessages(messages)
|
||||
|
||||
if (!result[0] || !result[0].content) {
|
||||
fail('Expected result to have content')
|
||||
fail("Expected result to have content")
|
||||
return
|
||||
}
|
||||
|
||||
expect(result[0].role).toBe('user')
|
||||
expect(result[0].role).toBe("user")
|
||||
expect(result[0].content).toHaveLength(2)
|
||||
expect(result[0].content[0]).toEqual({ text: 'Look at this image:' })
|
||||
expect(result[0].content[0]).toEqual({ text: "Look at this image:" })
|
||||
|
||||
const imageBlock = result[0].content[1] as ContentBlock
|
||||
if ('image' in imageBlock && imageBlock.image && imageBlock.image.source) {
|
||||
expect(imageBlock.image.format).toBe('jpeg')
|
||||
if ("image" in imageBlock && imageBlock.image && imageBlock.image.source) {
|
||||
expect(imageBlock.image.format).toBe("jpeg")
|
||||
expect(imageBlock.image.source).toBeDefined()
|
||||
expect(imageBlock.image.source.bytes).toBeDefined()
|
||||
} else {
|
||||
fail('Expected image block not found')
|
||||
fail("Expected image block not found")
|
||||
}
|
||||
})
|
||||
|
||||
test('converts tool use messages correctly', () => {
|
||||
test("converts tool use messages correctly", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'assistant',
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: 'tool_use',
|
||||
id: 'test-id',
|
||||
name: 'read_file',
|
||||
type: "tool_use",
|
||||
id: "test-id",
|
||||
name: "read_file",
|
||||
input: {
|
||||
path: 'test.txt'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
path: "test.txt",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
const result = convertToBedrockConverseMessages(messages)
|
||||
|
||||
if (!result[0] || !result[0].content) {
|
||||
fail('Expected result to have content')
|
||||
fail("Expected result to have content")
|
||||
return
|
||||
}
|
||||
|
||||
expect(result[0].role).toBe('assistant')
|
||||
expect(result[0].role).toBe("assistant")
|
||||
const toolBlock = result[0].content[0] as ContentBlock
|
||||
if ('toolUse' in toolBlock && toolBlock.toolUse) {
|
||||
if ("toolUse" in toolBlock && toolBlock.toolUse) {
|
||||
expect(toolBlock.toolUse).toEqual({
|
||||
toolUseId: 'test-id',
|
||||
name: 'read_file',
|
||||
input: '<read_file>\n<path>\ntest.txt\n</path>\n</read_file>'
|
||||
toolUseId: "test-id",
|
||||
name: "read_file",
|
||||
input: "<read_file>\n<path>\ntest.txt\n</path>\n</read_file>",
|
||||
})
|
||||
} else {
|
||||
fail('Expected tool use block not found')
|
||||
fail("Expected tool use block not found")
|
||||
}
|
||||
})
|
||||
|
||||
test('converts tool result messages correctly', () => {
|
||||
test("converts tool result messages correctly", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'assistant',
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: 'test-id',
|
||||
content: [{ type: 'text', text: 'File contents here' }]
|
||||
}
|
||||
]
|
||||
}
|
||||
type: "tool_result",
|
||||
tool_use_id: "test-id",
|
||||
content: [{ type: "text", text: "File contents here" }],
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
const result = convertToBedrockConverseMessages(messages)
|
||||
|
||||
if (!result[0] || !result[0].content) {
|
||||
fail('Expected result to have content')
|
||||
fail("Expected result to have content")
|
||||
return
|
||||
}
|
||||
|
||||
expect(result[0].role).toBe('assistant')
|
||||
expect(result[0].role).toBe("assistant")
|
||||
const resultBlock = result[0].content[0] as ContentBlock
|
||||
if ('toolResult' in resultBlock && resultBlock.toolResult) {
|
||||
const expectedContent: ToolResultContentBlock[] = [
|
||||
{ text: 'File contents here' }
|
||||
]
|
||||
if ("toolResult" in resultBlock && resultBlock.toolResult) {
|
||||
const expectedContent: ToolResultContentBlock[] = [{ text: "File contents here" }]
|
||||
expect(resultBlock.toolResult).toEqual({
|
||||
toolUseId: 'test-id',
|
||||
toolUseId: "test-id",
|
||||
content: expectedContent,
|
||||
status: 'success'
|
||||
status: "success",
|
||||
})
|
||||
} else {
|
||||
fail('Expected tool result block not found')
|
||||
fail("Expected tool result block not found")
|
||||
}
|
||||
})
|
||||
|
||||
test('handles text content correctly', () => {
|
||||
test("handles text content correctly", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Hello world'
|
||||
}
|
||||
]
|
||||
}
|
||||
type: "text",
|
||||
text: "Hello world",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
const result = convertToBedrockConverseMessages(messages)
|
||||
|
||||
if (!result[0] || !result[0].content) {
|
||||
fail('Expected result to have content')
|
||||
fail("Expected result to have content")
|
||||
return
|
||||
}
|
||||
|
||||
expect(result[0].role).toBe('user')
|
||||
expect(result[0].role).toBe("user")
|
||||
expect(result[0].content).toHaveLength(1)
|
||||
const textBlock = result[0].content[0] as ContentBlock
|
||||
expect(textBlock).toEqual({ text: 'Hello world' })
|
||||
expect(textBlock).toEqual({ text: "Hello world" })
|
||||
})
|
||||
})
|
||||
|
||||
describe('convertToAnthropicMessage', () => {
|
||||
test('converts metadata events correctly', () => {
|
||||
describe("convertToAnthropicMessage", () => {
|
||||
test("converts metadata events correctly", () => {
|
||||
const event: StreamEvent = {
|
||||
metadata: {
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
outputTokens: 20
|
||||
}
|
||||
}
|
||||
outputTokens: 20,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const result = convertToAnthropicMessage(event, 'test-model')
|
||||
const result = convertToAnthropicMessage(event, "test-model")
|
||||
|
||||
expect(result).toEqual({
|
||||
id: '',
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
model: 'test-model',
|
||||
id: "",
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
model: "test-model",
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 20
|
||||
}
|
||||
output_tokens: 20,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
test('converts content block start events correctly', () => {
|
||||
test("converts content block start events correctly", () => {
|
||||
const event: StreamEvent = {
|
||||
contentBlockStart: {
|
||||
start: {
|
||||
text: 'Hello'
|
||||
}
|
||||
}
|
||||
text: "Hello",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const result = convertToAnthropicMessage(event, 'test-model')
|
||||
const result = convertToAnthropicMessage(event, "test-model")
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Hello' }],
|
||||
model: 'test-model'
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "Hello" }],
|
||||
model: "test-model",
|
||||
})
|
||||
})
|
||||
|
||||
test('converts content block delta events correctly', () => {
|
||||
test("converts content block delta events correctly", () => {
|
||||
const event: StreamEvent = {
|
||||
contentBlockDelta: {
|
||||
delta: {
|
||||
text: ' world'
|
||||
}
|
||||
}
|
||||
text: " world",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const result = convertToAnthropicMessage(event, 'test-model')
|
||||
const result = convertToAnthropicMessage(event, "test-model")
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: ' world' }],
|
||||
model: 'test-model'
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: " world" }],
|
||||
model: "test-model",
|
||||
})
|
||||
})
|
||||
|
||||
test('converts message stop events correctly', () => {
|
||||
test("converts message stop events correctly", () => {
|
||||
const event: StreamEvent = {
|
||||
messageStop: {
|
||||
stopReason: 'end_turn' as const
|
||||
}
|
||||
stopReason: "end_turn" as const,
|
||||
},
|
||||
}
|
||||
|
||||
const result = convertToAnthropicMessage(event, 'test-model')
|
||||
const result = convertToAnthropicMessage(event, "test-model")
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
stop_reason: 'end_turn',
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
stop_reason: "end_turn",
|
||||
stop_sequence: null,
|
||||
model: 'test-model'
|
||||
model: "test-model",
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,257 +1,275 @@
|
||||
import { convertToOpenAiMessages, convertToAnthropicMessage } from '../openai-format';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
import OpenAI from 'openai';
|
||||
import { convertToOpenAiMessages, convertToAnthropicMessage } from "../openai-format"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import OpenAI from "openai"
|
||||
|
||||
type PartialChatCompletion = Omit<OpenAI.Chat.Completions.ChatCompletion, 'choices'> & {
|
||||
choices: Array<Partial<OpenAI.Chat.Completions.ChatCompletion.Choice> & {
|
||||
message: OpenAI.Chat.Completions.ChatCompletion.Choice['message'];
|
||||
finish_reason: string;
|
||||
index: number;
|
||||
}>;
|
||||
};
|
||||
type PartialChatCompletion = Omit<OpenAI.Chat.Completions.ChatCompletion, "choices"> & {
|
||||
choices: Array<
|
||||
Partial<OpenAI.Chat.Completions.ChatCompletion.Choice> & {
|
||||
message: OpenAI.Chat.Completions.ChatCompletion.Choice["message"]
|
||||
finish_reason: string
|
||||
index: number
|
||||
}
|
||||
>
|
||||
}
|
||||
|
||||
describe('OpenAI Format Transformations', () => {
|
||||
describe('convertToOpenAiMessages', () => {
|
||||
it('should convert simple text messages', () => {
|
||||
describe("OpenAI Format Transformations", () => {
|
||||
describe("convertToOpenAiMessages", () => {
|
||||
it("should convert simple text messages", () => {
|
||||
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello'
|
||||
role: "user",
|
||||
content: "Hello",
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi there!'
|
||||
}
|
||||
];
|
||||
role: "assistant",
|
||||
content: "Hi there!",
|
||||
},
|
||||
]
|
||||
|
||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages);
|
||||
expect(openAiMessages).toHaveLength(2);
|
||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages)
|
||||
expect(openAiMessages).toHaveLength(2)
|
||||
expect(openAiMessages[0]).toEqual({
|
||||
role: 'user',
|
||||
content: 'Hello'
|
||||
});
|
||||
role: "user",
|
||||
content: "Hello",
|
||||
})
|
||||
expect(openAiMessages[1]).toEqual({
|
||||
role: 'assistant',
|
||||
content: 'Hi there!'
|
||||
});
|
||||
});
|
||||
role: "assistant",
|
||||
content: "Hi there!",
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle messages with image content', () => {
|
||||
it("should handle messages with image content", () => {
|
||||
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'What is in this image?'
|
||||
type: "text",
|
||||
text: "What is in this image?",
|
||||
},
|
||||
{
|
||||
type: 'image',
|
||||
type: "image",
|
||||
source: {
|
||||
type: 'base64',
|
||||
media_type: 'image/jpeg',
|
||||
data: 'base64data'
|
||||
}
|
||||
}
|
||||
type: "base64",
|
||||
media_type: "image/jpeg",
|
||||
data: "base64data",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
}
|
||||
];
|
||||
|
||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages);
|
||||
expect(openAiMessages).toHaveLength(1);
|
||||
expect(openAiMessages[0].role).toBe('user');
|
||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages)
|
||||
expect(openAiMessages).toHaveLength(1)
|
||||
expect(openAiMessages[0].role).toBe("user")
|
||||
|
||||
const content = openAiMessages[0].content as Array<{
|
||||
type: string;
|
||||
text?: string;
|
||||
image_url?: { url: string };
|
||||
}>;
|
||||
type: string
|
||||
text?: string
|
||||
image_url?: { url: string }
|
||||
}>
|
||||
|
||||
expect(Array.isArray(content)).toBe(true);
|
||||
expect(content).toHaveLength(2);
|
||||
expect(content[0]).toEqual({ type: 'text', text: 'What is in this image?' });
|
||||
expect(Array.isArray(content)).toBe(true)
|
||||
expect(content).toHaveLength(2)
|
||||
expect(content[0]).toEqual({ type: "text", text: "What is in this image?" })
|
||||
expect(content[1]).toEqual({
|
||||
type: 'image_url',
|
||||
image_url: { url: '' }
|
||||
});
|
||||
});
|
||||
type: "image_url",
|
||||
image_url: { url: "" },
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle assistant messages with tool use', () => {
|
||||
it("should handle assistant messages with tool use", () => {
|
||||
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'assistant',
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Let me check the weather.'
|
||||
type: "text",
|
||||
text: "Let me check the weather.",
|
||||
},
|
||||
{
|
||||
type: 'tool_use',
|
||||
id: 'weather-123',
|
||||
name: 'get_weather',
|
||||
input: { city: 'London' }
|
||||
}
|
||||
type: "tool_use",
|
||||
id: "weather-123",
|
||||
name: "get_weather",
|
||||
input: { city: "London" },
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
}
|
||||
];
|
||||
|
||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages);
|
||||
expect(openAiMessages).toHaveLength(1);
|
||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages)
|
||||
expect(openAiMessages).toHaveLength(1)
|
||||
|
||||
const assistantMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionAssistantMessageParam;
|
||||
expect(assistantMessage.role).toBe('assistant');
|
||||
expect(assistantMessage.content).toBe('Let me check the weather.');
|
||||
expect(assistantMessage.tool_calls).toHaveLength(1);
|
||||
const assistantMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionAssistantMessageParam
|
||||
expect(assistantMessage.role).toBe("assistant")
|
||||
expect(assistantMessage.content).toBe("Let me check the weather.")
|
||||
expect(assistantMessage.tool_calls).toHaveLength(1)
|
||||
expect(assistantMessage.tool_calls![0]).toEqual({
|
||||
id: 'weather-123',
|
||||
type: 'function',
|
||||
id: "weather-123",
|
||||
type: "function",
|
||||
function: {
|
||||
name: 'get_weather',
|
||||
arguments: JSON.stringify({ city: 'London' })
|
||||
}
|
||||
});
|
||||
});
|
||||
name: "get_weather",
|
||||
arguments: JSON.stringify({ city: "London" }),
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle user messages with tool results', () => {
|
||||
it("should handle user messages with tool results", () => {
|
||||
const anthropicMessages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: 'weather-123',
|
||||
content: 'Current temperature in London: 20°C'
|
||||
}
|
||||
]
|
||||
}
|
||||
];
|
||||
|
||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages);
|
||||
expect(openAiMessages).toHaveLength(1);
|
||||
|
||||
const toolMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionToolMessageParam;
|
||||
expect(toolMessage.role).toBe('tool');
|
||||
expect(toolMessage.tool_call_id).toBe('weather-123');
|
||||
expect(toolMessage.content).toBe('Current temperature in London: 20°C');
|
||||
});
|
||||
});
|
||||
|
||||
describe('convertToAnthropicMessage', () => {
|
||||
it('should convert simple completion', () => {
|
||||
const openAiCompletion: PartialChatCompletion = {
|
||||
id: 'completion-123',
|
||||
model: 'gpt-4',
|
||||
choices: [{
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content: 'Hello there!',
|
||||
refusal: null
|
||||
type: "tool_result",
|
||||
tool_use_id: "weather-123",
|
||||
content: "Current temperature in London: 20°C",
|
||||
},
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
const openAiMessages = convertToOpenAiMessages(anthropicMessages)
|
||||
expect(openAiMessages).toHaveLength(1)
|
||||
|
||||
const toolMessage = openAiMessages[0] as OpenAI.Chat.ChatCompletionToolMessageParam
|
||||
expect(toolMessage.role).toBe("tool")
|
||||
expect(toolMessage.tool_call_id).toBe("weather-123")
|
||||
expect(toolMessage.content).toBe("Current temperature in London: 20°C")
|
||||
})
|
||||
})
|
||||
|
||||
describe("convertToAnthropicMessage", () => {
|
||||
it("should convert simple completion", () => {
|
||||
const openAiCompletion: PartialChatCompletion = {
|
||||
id: "completion-123",
|
||||
model: "gpt-4",
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: "Hello there!",
|
||||
refusal: null,
|
||||
},
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
total_tokens: 15,
|
||||
},
|
||||
created: 123456789,
|
||||
object: 'chat.completion'
|
||||
};
|
||||
object: "chat.completion",
|
||||
}
|
||||
|
||||
const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion);
|
||||
expect(anthropicMessage.id).toBe('completion-123');
|
||||
expect(anthropicMessage.role).toBe('assistant');
|
||||
expect(anthropicMessage.content).toHaveLength(1);
|
||||
const anthropicMessage = convertToAnthropicMessage(
|
||||
openAiCompletion as OpenAI.Chat.Completions.ChatCompletion,
|
||||
)
|
||||
expect(anthropicMessage.id).toBe("completion-123")
|
||||
expect(anthropicMessage.role).toBe("assistant")
|
||||
expect(anthropicMessage.content).toHaveLength(1)
|
||||
expect(anthropicMessage.content[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Hello there!'
|
||||
});
|
||||
expect(anthropicMessage.stop_reason).toBe('end_turn');
|
||||
type: "text",
|
||||
text: "Hello there!",
|
||||
})
|
||||
expect(anthropicMessage.stop_reason).toBe("end_turn")
|
||||
expect(anthropicMessage.usage).toEqual({
|
||||
input_tokens: 10,
|
||||
output_tokens: 5
|
||||
});
|
||||
});
|
||||
output_tokens: 5,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle tool calls in completion', () => {
|
||||
it("should handle tool calls in completion", () => {
|
||||
const openAiCompletion: PartialChatCompletion = {
|
||||
id: 'completion-123',
|
||||
model: 'gpt-4',
|
||||
choices: [{
|
||||
id: "completion-123",
|
||||
model: "gpt-4",
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content: 'Let me check the weather.',
|
||||
tool_calls: [{
|
||||
id: 'weather-123',
|
||||
type: 'function',
|
||||
role: "assistant",
|
||||
content: "Let me check the weather.",
|
||||
tool_calls: [
|
||||
{
|
||||
id: "weather-123",
|
||||
type: "function",
|
||||
function: {
|
||||
name: 'get_weather',
|
||||
arguments: '{"city":"London"}'
|
||||
}
|
||||
}],
|
||||
refusal: null
|
||||
name: "get_weather",
|
||||
arguments: '{"city":"London"}',
|
||||
},
|
||||
finish_reason: 'tool_calls',
|
||||
index: 0
|
||||
}],
|
||||
},
|
||||
],
|
||||
refusal: null,
|
||||
},
|
||||
finish_reason: "tool_calls",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 15,
|
||||
completion_tokens: 8,
|
||||
total_tokens: 23
|
||||
total_tokens: 23,
|
||||
},
|
||||
created: 123456789,
|
||||
object: 'chat.completion'
|
||||
};
|
||||
|
||||
const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion);
|
||||
expect(anthropicMessage.content).toHaveLength(2);
|
||||
expect(anthropicMessage.content[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Let me check the weather.'
|
||||
});
|
||||
expect(anthropicMessage.content[1]).toEqual({
|
||||
type: 'tool_use',
|
||||
id: 'weather-123',
|
||||
name: 'get_weather',
|
||||
input: { city: 'London' }
|
||||
});
|
||||
expect(anthropicMessage.stop_reason).toBe('tool_use');
|
||||
});
|
||||
|
||||
it('should handle invalid tool call arguments', () => {
|
||||
const openAiCompletion: PartialChatCompletion = {
|
||||
id: 'completion-123',
|
||||
model: 'gpt-4',
|
||||
choices: [{
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content: 'Testing invalid arguments',
|
||||
tool_calls: [{
|
||||
id: 'test-123',
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'test_function',
|
||||
arguments: 'invalid json'
|
||||
object: "chat.completion",
|
||||
}
|
||||
}],
|
||||
refusal: null
|
||||
},
|
||||
finish_reason: 'tool_calls',
|
||||
index: 0
|
||||
}],
|
||||
created: 123456789,
|
||||
object: 'chat.completion'
|
||||
};
|
||||
|
||||
const anthropicMessage = convertToAnthropicMessage(openAiCompletion as OpenAI.Chat.Completions.ChatCompletion);
|
||||
expect(anthropicMessage.content).toHaveLength(2);
|
||||
const anthropicMessage = convertToAnthropicMessage(
|
||||
openAiCompletion as OpenAI.Chat.Completions.ChatCompletion,
|
||||
)
|
||||
expect(anthropicMessage.content).toHaveLength(2)
|
||||
expect(anthropicMessage.content[0]).toEqual({
|
||||
type: "text",
|
||||
text: "Let me check the weather.",
|
||||
})
|
||||
expect(anthropicMessage.content[1]).toEqual({
|
||||
type: 'tool_use',
|
||||
id: 'test-123',
|
||||
name: 'test_function',
|
||||
input: {} // Should default to empty object for invalid JSON
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
type: "tool_use",
|
||||
id: "weather-123",
|
||||
name: "get_weather",
|
||||
input: { city: "London" },
|
||||
})
|
||||
expect(anthropicMessage.stop_reason).toBe("tool_use")
|
||||
})
|
||||
|
||||
it("should handle invalid tool call arguments", () => {
|
||||
const openAiCompletion: PartialChatCompletion = {
|
||||
id: "completion-123",
|
||||
model: "gpt-4",
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: "Testing invalid arguments",
|
||||
tool_calls: [
|
||||
{
|
||||
id: "test-123",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "test_function",
|
||||
arguments: "invalid json",
|
||||
},
|
||||
},
|
||||
],
|
||||
refusal: null,
|
||||
},
|
||||
finish_reason: "tool_calls",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
created: 123456789,
|
||||
object: "chat.completion",
|
||||
}
|
||||
|
||||
const anthropicMessage = convertToAnthropicMessage(
|
||||
openAiCompletion as OpenAI.Chat.Completions.ChatCompletion,
|
||||
)
|
||||
expect(anthropicMessage.content).toHaveLength(2)
|
||||
expect(anthropicMessage.content[1]).toEqual({
|
||||
type: "tool_use",
|
||||
id: "test-123",
|
||||
name: "test_function",
|
||||
input: {}, // Should default to empty object for invalid JSON
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,114 +1,114 @@
|
||||
import { ApiStreamChunk } from '../stream';
|
||||
import { ApiStreamChunk } from "../stream"
|
||||
|
||||
describe('API Stream Types', () => {
|
||||
describe('ApiStreamChunk', () => {
|
||||
it('should correctly handle text chunks', () => {
|
||||
describe("API Stream Types", () => {
|
||||
describe("ApiStreamChunk", () => {
|
||||
it("should correctly handle text chunks", () => {
|
||||
const textChunk: ApiStreamChunk = {
|
||||
type: 'text',
|
||||
text: 'Hello world'
|
||||
};
|
||||
type: "text",
|
||||
text: "Hello world",
|
||||
}
|
||||
|
||||
expect(textChunk.type).toBe('text');
|
||||
expect(textChunk.text).toBe('Hello world');
|
||||
});
|
||||
expect(textChunk.type).toBe("text")
|
||||
expect(textChunk.text).toBe("Hello world")
|
||||
})
|
||||
|
||||
it('should correctly handle usage chunks with cache information', () => {
|
||||
it("should correctly handle usage chunks with cache information", () => {
|
||||
const usageChunk: ApiStreamChunk = {
|
||||
type: 'usage',
|
||||
type: "usage",
|
||||
inputTokens: 100,
|
||||
outputTokens: 50,
|
||||
cacheWriteTokens: 20,
|
||||
cacheReadTokens: 10
|
||||
};
|
||||
cacheReadTokens: 10,
|
||||
}
|
||||
|
||||
expect(usageChunk.type).toBe('usage');
|
||||
expect(usageChunk.inputTokens).toBe(100);
|
||||
expect(usageChunk.outputTokens).toBe(50);
|
||||
expect(usageChunk.cacheWriteTokens).toBe(20);
|
||||
expect(usageChunk.cacheReadTokens).toBe(10);
|
||||
});
|
||||
expect(usageChunk.type).toBe("usage")
|
||||
expect(usageChunk.inputTokens).toBe(100)
|
||||
expect(usageChunk.outputTokens).toBe(50)
|
||||
expect(usageChunk.cacheWriteTokens).toBe(20)
|
||||
expect(usageChunk.cacheReadTokens).toBe(10)
|
||||
})
|
||||
|
||||
it('should handle usage chunks without cache tokens', () => {
|
||||
it("should handle usage chunks without cache tokens", () => {
|
||||
const usageChunk: ApiStreamChunk = {
|
||||
type: 'usage',
|
||||
type: "usage",
|
||||
inputTokens: 100,
|
||||
outputTokens: 50
|
||||
};
|
||||
outputTokens: 50,
|
||||
}
|
||||
|
||||
expect(usageChunk.type).toBe('usage');
|
||||
expect(usageChunk.inputTokens).toBe(100);
|
||||
expect(usageChunk.outputTokens).toBe(50);
|
||||
expect(usageChunk.cacheWriteTokens).toBeUndefined();
|
||||
expect(usageChunk.cacheReadTokens).toBeUndefined();
|
||||
});
|
||||
expect(usageChunk.type).toBe("usage")
|
||||
expect(usageChunk.inputTokens).toBe(100)
|
||||
expect(usageChunk.outputTokens).toBe(50)
|
||||
expect(usageChunk.cacheWriteTokens).toBeUndefined()
|
||||
expect(usageChunk.cacheReadTokens).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle text chunks with empty strings', () => {
|
||||
it("should handle text chunks with empty strings", () => {
|
||||
const emptyTextChunk: ApiStreamChunk = {
|
||||
type: 'text',
|
||||
text: ''
|
||||
};
|
||||
type: "text",
|
||||
text: "",
|
||||
}
|
||||
|
||||
expect(emptyTextChunk.type).toBe('text');
|
||||
expect(emptyTextChunk.text).toBe('');
|
||||
});
|
||||
expect(emptyTextChunk.type).toBe("text")
|
||||
expect(emptyTextChunk.text).toBe("")
|
||||
})
|
||||
|
||||
it('should handle usage chunks with zero tokens', () => {
|
||||
it("should handle usage chunks with zero tokens", () => {
|
||||
const zeroUsageChunk: ApiStreamChunk = {
|
||||
type: 'usage',
|
||||
type: "usage",
|
||||
inputTokens: 0,
|
||||
outputTokens: 0
|
||||
};
|
||||
outputTokens: 0,
|
||||
}
|
||||
|
||||
expect(zeroUsageChunk.type).toBe('usage');
|
||||
expect(zeroUsageChunk.inputTokens).toBe(0);
|
||||
expect(zeroUsageChunk.outputTokens).toBe(0);
|
||||
});
|
||||
expect(zeroUsageChunk.type).toBe("usage")
|
||||
expect(zeroUsageChunk.inputTokens).toBe(0)
|
||||
expect(zeroUsageChunk.outputTokens).toBe(0)
|
||||
})
|
||||
|
||||
it('should handle usage chunks with large token counts', () => {
|
||||
it("should handle usage chunks with large token counts", () => {
|
||||
const largeUsageChunk: ApiStreamChunk = {
|
||||
type: 'usage',
|
||||
type: "usage",
|
||||
inputTokens: 1000000,
|
||||
outputTokens: 500000,
|
||||
cacheWriteTokens: 200000,
|
||||
cacheReadTokens: 100000
|
||||
};
|
||||
cacheReadTokens: 100000,
|
||||
}
|
||||
|
||||
expect(largeUsageChunk.type).toBe('usage');
|
||||
expect(largeUsageChunk.inputTokens).toBe(1000000);
|
||||
expect(largeUsageChunk.outputTokens).toBe(500000);
|
||||
expect(largeUsageChunk.cacheWriteTokens).toBe(200000);
|
||||
expect(largeUsageChunk.cacheReadTokens).toBe(100000);
|
||||
});
|
||||
expect(largeUsageChunk.type).toBe("usage")
|
||||
expect(largeUsageChunk.inputTokens).toBe(1000000)
|
||||
expect(largeUsageChunk.outputTokens).toBe(500000)
|
||||
expect(largeUsageChunk.cacheWriteTokens).toBe(200000)
|
||||
expect(largeUsageChunk.cacheReadTokens).toBe(100000)
|
||||
})
|
||||
|
||||
it('should handle text chunks with special characters', () => {
|
||||
it("should handle text chunks with special characters", () => {
|
||||
const specialCharsChunk: ApiStreamChunk = {
|
||||
type: 'text',
|
||||
text: '!@#$%^&*()_+-=[]{}|;:,.<>?`~'
|
||||
};
|
||||
type: "text",
|
||||
text: "!@#$%^&*()_+-=[]{}|;:,.<>?`~",
|
||||
}
|
||||
|
||||
expect(specialCharsChunk.type).toBe('text');
|
||||
expect(specialCharsChunk.text).toBe('!@#$%^&*()_+-=[]{}|;:,.<>?`~');
|
||||
});
|
||||
expect(specialCharsChunk.type).toBe("text")
|
||||
expect(specialCharsChunk.text).toBe("!@#$%^&*()_+-=[]{}|;:,.<>?`~")
|
||||
})
|
||||
|
||||
it('should handle text chunks with unicode characters', () => {
|
||||
it("should handle text chunks with unicode characters", () => {
|
||||
const unicodeChunk: ApiStreamChunk = {
|
||||
type: 'text',
|
||||
text: '你好世界👋🌍'
|
||||
};
|
||||
type: "text",
|
||||
text: "你好世界👋🌍",
|
||||
}
|
||||
|
||||
expect(unicodeChunk.type).toBe('text');
|
||||
expect(unicodeChunk.text).toBe('你好世界👋🌍');
|
||||
});
|
||||
expect(unicodeChunk.type).toBe("text")
|
||||
expect(unicodeChunk.text).toBe("你好世界👋🌍")
|
||||
})
|
||||
|
||||
it('should handle text chunks with multiline content', () => {
|
||||
it("should handle text chunks with multiline content", () => {
|
||||
const multilineChunk: ApiStreamChunk = {
|
||||
type: 'text',
|
||||
text: 'Line 1\nLine 2\nLine 3'
|
||||
};
|
||||
type: "text",
|
||||
text: "Line 1\nLine 2\nLine 3",
|
||||
}
|
||||
|
||||
expect(multilineChunk.type).toBe('text');
|
||||
expect(multilineChunk.text).toBe('Line 1\nLine 2\nLine 3');
|
||||
expect(multilineChunk.text.split('\n')).toHaveLength(3);
|
||||
});
|
||||
});
|
||||
});
|
||||
expect(multilineChunk.type).toBe("text")
|
||||
expect(multilineChunk.text).toBe("Line 1\nLine 2\nLine 3")
|
||||
expect(multilineChunk.text.split("\n")).toHaveLength(3)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,66 +1,66 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk";
|
||||
import * as vscode from 'vscode';
|
||||
import { convertToVsCodeLmMessages, convertToAnthropicRole, convertToAnthropicMessage } from '../vscode-lm-format';
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import * as vscode from "vscode"
|
||||
import { convertToVsCodeLmMessages, convertToAnthropicRole, convertToAnthropicMessage } from "../vscode-lm-format"
|
||||
|
||||
// Mock crypto
|
||||
const mockCrypto = {
|
||||
randomUUID: () => 'test-uuid'
|
||||
};
|
||||
global.crypto = mockCrypto as any;
|
||||
randomUUID: () => "test-uuid",
|
||||
}
|
||||
global.crypto = mockCrypto as any
|
||||
|
||||
// Define types for our mocked classes
|
||||
interface MockLanguageModelTextPart {
|
||||
type: 'text';
|
||||
value: string;
|
||||
type: "text"
|
||||
value: string
|
||||
}
|
||||
|
||||
interface MockLanguageModelToolCallPart {
|
||||
type: 'tool_call';
|
||||
callId: string;
|
||||
name: string;
|
||||
input: any;
|
||||
type: "tool_call"
|
||||
callId: string
|
||||
name: string
|
||||
input: any
|
||||
}
|
||||
|
||||
interface MockLanguageModelToolResultPart {
|
||||
type: 'tool_result';
|
||||
toolUseId: string;
|
||||
parts: MockLanguageModelTextPart[];
|
||||
type: "tool_result"
|
||||
toolUseId: string
|
||||
parts: MockLanguageModelTextPart[]
|
||||
}
|
||||
|
||||
type MockMessageContent = MockLanguageModelTextPart | MockLanguageModelToolCallPart | MockLanguageModelToolResultPart;
|
||||
type MockMessageContent = MockLanguageModelTextPart | MockLanguageModelToolCallPart | MockLanguageModelToolResultPart
|
||||
|
||||
interface MockLanguageModelChatMessage {
|
||||
role: string;
|
||||
name?: string;
|
||||
content: MockMessageContent[];
|
||||
role: string
|
||||
name?: string
|
||||
content: MockMessageContent[]
|
||||
}
|
||||
|
||||
// Mock vscode namespace
|
||||
jest.mock('vscode', () => {
|
||||
jest.mock("vscode", () => {
|
||||
const LanguageModelChatMessageRole = {
|
||||
Assistant: 'assistant',
|
||||
User: 'user'
|
||||
};
|
||||
Assistant: "assistant",
|
||||
User: "user",
|
||||
}
|
||||
|
||||
class MockLanguageModelTextPart {
|
||||
type = 'text';
|
||||
type = "text"
|
||||
constructor(public value: string) {}
|
||||
}
|
||||
|
||||
class MockLanguageModelToolCallPart {
|
||||
type = 'tool_call';
|
||||
type = "tool_call"
|
||||
constructor(
|
||||
public callId: string,
|
||||
public name: string,
|
||||
public input: any
|
||||
public input: any,
|
||||
) {}
|
||||
}
|
||||
|
||||
class MockLanguageModelToolResultPart {
|
||||
type = 'tool_result';
|
||||
type = "tool_result"
|
||||
constructor(
|
||||
public toolUseId: string,
|
||||
public parts: MockLanguageModelTextPart[]
|
||||
public parts: MockLanguageModelTextPart[],
|
||||
) {}
|
||||
}
|
||||
|
||||
@@ -68,179 +68,189 @@ jest.mock('vscode', () => {
|
||||
LanguageModelChatMessage: {
|
||||
Assistant: jest.fn((content) => ({
|
||||
role: LanguageModelChatMessageRole.Assistant,
|
||||
name: 'assistant',
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
|
||||
name: "assistant",
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
|
||||
})),
|
||||
User: jest.fn((content) => ({
|
||||
role: LanguageModelChatMessageRole.User,
|
||||
name: 'user',
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)]
|
||||
}))
|
||||
name: "user",
|
||||
content: Array.isArray(content) ? content : [new MockLanguageModelTextPart(content)],
|
||||
})),
|
||||
},
|
||||
LanguageModelChatMessageRole,
|
||||
LanguageModelTextPart: MockLanguageModelTextPart,
|
||||
LanguageModelToolCallPart: MockLanguageModelToolCallPart,
|
||||
LanguageModelToolResultPart: MockLanguageModelToolResultPart
|
||||
};
|
||||
});
|
||||
LanguageModelToolResultPart: MockLanguageModelToolResultPart,
|
||||
}
|
||||
})
|
||||
|
||||
describe('vscode-lm-format', () => {
|
||||
describe('convertToVsCodeLmMessages', () => {
|
||||
it('should convert simple string messages', () => {
|
||||
describe("vscode-lm-format", () => {
|
||||
describe("convertToVsCodeLmMessages", () => {
|
||||
it("should convert simple string messages", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{ role: 'user', content: 'Hello' },
|
||||
{ role: 'assistant', content: 'Hi there' }
|
||||
];
|
||||
|
||||
const result = convertToVsCodeLmMessages(messages);
|
||||
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0].role).toBe('user');
|
||||
expect((result[0].content[0] as MockLanguageModelTextPart).value).toBe('Hello');
|
||||
expect(result[1].role).toBe('assistant');
|
||||
expect((result[1].content[0] as MockLanguageModelTextPart).value).toBe('Hi there');
|
||||
});
|
||||
|
||||
it('should handle complex user messages with tool results', () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Here is the result:' },
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: 'tool-1',
|
||||
content: 'Tool output'
|
||||
}
|
||||
{ role: "user", content: "Hello" },
|
||||
{ role: "assistant", content: "Hi there" },
|
||||
]
|
||||
}];
|
||||
|
||||
const result = convertToVsCodeLmMessages(messages);
|
||||
const result = convertToVsCodeLmMessages(messages)
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].role).toBe('user');
|
||||
expect(result[0].content).toHaveLength(2);
|
||||
const [toolResult, textContent] = result[0].content as [MockLanguageModelToolResultPart, MockLanguageModelTextPart];
|
||||
expect(toolResult.type).toBe('tool_result');
|
||||
expect(textContent.type).toBe('text');
|
||||
});
|
||||
expect(result).toHaveLength(2)
|
||||
expect(result[0].role).toBe("user")
|
||||
expect((result[0].content[0] as MockLanguageModelTextPart).value).toBe("Hello")
|
||||
expect(result[1].role).toBe("assistant")
|
||||
expect((result[1].content[0] as MockLanguageModelTextPart).value).toBe("Hi there")
|
||||
})
|
||||
|
||||
it('should handle complex assistant messages with tool calls', () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{ type: 'text', text: 'Let me help you with that.' },
|
||||
it("should handle complex user messages with tool results", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
type: 'tool_use',
|
||||
id: 'tool-1',
|
||||
name: 'calculator',
|
||||
input: { operation: 'add', numbers: [2, 2] }
|
||||
}
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: "Here is the result:" },
|
||||
{
|
||||
type: "tool_result",
|
||||
tool_use_id: "tool-1",
|
||||
content: "Tool output",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
}];
|
||||
|
||||
const result = convertToVsCodeLmMessages(messages);
|
||||
const result = convertToVsCodeLmMessages(messages)
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].role).toBe('assistant');
|
||||
expect(result[0].content).toHaveLength(2);
|
||||
const [toolCall, textContent] = result[0].content as [MockLanguageModelToolCallPart, MockLanguageModelTextPart];
|
||||
expect(toolCall.type).toBe('tool_call');
|
||||
expect(textContent.type).toBe('text');
|
||||
});
|
||||
expect(result).toHaveLength(1)
|
||||
expect(result[0].role).toBe("user")
|
||||
expect(result[0].content).toHaveLength(2)
|
||||
const [toolResult, textContent] = result[0].content as [
|
||||
MockLanguageModelToolResultPart,
|
||||
MockLanguageModelTextPart,
|
||||
]
|
||||
expect(toolResult.type).toBe("tool_result")
|
||||
expect(textContent.type).toBe("text")
|
||||
})
|
||||
|
||||
it('should handle image blocks with appropriate placeholders', () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Look at this:' },
|
||||
it("should handle complex assistant messages with tool calls", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
type: 'image',
|
||||
role: "assistant",
|
||||
content: [
|
||||
{ type: "text", text: "Let me help you with that." },
|
||||
{
|
||||
type: "tool_use",
|
||||
id: "tool-1",
|
||||
name: "calculator",
|
||||
input: { operation: "add", numbers: [2, 2] },
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
const result = convertToVsCodeLmMessages(messages)
|
||||
|
||||
expect(result).toHaveLength(1)
|
||||
expect(result[0].role).toBe("assistant")
|
||||
expect(result[0].content).toHaveLength(2)
|
||||
const [toolCall, textContent] = result[0].content as [
|
||||
MockLanguageModelToolCallPart,
|
||||
MockLanguageModelTextPart,
|
||||
]
|
||||
expect(toolCall.type).toBe("tool_call")
|
||||
expect(textContent.type).toBe("text")
|
||||
})
|
||||
|
||||
it("should handle image blocks with appropriate placeholders", () => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: "Look at this:" },
|
||||
{
|
||||
type: "image",
|
||||
source: {
|
||||
type: 'base64',
|
||||
media_type: 'image/png',
|
||||
data: 'base64data'
|
||||
}
|
||||
}
|
||||
type: "base64",
|
||||
media_type: "image/png",
|
||||
data: "base64data",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
}];
|
||||
|
||||
const result = convertToVsCodeLmMessages(messages);
|
||||
const result = convertToVsCodeLmMessages(messages)
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
const imagePlaceholder = result[0].content[1] as MockLanguageModelTextPart;
|
||||
expect(imagePlaceholder.value).toContain('[Image (base64): image/png not supported by VSCode LM API]');
|
||||
});
|
||||
});
|
||||
expect(result).toHaveLength(1)
|
||||
const imagePlaceholder = result[0].content[1] as MockLanguageModelTextPart
|
||||
expect(imagePlaceholder.value).toContain("[Image (base64): image/png not supported by VSCode LM API]")
|
||||
})
|
||||
})
|
||||
|
||||
describe('convertToAnthropicRole', () => {
|
||||
it('should convert assistant role correctly', () => {
|
||||
const result = convertToAnthropicRole('assistant' as any);
|
||||
expect(result).toBe('assistant');
|
||||
});
|
||||
describe("convertToAnthropicRole", () => {
|
||||
it("should convert assistant role correctly", () => {
|
||||
const result = convertToAnthropicRole("assistant" as any)
|
||||
expect(result).toBe("assistant")
|
||||
})
|
||||
|
||||
it('should convert user role correctly', () => {
|
||||
const result = convertToAnthropicRole('user' as any);
|
||||
expect(result).toBe('user');
|
||||
});
|
||||
it("should convert user role correctly", () => {
|
||||
const result = convertToAnthropicRole("user" as any)
|
||||
expect(result).toBe("user")
|
||||
})
|
||||
|
||||
it('should return null for unknown roles', () => {
|
||||
const result = convertToAnthropicRole('unknown' as any);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
it("should return null for unknown roles", () => {
|
||||
const result = convertToAnthropicRole("unknown" as any)
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('convertToAnthropicMessage', () => {
|
||||
it('should convert assistant message with text content', async () => {
|
||||
describe("convertToAnthropicMessage", () => {
|
||||
it("should convert assistant message with text content", async () => {
|
||||
const vsCodeMessage = {
|
||||
role: 'assistant',
|
||||
name: 'assistant',
|
||||
content: [new vscode.LanguageModelTextPart('Hello')]
|
||||
};
|
||||
role: "assistant",
|
||||
name: "assistant",
|
||||
content: [new vscode.LanguageModelTextPart("Hello")],
|
||||
}
|
||||
|
||||
const result = await convertToAnthropicMessage(vsCodeMessage as any);
|
||||
const result = await convertToAnthropicMessage(vsCodeMessage as any)
|
||||
|
||||
expect(result.role).toBe('assistant');
|
||||
expect(result.content).toHaveLength(1);
|
||||
expect(result.role).toBe("assistant")
|
||||
expect(result.content).toHaveLength(1)
|
||||
expect(result.content[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'Hello'
|
||||
});
|
||||
expect(result.id).toBe('test-uuid');
|
||||
});
|
||||
type: "text",
|
||||
text: "Hello",
|
||||
})
|
||||
expect(result.id).toBe("test-uuid")
|
||||
})
|
||||
|
||||
it('should convert assistant message with tool calls', async () => {
|
||||
it("should convert assistant message with tool calls", async () => {
|
||||
const vsCodeMessage = {
|
||||
role: 'assistant',
|
||||
name: 'assistant',
|
||||
content: [new vscode.LanguageModelToolCallPart(
|
||||
'call-1',
|
||||
'calculator',
|
||||
{ operation: 'add', numbers: [2, 2] }
|
||||
)]
|
||||
};
|
||||
role: "assistant",
|
||||
name: "assistant",
|
||||
content: [
|
||||
new vscode.LanguageModelToolCallPart("call-1", "calculator", { operation: "add", numbers: [2, 2] }),
|
||||
],
|
||||
}
|
||||
|
||||
const result = await convertToAnthropicMessage(vsCodeMessage as any);
|
||||
const result = await convertToAnthropicMessage(vsCodeMessage as any)
|
||||
|
||||
expect(result.content).toHaveLength(1);
|
||||
expect(result.content).toHaveLength(1)
|
||||
expect(result.content[0]).toEqual({
|
||||
type: 'tool_use',
|
||||
id: 'call-1',
|
||||
name: 'calculator',
|
||||
input: { operation: 'add', numbers: [2, 2] }
|
||||
});
|
||||
expect(result.id).toBe('test-uuid');
|
||||
});
|
||||
type: "tool_use",
|
||||
id: "call-1",
|
||||
name: "calculator",
|
||||
input: { operation: "add", numbers: [2, 2] },
|
||||
})
|
||||
expect(result.id).toBe("test-uuid")
|
||||
})
|
||||
|
||||
it('should throw error for non-assistant messages', async () => {
|
||||
it("should throw error for non-assistant messages", async () => {
|
||||
const vsCodeMessage = {
|
||||
role: 'user',
|
||||
name: 'user',
|
||||
content: [new vscode.LanguageModelTextPart('Hello')]
|
||||
};
|
||||
role: "user",
|
||||
name: "user",
|
||||
content: [new vscode.LanguageModelTextPart("Hello")],
|
||||
}
|
||||
|
||||
await expect(convertToAnthropicMessage(vsCodeMessage as any))
|
||||
.rejects
|
||||
.toThrow('Cline <Language Model API>: Only assistant messages are supported.');
|
||||
});
|
||||
});
|
||||
});
|
||||
await expect(convertToAnthropicMessage(vsCodeMessage as any)).rejects.toThrow(
|
||||
"Cline <Language Model API>: Only assistant messages are supported.",
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -8,41 +8,41 @@ import { StreamEvent } from "../providers/bedrock"
|
||||
/**
|
||||
* Convert Anthropic messages to Bedrock Converse format
|
||||
*/
|
||||
export function convertToBedrockConverseMessages(
|
||||
anthropicMessages: Anthropic.Messages.MessageParam[]
|
||||
): Message[] {
|
||||
return anthropicMessages.map(anthropicMessage => {
|
||||
export function convertToBedrockConverseMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): Message[] {
|
||||
return anthropicMessages.map((anthropicMessage) => {
|
||||
// Map Anthropic roles to Bedrock roles
|
||||
const role: ConversationRole = anthropicMessage.role === "assistant" ? "assistant" : "user"
|
||||
|
||||
if (typeof anthropicMessage.content === "string") {
|
||||
return {
|
||||
role,
|
||||
content: [{
|
||||
text: anthropicMessage.content
|
||||
}] as ContentBlock[]
|
||||
content: [
|
||||
{
|
||||
text: anthropicMessage.content,
|
||||
},
|
||||
] as ContentBlock[],
|
||||
}
|
||||
}
|
||||
|
||||
// Process complex content types
|
||||
const content = anthropicMessage.content.map(block => {
|
||||
const content = anthropicMessage.content.map((block) => {
|
||||
const messageBlock = block as MessageContent & {
|
||||
id?: string,
|
||||
tool_use_id?: string,
|
||||
content?: Array<{ type: string, text: string }>,
|
||||
output?: string | Array<{ type: string, text: string }>
|
||||
id?: string
|
||||
tool_use_id?: string
|
||||
content?: Array<{ type: string; text: string }>
|
||||
output?: string | Array<{ type: string; text: string }>
|
||||
}
|
||||
|
||||
if (messageBlock.type === "text") {
|
||||
return {
|
||||
text: messageBlock.text || ''
|
||||
text: messageBlock.text || "",
|
||||
} as ContentBlock
|
||||
}
|
||||
|
||||
if (messageBlock.type === "image" && messageBlock.source) {
|
||||
// Convert base64 string to byte array if needed
|
||||
let byteArray: Uint8Array
|
||||
if (typeof messageBlock.source.data === 'string') {
|
||||
if (typeof messageBlock.source.data === "string") {
|
||||
const binaryString = atob(messageBlock.source.data)
|
||||
byteArray = new Uint8Array(binaryString.length)
|
||||
for (let i = 0; i < binaryString.length; i++) {
|
||||
@@ -53,8 +53,8 @@ export function convertToBedrockConverseMessages(
|
||||
}
|
||||
|
||||
// Extract format from media_type (e.g., "image/jpeg" -> "jpeg")
|
||||
const format = messageBlock.source.media_type.split('/')[1]
|
||||
if (!['png', 'jpeg', 'gif', 'webp'].includes(format)) {
|
||||
const format = messageBlock.source.media_type.split("/")[1]
|
||||
if (!["png", "jpeg", "gif", "webp"].includes(format)) {
|
||||
throw new Error(`Unsupported image format: ${format}`)
|
||||
}
|
||||
|
||||
@@ -62,9 +62,9 @@ export function convertToBedrockConverseMessages(
|
||||
image: {
|
||||
format: format as "png" | "jpeg" | "gif" | "webp",
|
||||
source: {
|
||||
bytes: byteArray
|
||||
}
|
||||
}
|
||||
bytes: byteArray,
|
||||
},
|
||||
},
|
||||
} as ContentBlock
|
||||
}
|
||||
|
||||
@@ -72,14 +72,14 @@ export function convertToBedrockConverseMessages(
|
||||
// Convert tool use to XML format
|
||||
const toolParams = Object.entries(messageBlock.input || {})
|
||||
.map(([key, value]) => `<${key}>\n${value}\n</${key}>`)
|
||||
.join('\n')
|
||||
.join("\n")
|
||||
|
||||
return {
|
||||
toolUse: {
|
||||
toolUseId: messageBlock.id || '',
|
||||
name: messageBlock.name || '',
|
||||
input: `<${messageBlock.name}>\n${toolParams}\n</${messageBlock.name}>`
|
||||
}
|
||||
toolUseId: messageBlock.id || "",
|
||||
name: messageBlock.name || "",
|
||||
input: `<${messageBlock.name}>\n${toolParams}\n</${messageBlock.name}>`,
|
||||
},
|
||||
} as ContentBlock
|
||||
}
|
||||
|
||||
@@ -88,12 +88,12 @@ export function convertToBedrockConverseMessages(
|
||||
if (messageBlock.content && Array.isArray(messageBlock.content)) {
|
||||
return {
|
||||
toolResult: {
|
||||
toolUseId: messageBlock.tool_use_id || '',
|
||||
content: messageBlock.content.map(item => ({
|
||||
text: item.text
|
||||
toolUseId: messageBlock.tool_use_id || "",
|
||||
content: messageBlock.content.map((item) => ({
|
||||
text: item.text,
|
||||
})),
|
||||
status: "success"
|
||||
}
|
||||
status: "success",
|
||||
},
|
||||
} as ContentBlock
|
||||
}
|
||||
|
||||
@@ -101,20 +101,22 @@ export function convertToBedrockConverseMessages(
|
||||
if (messageBlock.output && typeof messageBlock.output === "string") {
|
||||
return {
|
||||
toolResult: {
|
||||
toolUseId: messageBlock.tool_use_id || '',
|
||||
content: [{
|
||||
text: messageBlock.output
|
||||
}],
|
||||
status: "success"
|
||||
}
|
||||
toolUseId: messageBlock.tool_use_id || "",
|
||||
content: [
|
||||
{
|
||||
text: messageBlock.output,
|
||||
},
|
||||
],
|
||||
status: "success",
|
||||
},
|
||||
} as ContentBlock
|
||||
}
|
||||
// Handle array of content blocks if output is an array
|
||||
if (Array.isArray(messageBlock.output)) {
|
||||
return {
|
||||
toolResult: {
|
||||
toolUseId: messageBlock.tool_use_id || '',
|
||||
content: messageBlock.output.map(part => {
|
||||
toolUseId: messageBlock.tool_use_id || "",
|
||||
content: messageBlock.output.map((part) => {
|
||||
if (typeof part === "object" && "text" in part) {
|
||||
return { text: part.text }
|
||||
}
|
||||
@@ -124,48 +126,52 @@ export function convertToBedrockConverseMessages(
|
||||
}
|
||||
return { text: String(part) }
|
||||
}),
|
||||
status: "success"
|
||||
}
|
||||
status: "success",
|
||||
},
|
||||
} as ContentBlock
|
||||
}
|
||||
|
||||
// Default case
|
||||
return {
|
||||
toolResult: {
|
||||
toolUseId: messageBlock.tool_use_id || '',
|
||||
content: [{
|
||||
text: String(messageBlock.output || '')
|
||||
}],
|
||||
status: "success"
|
||||
}
|
||||
toolUseId: messageBlock.tool_use_id || "",
|
||||
content: [
|
||||
{
|
||||
text: String(messageBlock.output || ""),
|
||||
},
|
||||
],
|
||||
status: "success",
|
||||
},
|
||||
} as ContentBlock
|
||||
}
|
||||
|
||||
if (messageBlock.type === "video") {
|
||||
const videoContent = messageBlock.s3Location ? {
|
||||
const videoContent = messageBlock.s3Location
|
||||
? {
|
||||
s3Location: {
|
||||
uri: messageBlock.s3Location.uri,
|
||||
bucketOwner: messageBlock.s3Location.bucketOwner
|
||||
bucketOwner: messageBlock.s3Location.bucketOwner,
|
||||
},
|
||||
}
|
||||
} : messageBlock.source
|
||||
: messageBlock.source
|
||||
|
||||
return {
|
||||
video: {
|
||||
format: "mp4", // Default to mp4, adjust based on actual format if needed
|
||||
source: videoContent
|
||||
}
|
||||
source: videoContent,
|
||||
},
|
||||
} as ContentBlock
|
||||
}
|
||||
|
||||
// Default case for unknown block types
|
||||
return {
|
||||
text: '[Unknown Block Type]'
|
||||
text: "[Unknown Block Type]",
|
||||
} as ContentBlock
|
||||
})
|
||||
|
||||
return {
|
||||
role,
|
||||
content
|
||||
content,
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -175,19 +181,19 @@ export function convertToBedrockConverseMessages(
|
||||
*/
|
||||
export function convertToAnthropicMessage(
|
||||
streamEvent: StreamEvent,
|
||||
modelId: string
|
||||
modelId: string,
|
||||
): Partial<Anthropic.Messages.Message> {
|
||||
// Handle metadata events
|
||||
if (streamEvent.metadata?.usage) {
|
||||
return {
|
||||
id: '', // Bedrock doesn't provide message IDs
|
||||
id: "", // Bedrock doesn't provide message IDs
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
model: modelId,
|
||||
usage: {
|
||||
input_tokens: streamEvent.metadata.usage.inputTokens || 0,
|
||||
output_tokens: streamEvent.metadata.usage.outputTokens || 0
|
||||
}
|
||||
output_tokens: streamEvent.metadata.usage.outputTokens || 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,7 +204,7 @@ export function convertToAnthropicMessage(
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: text }],
|
||||
model: modelId
|
||||
model: modelId,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,7 +215,7 @@ export function convertToAnthropicMessage(
|
||||
role: "assistant",
|
||||
stop_reason: streamEvent.messageStop.stopReason || null,
|
||||
stop_sequence: null,
|
||||
model: modelId
|
||||
model: modelId,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Anthropic } from "@anthropic-ai/sdk";
|
||||
import * as vscode from 'vscode';
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import * as vscode from "vscode"
|
||||
|
||||
/**
|
||||
* Safely converts a value into a plain object.
|
||||
@@ -7,30 +7,31 @@ import * as vscode from 'vscode';
|
||||
function asObjectSafe(value: any): object {
|
||||
// Handle null/undefined
|
||||
if (!value) {
|
||||
return {};
|
||||
return {}
|
||||
}
|
||||
|
||||
try {
|
||||
// Handle strings that might be JSON
|
||||
if (typeof value === 'string') {
|
||||
return JSON.parse(value);
|
||||
if (typeof value === "string") {
|
||||
return JSON.parse(value)
|
||||
}
|
||||
|
||||
// Handle pre-existing objects
|
||||
if (typeof value === 'object') {
|
||||
return Object.assign({}, value);
|
||||
if (typeof value === "object") {
|
||||
return Object.assign({}, value)
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
catch (error) {
|
||||
console.warn('Cline <Language Model API>: Failed to parse object:', error);
|
||||
return {};
|
||||
return {}
|
||||
} catch (error) {
|
||||
console.warn("Cline <Language Model API>: Failed to parse object:", error)
|
||||
return {}
|
||||
}
|
||||
}
|
||||
|
||||
export function convertToVsCodeLmMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): vscode.LanguageModelChatMessage[] {
|
||||
const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = [];
|
||||
export function convertToVsCodeLmMessages(
|
||||
anthropicMessages: Anthropic.Messages.MessageParam[],
|
||||
): vscode.LanguageModelChatMessage[] {
|
||||
const vsCodeLmMessages: vscode.LanguageModelChatMessage[] = []
|
||||
|
||||
for (const anthropicMessage of anthropicMessages) {
|
||||
// Handle simple string messages
|
||||
@@ -38,135 +39,129 @@ export function convertToVsCodeLmMessages(anthropicMessages: Anthropic.Messages.
|
||||
vsCodeLmMessages.push(
|
||||
anthropicMessage.role === "assistant"
|
||||
? vscode.LanguageModelChatMessage.Assistant(anthropicMessage.content)
|
||||
: vscode.LanguageModelChatMessage.User(anthropicMessage.content)
|
||||
);
|
||||
continue;
|
||||
: vscode.LanguageModelChatMessage.User(anthropicMessage.content),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle complex message structures
|
||||
switch (anthropicMessage.role) {
|
||||
case "user": {
|
||||
const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
|
||||
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[];
|
||||
toolMessages: Anthropic.ToolResultBlockParam[];
|
||||
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
|
||||
toolMessages: Anthropic.ToolResultBlockParam[]
|
||||
}>(
|
||||
(acc, part) => {
|
||||
if (part.type === "tool_result") {
|
||||
acc.toolMessages.push(part);
|
||||
acc.toolMessages.push(part)
|
||||
} else if (part.type === "text" || part.type === "image") {
|
||||
acc.nonToolMessages.push(part)
|
||||
}
|
||||
else if (part.type === "text" || part.type === "image") {
|
||||
acc.nonToolMessages.push(part);
|
||||
}
|
||||
return acc;
|
||||
return acc
|
||||
},
|
||||
{ nonToolMessages: [], toolMessages: [] },
|
||||
);
|
||||
)
|
||||
|
||||
// Process tool messages first then non-tool messages
|
||||
const contentParts = [
|
||||
// Convert tool messages to ToolResultParts
|
||||
...toolMessages.map((toolMessage) => {
|
||||
// Process tool result content into TextParts
|
||||
const toolContentParts: vscode.LanguageModelTextPart[] = (
|
||||
const toolContentParts: vscode.LanguageModelTextPart[] =
|
||||
typeof toolMessage.content === "string"
|
||||
? [new vscode.LanguageModelTextPart(toolMessage.content)]
|
||||
: (
|
||||
toolMessage.content?.map((part) => {
|
||||
: (toolMessage.content?.map((part) => {
|
||||
if (part.type === "image") {
|
||||
return new vscode.LanguageModelTextPart(
|
||||
`[Image (${part.source?.type || 'Unknown source-type'}): ${part.source?.media_type || 'unknown media-type'} not supported by VSCode LM API]`
|
||||
);
|
||||
}
|
||||
return new vscode.LanguageModelTextPart(part.text);
|
||||
})
|
||||
?? [new vscode.LanguageModelTextPart("")]
|
||||
`[Image (${part.source?.type || "Unknown source-type"}): ${part.source?.media_type || "unknown media-type"} not supported by VSCode LM API]`,
|
||||
)
|
||||
);
|
||||
}
|
||||
return new vscode.LanguageModelTextPart(part.text)
|
||||
}) ?? [new vscode.LanguageModelTextPart("")])
|
||||
|
||||
return new vscode.LanguageModelToolResultPart(
|
||||
toolMessage.tool_use_id,
|
||||
toolContentParts
|
||||
);
|
||||
return new vscode.LanguageModelToolResultPart(toolMessage.tool_use_id, toolContentParts)
|
||||
}),
|
||||
|
||||
// Convert non-tool messages to TextParts after tool messages
|
||||
...nonToolMessages.map((part) => {
|
||||
if (part.type === "image") {
|
||||
return new vscode.LanguageModelTextPart(
|
||||
`[Image (${part.source?.type || 'Unknown source-type'}): ${part.source?.media_type || 'unknown media-type'} not supported by VSCode LM API]`
|
||||
);
|
||||
`[Image (${part.source?.type || "Unknown source-type"}): ${part.source?.media_type || "unknown media-type"} not supported by VSCode LM API]`,
|
||||
)
|
||||
}
|
||||
return new vscode.LanguageModelTextPart(part.text);
|
||||
})
|
||||
];
|
||||
return new vscode.LanguageModelTextPart(part.text)
|
||||
}),
|
||||
]
|
||||
|
||||
// Add single user message with all content parts
|
||||
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.User(contentParts));
|
||||
break;
|
||||
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.User(contentParts))
|
||||
break
|
||||
}
|
||||
|
||||
case "assistant": {
|
||||
const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
|
||||
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[];
|
||||
toolMessages: Anthropic.ToolUseBlockParam[];
|
||||
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
|
||||
toolMessages: Anthropic.ToolUseBlockParam[]
|
||||
}>(
|
||||
(acc, part) => {
|
||||
if (part.type === "tool_use") {
|
||||
acc.toolMessages.push(part);
|
||||
acc.toolMessages.push(part)
|
||||
} else if (part.type === "text" || part.type === "image") {
|
||||
acc.nonToolMessages.push(part)
|
||||
}
|
||||
else if (part.type === "text" || part.type === "image") {
|
||||
acc.nonToolMessages.push(part);
|
||||
}
|
||||
return acc;
|
||||
return acc
|
||||
},
|
||||
{ nonToolMessages: [], toolMessages: [] },
|
||||
);
|
||||
)
|
||||
|
||||
// Process tool messages first then non-tool messages
|
||||
const contentParts = [
|
||||
// Convert tool messages to ToolCallParts first
|
||||
...toolMessages.map((toolMessage) =>
|
||||
...toolMessages.map(
|
||||
(toolMessage) =>
|
||||
new vscode.LanguageModelToolCallPart(
|
||||
toolMessage.id,
|
||||
toolMessage.name,
|
||||
asObjectSafe(toolMessage.input)
|
||||
)
|
||||
asObjectSafe(toolMessage.input),
|
||||
),
|
||||
),
|
||||
|
||||
// Convert non-tool messages to TextParts after tool messages
|
||||
...nonToolMessages.map((part) => {
|
||||
if (part.type === "image") {
|
||||
return new vscode.LanguageModelTextPart("[Image generation not supported by VSCode LM API]");
|
||||
return new vscode.LanguageModelTextPart("[Image generation not supported by VSCode LM API]")
|
||||
}
|
||||
return new vscode.LanguageModelTextPart(part.text);
|
||||
})
|
||||
];
|
||||
return new vscode.LanguageModelTextPart(part.text)
|
||||
}),
|
||||
]
|
||||
|
||||
// Add the assistant message to the list of messages
|
||||
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.Assistant(contentParts));
|
||||
break;
|
||||
vsCodeLmMessages.push(vscode.LanguageModelChatMessage.Assistant(contentParts))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return vsCodeLmMessages;
|
||||
return vsCodeLmMessages
|
||||
}
|
||||
|
||||
export function convertToAnthropicRole(vsCodeLmMessageRole: vscode.LanguageModelChatMessageRole): string | null {
|
||||
switch (vsCodeLmMessageRole) {
|
||||
case vscode.LanguageModelChatMessageRole.Assistant:
|
||||
return "assistant";
|
||||
return "assistant"
|
||||
case vscode.LanguageModelChatMessageRole.User:
|
||||
return "user";
|
||||
return "user"
|
||||
default:
|
||||
return null;
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
export async function convertToAnthropicMessage(vsCodeLmMessage: vscode.LanguageModelChatMessage): Promise<Anthropic.Messages.Message> {
|
||||
const anthropicRole: string | null = convertToAnthropicRole(vsCodeLmMessage.role);
|
||||
export async function convertToAnthropicMessage(
|
||||
vsCodeLmMessage: vscode.LanguageModelChatMessage,
|
||||
): Promise<Anthropic.Messages.Message> {
|
||||
const anthropicRole: string | null = convertToAnthropicRole(vsCodeLmMessage.role)
|
||||
if (anthropicRole !== "assistant") {
|
||||
throw new Error("Cline <Language Model API>: Only assistant messages are supported.");
|
||||
throw new Error("Cline <Language Model API>: Only assistant messages are supported.")
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -174,14 +169,13 @@ export async function convertToAnthropicMessage(vsCodeLmMessage: vscode.Language
|
||||
type: "message",
|
||||
model: "vscode-lm",
|
||||
role: anthropicRole,
|
||||
content: (
|
||||
vsCodeLmMessage.content
|
||||
content: vsCodeLmMessage.content
|
||||
.map((part): Anthropic.ContentBlock | null => {
|
||||
if (part instanceof vscode.LanguageModelTextPart) {
|
||||
return {
|
||||
type: "text",
|
||||
text: part.value
|
||||
};
|
||||
text: part.value,
|
||||
}
|
||||
}
|
||||
|
||||
if (part instanceof vscode.LanguageModelToolCallPart) {
|
||||
@@ -189,21 +183,18 @@ export async function convertToAnthropicMessage(vsCodeLmMessage: vscode.Language
|
||||
type: "tool_use",
|
||||
id: part.callId || crypto.randomUUID(),
|
||||
name: part.name,
|
||||
input: asObjectSafe(part.input)
|
||||
};
|
||||
input: asObjectSafe(part.input),
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
return null
|
||||
})
|
||||
.filter(
|
||||
(part): part is Anthropic.ContentBlock => part !== null
|
||||
)
|
||||
),
|
||||
.filter((part): part is Anthropic.ContentBlock => part !== null),
|
||||
stop_reason: null,
|
||||
stop_sequence: null,
|
||||
usage: {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -13,7 +13,13 @@ import { ApiHandler, SingleCompletionHandler, buildApiHandler } from "../api"
|
||||
import { ApiStream } from "../api/transform/stream"
|
||||
import { DiffViewProvider } from "../integrations/editor/DiffViewProvider"
|
||||
import { findToolName, formatContentBlockToMarkdown } from "../integrations/misc/export-markdown"
|
||||
import { extractTextFromFile, addLineNumbers, stripLineNumbers, everyLineHasLineNumbers, truncateOutput } from "../integrations/misc/extract-text"
|
||||
import {
|
||||
extractTextFromFile,
|
||||
addLineNumbers,
|
||||
stripLineNumbers,
|
||||
everyLineHasLineNumbers,
|
||||
truncateOutput,
|
||||
} from "../integrations/misc/extract-text"
|
||||
import { TerminalManager } from "../integrations/terminal/TerminalManager"
|
||||
import { UrlContentFetcher } from "../services/browser/UrlContentFetcher"
|
||||
import { listFiles } from "../services/glob/list-files"
|
||||
@@ -112,7 +118,7 @@ export class Cline {
|
||||
experimentalDiffStrategy: boolean = false,
|
||||
) {
|
||||
if (!task && !images && !historyItem) {
|
||||
throw new Error('Either historyItem or task/images must be provided');
|
||||
throw new Error("Either historyItem or task/images must be provided")
|
||||
}
|
||||
|
||||
this.taskId = crypto.randomUUID()
|
||||
@@ -144,7 +150,8 @@ export class Cline {
|
||||
async updateDiffStrategy(experimentalDiffStrategy?: boolean) {
|
||||
// If not provided, get from current state
|
||||
if (experimentalDiffStrategy === undefined) {
|
||||
const { experimentalDiffStrategy: stateExperimentalDiffStrategy } = await this.providerRef.deref()?.getState() ?? {}
|
||||
const { experimentalDiffStrategy: stateExperimentalDiffStrategy } =
|
||||
(await this.providerRef.deref()?.getState()) ?? {}
|
||||
experimentalDiffStrategy = stateExperimentalDiffStrategy ?? false
|
||||
}
|
||||
this.diffStrategy = getDiffStrategy(this.api.getModel().id, this.fuzzyMatchThreshold, experimentalDiffStrategy)
|
||||
@@ -756,8 +763,8 @@ export class Cline {
|
||||
// grouping command_output messages despite any gaps anyways)
|
||||
await delay(50)
|
||||
|
||||
const { terminalOutputLineLimit } = await this.providerRef.deref()?.getState() ?? {}
|
||||
const output = truncateOutput(lines.join('\n'), terminalOutputLineLimit)
|
||||
const { terminalOutputLineLimit } = (await this.providerRef.deref()?.getState()) ?? {}
|
||||
const output = truncateOutput(lines.join("\n"), terminalOutputLineLimit)
|
||||
const result = output.trim()
|
||||
|
||||
if (userFeedback) {
|
||||
@@ -788,7 +795,8 @@ export class Cline {
|
||||
async *attemptApiRequest(previousApiReqIndex: number): ApiStream {
|
||||
let mcpHub: McpHub | undefined
|
||||
|
||||
const { mcpEnabled, alwaysApproveResubmit, requestDelaySeconds } = await this.providerRef.deref()?.getState() ?? {}
|
||||
const { mcpEnabled, alwaysApproveResubmit, requestDelaySeconds } =
|
||||
(await this.providerRef.deref()?.getState()) ?? {}
|
||||
|
||||
if (mcpEnabled ?? true) {
|
||||
mcpHub = this.providerRef.deref()?.mcpHub
|
||||
@@ -801,24 +809,27 @@ export class Cline {
|
||||
})
|
||||
}
|
||||
|
||||
const { browserViewportSize, preferredLanguage, mode, customPrompts } = await this.providerRef.deref()?.getState() ?? {}
|
||||
const systemPrompt = await SYSTEM_PROMPT(
|
||||
const { browserViewportSize, preferredLanguage, mode, customPrompts } =
|
||||
(await this.providerRef.deref()?.getState()) ?? {}
|
||||
const systemPrompt =
|
||||
(await SYSTEM_PROMPT(
|
||||
cwd,
|
||||
this.api.getModel().info.supportsComputerUse ?? false,
|
||||
mcpHub,
|
||||
this.diffStrategy,
|
||||
browserViewportSize,
|
||||
mode,
|
||||
customPrompts
|
||||
) + await addCustomInstructions(
|
||||
customPrompts,
|
||||
)) +
|
||||
(await addCustomInstructions(
|
||||
{
|
||||
customInstructions: this.customInstructions,
|
||||
customPrompts,
|
||||
preferredLanguage
|
||||
preferredLanguage,
|
||||
},
|
||||
cwd,
|
||||
mode
|
||||
)
|
||||
mode,
|
||||
))
|
||||
|
||||
// If the previous API request's total token usage is close to the context window, truncate the conversation history to free up space for the new request
|
||||
if (previousApiReqIndex >= 0) {
|
||||
@@ -845,18 +856,18 @@ export class Cline {
|
||||
if (Array.isArray(content)) {
|
||||
if (!this.api.getModel().info.supportsImages) {
|
||||
// Convert image blocks to text descriptions
|
||||
content = content.map(block => {
|
||||
if (block.type === 'image') {
|
||||
content = content.map((block) => {
|
||||
if (block.type === "image") {
|
||||
// Convert image blocks to text descriptions
|
||||
// Note: We can't access the actual image content/url due to API limitations,
|
||||
// but we can indicate that an image was present in the conversation
|
||||
return {
|
||||
type: 'text',
|
||||
text: '[Referenced image in conversation]'
|
||||
};
|
||||
type: "text",
|
||||
text: "[Referenced image in conversation]",
|
||||
}
|
||||
return block;
|
||||
});
|
||||
}
|
||||
return block
|
||||
})
|
||||
}
|
||||
}
|
||||
return { role, content }
|
||||
@@ -876,7 +887,12 @@ export class Cline {
|
||||
// Automatically retry with delay
|
||||
// Show countdown timer in error color
|
||||
for (let i = requestDelay; i > 0; i--) {
|
||||
await this.say("api_req_retry_delayed", `${errorMsg}\n\nRetrying in ${i} seconds...`, undefined, true)
|
||||
await this.say(
|
||||
"api_req_retry_delayed",
|
||||
`${errorMsg}\n\nRetrying in ${i} seconds...`,
|
||||
undefined,
|
||||
true,
|
||||
)
|
||||
await delay(1000)
|
||||
}
|
||||
await this.say("api_req_retry_delayed", `${errorMsg}\n\nRetrying now...`, undefined, false)
|
||||
@@ -1125,7 +1141,7 @@ export class Cline {
|
||||
}
|
||||
|
||||
// Validate tool use based on current mode
|
||||
const { mode } = await this.providerRef.deref()?.getState() ?? {}
|
||||
const { mode } = (await this.providerRef.deref()?.getState()) ?? {}
|
||||
try {
|
||||
validateToolUse(block.name, mode ?? defaultModeSlug)
|
||||
} catch (error) {
|
||||
@@ -1192,7 +1208,10 @@ export class Cline {
|
||||
await this.diffViewProvider.open(relPath)
|
||||
}
|
||||
// editor is open, stream content in
|
||||
await this.diffViewProvider.update(everyLineHasLineNumbers(newContent) ? stripLineNumbers(newContent) : newContent, false)
|
||||
await this.diffViewProvider.update(
|
||||
everyLineHasLineNumbers(newContent) ? stripLineNumbers(newContent) : newContent,
|
||||
false,
|
||||
)
|
||||
break
|
||||
} else {
|
||||
if (!relPath) {
|
||||
@@ -1209,7 +1228,9 @@ export class Cline {
|
||||
}
|
||||
if (!predictedLineCount) {
|
||||
this.consecutiveMistakeCount++
|
||||
pushToolResult(await this.sayAndCreateMissingParamError("write_to_file", "line_count"))
|
||||
pushToolResult(
|
||||
await this.sayAndCreateMissingParamError("write_to_file", "line_count"),
|
||||
)
|
||||
await this.diffViewProvider.reset()
|
||||
break
|
||||
}
|
||||
@@ -1224,17 +1245,28 @@ export class Cline {
|
||||
await this.ask("tool", partialMessage, true).catch(() => {}) // sending true for partial even though it's not a partial, this shows the edit row before the content is streamed into the editor
|
||||
await this.diffViewProvider.open(relPath)
|
||||
}
|
||||
await this.diffViewProvider.update(everyLineHasLineNumbers(newContent) ? stripLineNumbers(newContent) : newContent, true)
|
||||
await this.diffViewProvider.update(
|
||||
everyLineHasLineNumbers(newContent) ? stripLineNumbers(newContent) : newContent,
|
||||
true,
|
||||
)
|
||||
await delay(300) // wait for diff view to update
|
||||
this.diffViewProvider.scrollToFirstDiff()
|
||||
|
||||
// Check for code omissions before proceeding
|
||||
if (detectCodeOmission(this.diffViewProvider.originalContent || "", newContent, predictedLineCount)) {
|
||||
if (
|
||||
detectCodeOmission(
|
||||
this.diffViewProvider.originalContent || "",
|
||||
newContent,
|
||||
predictedLineCount,
|
||||
)
|
||||
) {
|
||||
if (this.diffStrategy) {
|
||||
await this.diffViewProvider.revertChanges()
|
||||
pushToolResult(formatResponse.toolError(
|
||||
`Content appears to be truncated (file has ${newContent.split("\n").length} lines but was predicted to have ${predictedLineCount} lines), and found comments indicating omitted code (e.g., '// rest of code unchanged', '/* previous code */'). Please provide the complete file content without any omissions if possible, or otherwise use the 'apply_diff' tool to apply the diff to the original file.`
|
||||
))
|
||||
pushToolResult(
|
||||
formatResponse.toolError(
|
||||
`Content appears to be truncated (file has ${newContent.split("\n").length} lines but was predicted to have ${predictedLineCount} lines), and found comments indicating omitted code (e.g., '// rest of code unchanged', '/* previous code */'). Please provide the complete file content without any omissions if possible, or otherwise use the 'apply_diff' tool to apply the diff to the original file.`,
|
||||
),
|
||||
)
|
||||
break
|
||||
} else {
|
||||
vscode.window
|
||||
@@ -1285,7 +1317,7 @@ export class Cline {
|
||||
pushToolResult(
|
||||
`The user made the following updates to your content:\n\n${userEdits}\n\n` +
|
||||
`The updated content, which includes both your original modifications and the user's edits, has been successfully saved to ${relPath.toPosix()}. Here is the full, updated content of the file, including line numbers:\n\n` +
|
||||
`<final_file_content path="${relPath.toPosix()}">\n${addLineNumbers(finalContent || '')}\n</final_file_content>\n\n` +
|
||||
`<final_file_content path="${relPath.toPosix()}">\n${addLineNumbers(finalContent || "")}\n</final_file_content>\n\n` +
|
||||
`Please note:\n` +
|
||||
`1. You do not need to re-write the file with these changes, as they have already been applied.\n` +
|
||||
`2. Proceed with the task using this updated file content as the new baseline.\n` +
|
||||
@@ -1347,21 +1379,24 @@ export class Cline {
|
||||
const originalContent = await fs.readFile(absolutePath, "utf-8")
|
||||
|
||||
// Apply the diff to the original content
|
||||
const diffResult = await this.diffStrategy?.applyDiff(
|
||||
const diffResult = (await this.diffStrategy?.applyDiff(
|
||||
originalContent,
|
||||
diffContent,
|
||||
parseInt(block.params.start_line ?? ''),
|
||||
parseInt(block.params.end_line ?? '')
|
||||
) ?? {
|
||||
parseInt(block.params.start_line ?? ""),
|
||||
parseInt(block.params.end_line ?? ""),
|
||||
)) ?? {
|
||||
success: false,
|
||||
error: "No diff strategy available"
|
||||
error: "No diff strategy available",
|
||||
}
|
||||
if (!diffResult.success) {
|
||||
this.consecutiveMistakeCount++
|
||||
const currentCount = (this.consecutiveMistakeCountForApplyDiff.get(relPath) || 0) + 1
|
||||
const currentCount =
|
||||
(this.consecutiveMistakeCountForApplyDiff.get(relPath) || 0) + 1
|
||||
this.consecutiveMistakeCountForApplyDiff.set(relPath, currentCount)
|
||||
const errorDetails = diffResult.details ? JSON.stringify(diffResult.details, null, 2) : ''
|
||||
const formattedError = `Unable to apply diff to file: ${absolutePath}\n\n<error_details>\n${diffResult.error}${errorDetails ? `\n\nDetails:\n${errorDetails}` : ''}\n</error_details>`
|
||||
const errorDetails = diffResult.details
|
||||
? JSON.stringify(diffResult.details, null, 2)
|
||||
: ""
|
||||
const formattedError = `Unable to apply diff to file: ${absolutePath}\n\n<error_details>\n${diffResult.error}${errorDetails ? `\n\nDetails:\n${errorDetails}` : ""}\n</error_details>`
|
||||
if (currentCount >= 2) {
|
||||
await this.say("error", formattedError)
|
||||
}
|
||||
@@ -1373,9 +1408,9 @@ export class Cline {
|
||||
this.consecutiveMistakeCountForApplyDiff.delete(relPath)
|
||||
// Show diff view before asking for approval
|
||||
this.diffViewProvider.editType = "modify"
|
||||
await this.diffViewProvider.open(relPath);
|
||||
await this.diffViewProvider.update(diffResult.content, true);
|
||||
await this.diffViewProvider.scrollToFirstDiff();
|
||||
await this.diffViewProvider.open(relPath)
|
||||
await this.diffViewProvider.update(diffResult.content, true)
|
||||
await this.diffViewProvider.scrollToFirstDiff()
|
||||
|
||||
const completeMessage = JSON.stringify({
|
||||
...sharedMessageProps,
|
||||
@@ -1403,7 +1438,7 @@ export class Cline {
|
||||
pushToolResult(
|
||||
`The user made the following updates to your content:\n\n${userEdits}\n\n` +
|
||||
`The updated content, which includes both your original modifications and the user's edits, has been successfully saved to ${relPath.toPosix()}. Here is the full, updated content of the file, including line numbers:\n\n` +
|
||||
`<final_file_content path="${relPath.toPosix()}">\n${addLineNumbers(finalContent || '')}\n</final_file_content>\n\n` +
|
||||
`<final_file_content path="${relPath.toPosix()}">\n${addLineNumbers(finalContent || "")}\n</final_file_content>\n\n` +
|
||||
`Please note:\n` +
|
||||
`1. You do not need to re-write the file with these changes, as they have already been applied.\n` +
|
||||
`2. Proceed with the task using this updated file content as the new baseline.\n` +
|
||||
@@ -1411,7 +1446,9 @@ export class Cline {
|
||||
`${newProblemsMessage}`,
|
||||
)
|
||||
} else {
|
||||
pushToolResult(`Changes successfully applied to ${relPath.toPosix()}:\n\n${newProblemsMessage}`)
|
||||
pushToolResult(
|
||||
`Changes successfully applied to ${relPath.toPosix()}:\n\n${newProblemsMessage}`,
|
||||
)
|
||||
}
|
||||
await this.diffViewProvider.reset()
|
||||
break
|
||||
@@ -1615,7 +1652,7 @@ export class Cline {
|
||||
await this.ask(
|
||||
"browser_action_launch",
|
||||
removeClosingTag("url", url),
|
||||
block.partial
|
||||
block.partial,
|
||||
).catch(() => {})
|
||||
} else {
|
||||
await this.say(
|
||||
@@ -1744,7 +1781,7 @@ export class Cline {
|
||||
try {
|
||||
if (block.partial) {
|
||||
await this.ask("command", removeClosingTag("command", command), block.partial).catch(
|
||||
() => {}
|
||||
() => {},
|
||||
)
|
||||
break
|
||||
} else {
|
||||
@@ -2409,7 +2446,7 @@ export class Cline {
|
||||
Promise.all(
|
||||
userContent.map(async (block) => {
|
||||
const shouldProcessMentions = (text: string) =>
|
||||
text.includes("<task>") || text.includes("<feedback>");
|
||||
text.includes("<task>") || text.includes("<feedback>")
|
||||
|
||||
if (block.type === "text") {
|
||||
if (shouldProcessMentions(block.text)) {
|
||||
@@ -2418,7 +2455,7 @@ export class Cline {
|
||||
text: await parseMentions(block.text, cwd, this.urlContentFetcher),
|
||||
}
|
||||
}
|
||||
return block;
|
||||
return block
|
||||
} else if (block.type === "tool_result") {
|
||||
if (typeof block.content === "string") {
|
||||
if (shouldProcessMentions(block.content)) {
|
||||
@@ -2427,7 +2464,7 @@ export class Cline {
|
||||
content: await parseMentions(block.content, cwd, this.urlContentFetcher),
|
||||
}
|
||||
}
|
||||
return block;
|
||||
return block
|
||||
} else if (Array.isArray(block.content)) {
|
||||
const parsedContent = await Promise.all(
|
||||
block.content.map(async (contentBlock) => {
|
||||
@@ -2445,7 +2482,7 @@ export class Cline {
|
||||
content: parsedContent,
|
||||
}
|
||||
}
|
||||
return block;
|
||||
return block
|
||||
}
|
||||
return block
|
||||
}),
|
||||
@@ -2571,26 +2608,29 @@ export class Cline {
|
||||
// Add current time information with timezone
|
||||
const now = new Date()
|
||||
const formatter = new Intl.DateTimeFormat(undefined, {
|
||||
year: 'numeric',
|
||||
month: 'numeric',
|
||||
day: 'numeric',
|
||||
hour: 'numeric',
|
||||
minute: 'numeric',
|
||||
second: 'numeric',
|
||||
hour12: true
|
||||
year: "numeric",
|
||||
month: "numeric",
|
||||
day: "numeric",
|
||||
hour: "numeric",
|
||||
minute: "numeric",
|
||||
second: "numeric",
|
||||
hour12: true,
|
||||
})
|
||||
const timeZone = formatter.resolvedOptions().timeZone
|
||||
const timeZoneOffset = -now.getTimezoneOffset() / 60 // Convert to hours and invert sign to match conventional notation
|
||||
const timeZoneOffsetStr = `${timeZoneOffset >= 0 ? '+' : ''}${timeZoneOffset}:00`
|
||||
const timeZoneOffsetStr = `${timeZoneOffset >= 0 ? "+" : ""}${timeZoneOffset}:00`
|
||||
details += `\n\n# Current Time\n${formatter.format(now)} (${timeZone}, UTC${timeZoneOffsetStr})`
|
||||
|
||||
// Add current mode and any mode-specific warnings
|
||||
const { mode } = await this.providerRef.deref()?.getState() ?? {}
|
||||
const { mode } = (await this.providerRef.deref()?.getState()) ?? {}
|
||||
const currentMode = mode ?? defaultModeSlug
|
||||
details += `\n\n# Current Mode\n${currentMode}`
|
||||
|
||||
// Add warning if not in code mode
|
||||
if (!isToolAllowedForMode('write_to_file', currentMode) || !isToolAllowedForMode('execute_command', currentMode)) {
|
||||
if (
|
||||
!isToolAllowedForMode("write_to_file", currentMode) ||
|
||||
!isToolAllowedForMode("execute_command", currentMode)
|
||||
) {
|
||||
details += `\n\nNOTE: You are currently in '${currentMode}' mode which only allows read-only operations. To write files or execute commands, the user will need to switch to '${defaultModeSlug}' mode. Note that only the user can switch modes.`
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,36 +1,36 @@
|
||||
import { Mode, isToolAllowedForMode, TestToolName, getModeConfig, modes } from '../../shared/modes';
|
||||
import { validateToolUse } from '../mode-validator';
|
||||
import { Mode, isToolAllowedForMode, TestToolName, getModeConfig, modes } from "../../shared/modes"
|
||||
import { validateToolUse } from "../mode-validator"
|
||||
|
||||
const asTestTool = (tool: string): TestToolName => tool as TestToolName;
|
||||
const [codeMode, architectMode, askMode] = modes.map(mode => mode.slug);
|
||||
const asTestTool = (tool: string): TestToolName => tool as TestToolName
|
||||
const [codeMode, architectMode, askMode] = modes.map((mode) => mode.slug)
|
||||
|
||||
describe('mode-validator', () => {
|
||||
describe('isToolAllowedForMode', () => {
|
||||
describe('code mode', () => {
|
||||
it('allows all code mode tools', () => {
|
||||
const mode = getModeConfig(codeMode);
|
||||
describe("mode-validator", () => {
|
||||
describe("isToolAllowedForMode", () => {
|
||||
describe("code mode", () => {
|
||||
it("allows all code mode tools", () => {
|
||||
const mode = getModeConfig(codeMode)
|
||||
mode.tools.forEach(([tool]) => {
|
||||
expect(isToolAllowedForMode(tool, codeMode)).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('disallows unknown tools', () => {
|
||||
expect(isToolAllowedForMode(asTestTool('unknown_tool'), codeMode)).toBe(false)
|
||||
it("disallows unknown tools", () => {
|
||||
expect(isToolAllowedForMode(asTestTool("unknown_tool"), codeMode)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('architect mode', () => {
|
||||
it('allows configured tools', () => {
|
||||
const mode = getModeConfig(architectMode);
|
||||
describe("architect mode", () => {
|
||||
it("allows configured tools", () => {
|
||||
const mode = getModeConfig(architectMode)
|
||||
mode.tools.forEach(([tool]) => {
|
||||
expect(isToolAllowedForMode(tool, architectMode)).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('ask mode', () => {
|
||||
it('allows configured tools', () => {
|
||||
const mode = getModeConfig(askMode);
|
||||
describe("ask mode", () => {
|
||||
it("allows configured tools", () => {
|
||||
const mode = getModeConfig(askMode)
|
||||
mode.tools.forEach(([tool]) => {
|
||||
expect(isToolAllowedForMode(tool, askMode)).toBe(true)
|
||||
})
|
||||
@@ -38,15 +38,15 @@ describe('mode-validator', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('validateToolUse', () => {
|
||||
it('throws error for disallowed tools in architect mode', () => {
|
||||
expect(() => validateToolUse('unknown_tool', 'architect')).toThrow(
|
||||
'Tool "unknown_tool" is not allowed in architect mode.'
|
||||
describe("validateToolUse", () => {
|
||||
it("throws error for disallowed tools in architect mode", () => {
|
||||
expect(() => validateToolUse("unknown_tool", "architect")).toThrow(
|
||||
'Tool "unknown_tool" is not allowed in architect mode.',
|
||||
)
|
||||
})
|
||||
|
||||
it('does not throw for allowed tools in architect mode', () => {
|
||||
expect(() => validateToolUse('read_file', 'architect')).not.toThrow()
|
||||
it("does not throw for allowed tools in architect mode", () => {
|
||||
expect(() => validateToolUse("read_file", "architect")).not.toThrow()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,7 +1,7 @@
|
||||
import { ExtensionContext } from 'vscode'
|
||||
import { ApiConfiguration } from '../../shared/api'
|
||||
import { Mode } from '../prompts/types'
|
||||
import { ApiConfigMeta } from '../../shared/ExtensionMessage'
|
||||
import { ExtensionContext } from "vscode"
|
||||
import { ApiConfiguration } from "../../shared/api"
|
||||
import { Mode } from "../prompts/types"
|
||||
import { ApiConfigMeta } from "../../shared/ExtensionMessage"
|
||||
|
||||
export interface ApiConfigData {
|
||||
currentApiConfigName: string
|
||||
@@ -13,12 +13,12 @@ export interface ApiConfigData {
|
||||
|
||||
export class ConfigManager {
|
||||
private readonly defaultConfig: ApiConfigData = {
|
||||
currentApiConfigName: 'default',
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: {
|
||||
default: {
|
||||
id: this.generateId()
|
||||
}
|
||||
}
|
||||
id: this.generateId(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
private readonly SCOPE_PREFIX = "roo_cline_config_"
|
||||
@@ -69,7 +69,7 @@ export class ConfigManager {
|
||||
const config = await this.readConfig()
|
||||
return Object.entries(config.apiConfigs).map(([name, apiConfig]) => ({
|
||||
name,
|
||||
id: apiConfig.id || '',
|
||||
id: apiConfig.id || "",
|
||||
apiProvider: apiConfig.apiProvider,
|
||||
}))
|
||||
} catch (error) {
|
||||
@@ -86,7 +86,7 @@ export class ConfigManager {
|
||||
const existingConfig = currentConfig.apiConfigs[name]
|
||||
currentConfig.apiConfigs[name] = {
|
||||
...config,
|
||||
id: existingConfig?.id || this.generateId()
|
||||
id: existingConfig?.id || this.generateId(),
|
||||
}
|
||||
await this.writeConfig(currentConfig)
|
||||
} catch (error) {
|
||||
@@ -106,7 +106,7 @@ export class ConfigManager {
|
||||
throw new Error(`Config '${name}' not found`)
|
||||
}
|
||||
|
||||
config.currentApiConfigName = name;
|
||||
config.currentApiConfigName = name
|
||||
await this.writeConfig(config)
|
||||
|
||||
return apiConfig
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
import { ExtensionContext } from 'vscode'
|
||||
import { ConfigManager, ApiConfigData } from '../ConfigManager'
|
||||
import { ApiConfiguration } from '../../../shared/api'
|
||||
import { ExtensionContext } from "vscode"
|
||||
import { ConfigManager, ApiConfigData } from "../ConfigManager"
|
||||
import { ApiConfiguration } from "../../../shared/api"
|
||||
|
||||
// Mock VSCode ExtensionContext
|
||||
const mockSecrets = {
|
||||
get: jest.fn(),
|
||||
store: jest.fn(),
|
||||
delete: jest.fn()
|
||||
delete: jest.fn(),
|
||||
}
|
||||
|
||||
const mockContext = {
|
||||
secrets: mockSecrets
|
||||
secrets: mockSecrets,
|
||||
} as unknown as ExtensionContext
|
||||
|
||||
describe('ConfigManager', () => {
|
||||
describe("ConfigManager", () => {
|
||||
let configManager: ConfigManager
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -21,8 +21,8 @@ describe('ConfigManager', () => {
|
||||
configManager = new ConfigManager(mockContext)
|
||||
})
|
||||
|
||||
describe('initConfig', () => {
|
||||
it('should not write to storage when secrets.get returns null', async () => {
|
||||
describe("initConfig", () => {
|
||||
it("should not write to storage when secrets.get returns null", async () => {
|
||||
// Mock readConfig to return null
|
||||
mockSecrets.get.mockResolvedValueOnce(null)
|
||||
|
||||
@@ -32,35 +32,39 @@ describe('ConfigManager', () => {
|
||||
expect(mockSecrets.store).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not initialize config if it exists', async () => {
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
||||
currentApiConfigName: 'default',
|
||||
it("should not initialize config if it exists", async () => {
|
||||
mockSecrets.get.mockResolvedValue(
|
||||
JSON.stringify({
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: {
|
||||
default: {
|
||||
config: {},
|
||||
id: 'default'
|
||||
}
|
||||
}
|
||||
}))
|
||||
id: "default",
|
||||
},
|
||||
},
|
||||
}),
|
||||
)
|
||||
|
||||
await configManager.initConfig()
|
||||
|
||||
expect(mockSecrets.store).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should generate IDs for configs that lack them', async () => {
|
||||
it("should generate IDs for configs that lack them", async () => {
|
||||
// Mock a config with missing IDs
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
||||
currentApiConfigName: 'default',
|
||||
mockSecrets.get.mockResolvedValue(
|
||||
JSON.stringify({
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: {
|
||||
default: {
|
||||
config: {}
|
||||
config: {},
|
||||
},
|
||||
test: {
|
||||
apiProvider: 'anthropic'
|
||||
}
|
||||
}
|
||||
}))
|
||||
apiProvider: "anthropic",
|
||||
},
|
||||
},
|
||||
}),
|
||||
)
|
||||
|
||||
await configManager.initConfig()
|
||||
|
||||
@@ -71,53 +75,53 @@ describe('ConfigManager', () => {
|
||||
expect(storedConfig.apiConfigs.test.id).toBeTruthy()
|
||||
})
|
||||
|
||||
it('should throw error if secrets storage fails', async () => {
|
||||
mockSecrets.get.mockRejectedValue(new Error('Storage failed'))
|
||||
it("should throw error if secrets storage fails", async () => {
|
||||
mockSecrets.get.mockRejectedValue(new Error("Storage failed"))
|
||||
|
||||
await expect(configManager.initConfig()).rejects.toThrow(
|
||||
'Failed to initialize config: Error: Failed to read config from secrets: Error: Storage failed'
|
||||
"Failed to initialize config: Error: Failed to read config from secrets: Error: Storage failed",
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('ListConfig', () => {
|
||||
it('should list all available configs', async () => {
|
||||
describe("ListConfig", () => {
|
||||
it("should list all available configs", async () => {
|
||||
const existingConfig: ApiConfigData = {
|
||||
currentApiConfigName: 'default',
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: {
|
||||
default: {
|
||||
id: 'default'
|
||||
id: "default",
|
||||
},
|
||||
test: {
|
||||
apiProvider: 'anthropic',
|
||||
id: 'test-id'
|
||||
}
|
||||
apiProvider: "anthropic",
|
||||
id: "test-id",
|
||||
},
|
||||
},
|
||||
modeApiConfigs: {
|
||||
code: 'default',
|
||||
architect: 'default',
|
||||
ask: 'default'
|
||||
}
|
||||
code: "default",
|
||||
architect: "default",
|
||||
ask: "default",
|
||||
},
|
||||
}
|
||||
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
||||
|
||||
const configs = await configManager.ListConfig()
|
||||
expect(configs).toEqual([
|
||||
{ name: 'default', id: 'default', apiProvider: undefined },
|
||||
{ name: 'test', id: 'test-id', apiProvider: 'anthropic' }
|
||||
{ name: "default", id: "default", apiProvider: undefined },
|
||||
{ name: "test", id: "test-id", apiProvider: "anthropic" },
|
||||
])
|
||||
})
|
||||
|
||||
it('should handle empty config file', async () => {
|
||||
it("should handle empty config file", async () => {
|
||||
const emptyConfig: ApiConfigData = {
|
||||
currentApiConfigName: 'default',
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: {},
|
||||
modeApiConfigs: {
|
||||
code: 'default',
|
||||
architect: 'default',
|
||||
ask: 'default'
|
||||
}
|
||||
code: "default",
|
||||
architect: "default",
|
||||
ask: "default",
|
||||
},
|
||||
}
|
||||
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify(emptyConfig))
|
||||
@@ -126,326 +130,340 @@ describe('ConfigManager', () => {
|
||||
expect(configs).toEqual([])
|
||||
})
|
||||
|
||||
it('should throw error if reading from secrets fails', async () => {
|
||||
mockSecrets.get.mockRejectedValue(new Error('Read failed'))
|
||||
it("should throw error if reading from secrets fails", async () => {
|
||||
mockSecrets.get.mockRejectedValue(new Error("Read failed"))
|
||||
|
||||
await expect(configManager.ListConfig()).rejects.toThrow(
|
||||
'Failed to list configs: Error: Failed to read config from secrets: Error: Read failed'
|
||||
"Failed to list configs: Error: Failed to read config from secrets: Error: Read failed",
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('SaveConfig', () => {
|
||||
it('should save new config', async () => {
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
||||
currentApiConfigName: 'default',
|
||||
describe("SaveConfig", () => {
|
||||
it("should save new config", async () => {
|
||||
mockSecrets.get.mockResolvedValue(
|
||||
JSON.stringify({
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: {
|
||||
default: {}
|
||||
default: {},
|
||||
},
|
||||
modeApiConfigs: {
|
||||
code: 'default',
|
||||
architect: 'default',
|
||||
ask: 'default'
|
||||
}
|
||||
}))
|
||||
code: "default",
|
||||
architect: "default",
|
||||
ask: "default",
|
||||
},
|
||||
}),
|
||||
)
|
||||
|
||||
const newConfig: ApiConfiguration = {
|
||||
apiProvider: 'anthropic',
|
||||
apiKey: 'test-key'
|
||||
apiProvider: "anthropic",
|
||||
apiKey: "test-key",
|
||||
}
|
||||
|
||||
await configManager.SaveConfig('test', newConfig)
|
||||
await configManager.SaveConfig("test", newConfig)
|
||||
|
||||
// Get the actual stored config to check the generated ID
|
||||
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
|
||||
const testConfigId = storedConfig.apiConfigs.test.id
|
||||
|
||||
const expectedConfig = {
|
||||
currentApiConfigName: 'default',
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: {
|
||||
default: {},
|
||||
test: {
|
||||
...newConfig,
|
||||
id: testConfigId
|
||||
}
|
||||
id: testConfigId,
|
||||
},
|
||||
},
|
||||
modeApiConfigs: {
|
||||
code: 'default',
|
||||
architect: 'default',
|
||||
ask: 'default'
|
||||
}
|
||||
code: "default",
|
||||
architect: "default",
|
||||
ask: "default",
|
||||
},
|
||||
}
|
||||
|
||||
expect(mockSecrets.store).toHaveBeenCalledWith(
|
||||
'roo_cline_config_api_config',
|
||||
JSON.stringify(expectedConfig, null, 2)
|
||||
"roo_cline_config_api_config",
|
||||
JSON.stringify(expectedConfig, null, 2),
|
||||
)
|
||||
})
|
||||
|
||||
it('should update existing config', async () => {
|
||||
it("should update existing config", async () => {
|
||||
const existingConfig: ApiConfigData = {
|
||||
currentApiConfigName: 'default',
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: {
|
||||
test: {
|
||||
apiProvider: 'anthropic',
|
||||
apiKey: 'old-key',
|
||||
id: 'test-id'
|
||||
}
|
||||
}
|
||||
apiProvider: "anthropic",
|
||||
apiKey: "old-key",
|
||||
id: "test-id",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
||||
|
||||
const updatedConfig: ApiConfiguration = {
|
||||
apiProvider: 'anthropic',
|
||||
apiKey: 'new-key'
|
||||
apiProvider: "anthropic",
|
||||
apiKey: "new-key",
|
||||
}
|
||||
|
||||
await configManager.SaveConfig('test', updatedConfig)
|
||||
await configManager.SaveConfig("test", updatedConfig)
|
||||
|
||||
const expectedConfig = {
|
||||
currentApiConfigName: 'default',
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: {
|
||||
test: {
|
||||
apiProvider: 'anthropic',
|
||||
apiKey: 'new-key',
|
||||
id: 'test-id'
|
||||
}
|
||||
}
|
||||
apiProvider: "anthropic",
|
||||
apiKey: "new-key",
|
||||
id: "test-id",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
expect(mockSecrets.store).toHaveBeenCalledWith(
|
||||
'roo_cline_config_api_config',
|
||||
JSON.stringify(expectedConfig, null, 2)
|
||||
"roo_cline_config_api_config",
|
||||
JSON.stringify(expectedConfig, null, 2),
|
||||
)
|
||||
})
|
||||
|
||||
it('should throw error if secrets storage fails', async () => {
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
||||
currentApiConfigName: 'default',
|
||||
apiConfigs: { default: {} }
|
||||
}))
|
||||
mockSecrets.store.mockRejectedValueOnce(new Error('Storage failed'))
|
||||
it("should throw error if secrets storage fails", async () => {
|
||||
mockSecrets.get.mockResolvedValue(
|
||||
JSON.stringify({
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: { default: {} },
|
||||
}),
|
||||
)
|
||||
mockSecrets.store.mockRejectedValueOnce(new Error("Storage failed"))
|
||||
|
||||
await expect(configManager.SaveConfig('test', {})).rejects.toThrow(
|
||||
'Failed to save config: Error: Failed to write config to secrets: Error: Storage failed'
|
||||
await expect(configManager.SaveConfig("test", {})).rejects.toThrow(
|
||||
"Failed to save config: Error: Failed to write config to secrets: Error: Storage failed",
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('DeleteConfig', () => {
|
||||
it('should delete existing config', async () => {
|
||||
describe("DeleteConfig", () => {
|
||||
it("should delete existing config", async () => {
|
||||
const existingConfig: ApiConfigData = {
|
||||
currentApiConfigName: 'default',
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: {
|
||||
default: {
|
||||
id: 'default'
|
||||
id: "default",
|
||||
},
|
||||
test: {
|
||||
apiProvider: 'anthropic',
|
||||
id: 'test-id'
|
||||
}
|
||||
}
|
||||
apiProvider: "anthropic",
|
||||
id: "test-id",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
||||
|
||||
await configManager.DeleteConfig('test')
|
||||
await configManager.DeleteConfig("test")
|
||||
|
||||
// Get the stored config to check the ID
|
||||
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
|
||||
expect(storedConfig.currentApiConfigName).toBe('default')
|
||||
expect(Object.keys(storedConfig.apiConfigs)).toEqual(['default'])
|
||||
expect(storedConfig.currentApiConfigName).toBe("default")
|
||||
expect(Object.keys(storedConfig.apiConfigs)).toEqual(["default"])
|
||||
expect(storedConfig.apiConfigs.default.id).toBeTruthy()
|
||||
})
|
||||
|
||||
it('should throw error when trying to delete non-existent config', async () => {
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
||||
currentApiConfigName: 'default',
|
||||
apiConfigs: { default: {} }
|
||||
}))
|
||||
|
||||
await expect(configManager.DeleteConfig('nonexistent')).rejects.toThrow(
|
||||
"Config 'nonexistent' not found"
|
||||
it("should throw error when trying to delete non-existent config", async () => {
|
||||
mockSecrets.get.mockResolvedValue(
|
||||
JSON.stringify({
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: { default: {} },
|
||||
}),
|
||||
)
|
||||
|
||||
await expect(configManager.DeleteConfig("nonexistent")).rejects.toThrow("Config 'nonexistent' not found")
|
||||
})
|
||||
|
||||
it('should throw error when trying to delete last remaining config', async () => {
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
||||
currentApiConfigName: 'default',
|
||||
it("should throw error when trying to delete last remaining config", async () => {
|
||||
mockSecrets.get.mockResolvedValue(
|
||||
JSON.stringify({
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: {
|
||||
default: {
|
||||
id: 'default'
|
||||
}
|
||||
}
|
||||
}))
|
||||
id: "default",
|
||||
},
|
||||
},
|
||||
}),
|
||||
)
|
||||
|
||||
await expect(configManager.DeleteConfig('default')).rejects.toThrow(
|
||||
'Cannot delete the last remaining configuration.'
|
||||
await expect(configManager.DeleteConfig("default")).rejects.toThrow(
|
||||
"Cannot delete the last remaining configuration.",
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('LoadConfig', () => {
|
||||
it('should load config and update current config name', async () => {
|
||||
describe("LoadConfig", () => {
|
||||
it("should load config and update current config name", async () => {
|
||||
const existingConfig: ApiConfigData = {
|
||||
currentApiConfigName: 'default',
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: {
|
||||
test: {
|
||||
apiProvider: 'anthropic',
|
||||
apiKey: 'test-key',
|
||||
id: 'test-id'
|
||||
}
|
||||
}
|
||||
apiProvider: "anthropic",
|
||||
apiKey: "test-key",
|
||||
id: "test-id",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
||||
|
||||
const config = await configManager.LoadConfig('test')
|
||||
const config = await configManager.LoadConfig("test")
|
||||
|
||||
expect(config).toEqual({
|
||||
apiProvider: 'anthropic',
|
||||
apiKey: 'test-key',
|
||||
id: 'test-id'
|
||||
apiProvider: "anthropic",
|
||||
apiKey: "test-key",
|
||||
id: "test-id",
|
||||
})
|
||||
|
||||
// Get the stored config to check the structure
|
||||
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
|
||||
expect(storedConfig.currentApiConfigName).toBe('test')
|
||||
expect(storedConfig.currentApiConfigName).toBe("test")
|
||||
expect(storedConfig.apiConfigs.test).toEqual({
|
||||
apiProvider: 'anthropic',
|
||||
apiKey: 'test-key',
|
||||
id: 'test-id'
|
||||
apiProvider: "anthropic",
|
||||
apiKey: "test-key",
|
||||
id: "test-id",
|
||||
})
|
||||
})
|
||||
|
||||
it('should throw error when config does not exist', async () => {
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
||||
currentApiConfigName: 'default',
|
||||
it("should throw error when config does not exist", async () => {
|
||||
mockSecrets.get.mockResolvedValue(
|
||||
JSON.stringify({
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: {
|
||||
default: {
|
||||
config: {},
|
||||
id: 'default'
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
await expect(configManager.LoadConfig('nonexistent')).rejects.toThrow(
|
||||
"Config 'nonexistent' not found"
|
||||
id: "default",
|
||||
},
|
||||
},
|
||||
}),
|
||||
)
|
||||
|
||||
await expect(configManager.LoadConfig("nonexistent")).rejects.toThrow("Config 'nonexistent' not found")
|
||||
})
|
||||
|
||||
it('should throw error if secrets storage fails', async () => {
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
||||
currentApiConfigName: 'default',
|
||||
it("should throw error if secrets storage fails", async () => {
|
||||
mockSecrets.get.mockResolvedValue(
|
||||
JSON.stringify({
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: {
|
||||
test: {
|
||||
config: {
|
||||
apiProvider: 'anthropic'
|
||||
apiProvider: "anthropic",
|
||||
},
|
||||
id: 'test-id'
|
||||
}
|
||||
}
|
||||
}))
|
||||
mockSecrets.store.mockRejectedValueOnce(new Error('Storage failed'))
|
||||
id: "test-id",
|
||||
},
|
||||
},
|
||||
}),
|
||||
)
|
||||
mockSecrets.store.mockRejectedValueOnce(new Error("Storage failed"))
|
||||
|
||||
await expect(configManager.LoadConfig('test')).rejects.toThrow(
|
||||
'Failed to load config: Error: Failed to write config to secrets: Error: Storage failed'
|
||||
await expect(configManager.LoadConfig("test")).rejects.toThrow(
|
||||
"Failed to load config: Error: Failed to write config to secrets: Error: Storage failed",
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('SetCurrentConfig', () => {
|
||||
it('should set current config', async () => {
|
||||
describe("SetCurrentConfig", () => {
|
||||
it("should set current config", async () => {
|
||||
const existingConfig: ApiConfigData = {
|
||||
currentApiConfigName: 'default',
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: {
|
||||
default: {
|
||||
id: 'default'
|
||||
id: "default",
|
||||
},
|
||||
test: {
|
||||
apiProvider: 'anthropic',
|
||||
id: 'test-id'
|
||||
}
|
||||
}
|
||||
apiProvider: "anthropic",
|
||||
id: "test-id",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
||||
|
||||
await configManager.SetCurrentConfig('test')
|
||||
await configManager.SetCurrentConfig("test")
|
||||
|
||||
// Get the stored config to check the structure
|
||||
const storedConfig = JSON.parse(mockSecrets.store.mock.calls[0][1])
|
||||
expect(storedConfig.currentApiConfigName).toBe('test')
|
||||
expect(storedConfig.apiConfigs.default.id).toBe('default')
|
||||
expect(storedConfig.currentApiConfigName).toBe("test")
|
||||
expect(storedConfig.apiConfigs.default.id).toBe("default")
|
||||
expect(storedConfig.apiConfigs.test).toEqual({
|
||||
apiProvider: 'anthropic',
|
||||
id: 'test-id'
|
||||
apiProvider: "anthropic",
|
||||
id: "test-id",
|
||||
})
|
||||
})
|
||||
|
||||
it('should throw error when config does not exist', async () => {
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
||||
currentApiConfigName: 'default',
|
||||
apiConfigs: { default: {} }
|
||||
}))
|
||||
it("should throw error when config does not exist", async () => {
|
||||
mockSecrets.get.mockResolvedValue(
|
||||
JSON.stringify({
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: { default: {} },
|
||||
}),
|
||||
)
|
||||
|
||||
await expect(configManager.SetCurrentConfig('nonexistent')).rejects.toThrow(
|
||||
"Config 'nonexistent' not found"
|
||||
await expect(configManager.SetCurrentConfig("nonexistent")).rejects.toThrow(
|
||||
"Config 'nonexistent' not found",
|
||||
)
|
||||
})
|
||||
|
||||
it('should throw error if secrets storage fails', async () => {
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
||||
currentApiConfigName: 'default',
|
||||
it("should throw error if secrets storage fails", async () => {
|
||||
mockSecrets.get.mockResolvedValue(
|
||||
JSON.stringify({
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: {
|
||||
test: { apiProvider: 'anthropic' }
|
||||
}
|
||||
}))
|
||||
mockSecrets.store.mockRejectedValueOnce(new Error('Storage failed'))
|
||||
test: { apiProvider: "anthropic" },
|
||||
},
|
||||
}),
|
||||
)
|
||||
mockSecrets.store.mockRejectedValueOnce(new Error("Storage failed"))
|
||||
|
||||
await expect(configManager.SetCurrentConfig('test')).rejects.toThrow(
|
||||
'Failed to set current config: Error: Failed to write config to secrets: Error: Storage failed'
|
||||
await expect(configManager.SetCurrentConfig("test")).rejects.toThrow(
|
||||
"Failed to set current config: Error: Failed to write config to secrets: Error: Storage failed",
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('HasConfig', () => {
|
||||
it('should return true for existing config', async () => {
|
||||
describe("HasConfig", () => {
|
||||
it("should return true for existing config", async () => {
|
||||
const existingConfig: ApiConfigData = {
|
||||
currentApiConfigName: 'default',
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: {
|
||||
default: {
|
||||
id: 'default'
|
||||
id: "default",
|
||||
},
|
||||
test: {
|
||||
apiProvider: 'anthropic',
|
||||
id: 'test-id'
|
||||
}
|
||||
}
|
||||
apiProvider: "anthropic",
|
||||
id: "test-id",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify(existingConfig))
|
||||
|
||||
const hasConfig = await configManager.HasConfig('test')
|
||||
const hasConfig = await configManager.HasConfig("test")
|
||||
expect(hasConfig).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for non-existent config', async () => {
|
||||
mockSecrets.get.mockResolvedValue(JSON.stringify({
|
||||
currentApiConfigName: 'default',
|
||||
apiConfigs: { default: {} }
|
||||
}))
|
||||
it("should return false for non-existent config", async () => {
|
||||
mockSecrets.get.mockResolvedValue(
|
||||
JSON.stringify({
|
||||
currentApiConfigName: "default",
|
||||
apiConfigs: { default: {} },
|
||||
}),
|
||||
)
|
||||
|
||||
const hasConfig = await configManager.HasConfig('nonexistent')
|
||||
const hasConfig = await configManager.HasConfig("nonexistent")
|
||||
expect(hasConfig).toBe(false)
|
||||
})
|
||||
|
||||
it('should throw error if secrets storage fails', async () => {
|
||||
mockSecrets.get.mockRejectedValue(new Error('Storage failed'))
|
||||
it("should throw error if secrets storage fails", async () => {
|
||||
mockSecrets.get.mockRejectedValue(new Error("Storage failed"))
|
||||
|
||||
await expect(configManager.HasConfig('test')).rejects.toThrow(
|
||||
'Failed to check config existence: Error: Failed to read config from secrets: Error: Storage failed'
|
||||
await expect(configManager.HasConfig("test")).rejects.toThrow(
|
||||
"Failed to check config existence: Error: Failed to read config from secrets: Error: Storage failed",
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import type { DiffStrategy } from './types'
|
||||
import { UnifiedDiffStrategy } from './strategies/unified'
|
||||
import { SearchReplaceDiffStrategy } from './strategies/search-replace'
|
||||
import { NewUnifiedDiffStrategy } from './strategies/new-unified'
|
||||
import type { DiffStrategy } from "./types"
|
||||
import { UnifiedDiffStrategy } from "./strategies/unified"
|
||||
import { SearchReplaceDiffStrategy } from "./strategies/search-replace"
|
||||
import { NewUnifiedDiffStrategy } from "./strategies/new-unified"
|
||||
/**
|
||||
* Get the appropriate diff strategy for the given model
|
||||
* @param model The name of the model being used (e.g., 'gpt-4', 'claude-3-opus')
|
||||
* @returns The appropriate diff strategy for the model
|
||||
*/
|
||||
export function getDiffStrategy(model: string, fuzzyMatchThreshold?: number, experimentalDiffStrategy: boolean = false): DiffStrategy {
|
||||
export function getDiffStrategy(
|
||||
model: string,
|
||||
fuzzyMatchThreshold?: number,
|
||||
experimentalDiffStrategy: boolean = false,
|
||||
): DiffStrategy {
|
||||
if (experimentalDiffStrategy) {
|
||||
return new NewUnifiedDiffStrategy(fuzzyMatchThreshold)
|
||||
}
|
||||
|
||||
@@ -1,46 +1,45 @@
|
||||
import { NewUnifiedDiffStrategy } from '../new-unified';
|
||||
|
||||
describe('main', () => {
|
||||
import { NewUnifiedDiffStrategy } from "../new-unified"
|
||||
|
||||
describe("main", () => {
|
||||
let strategy: NewUnifiedDiffStrategy
|
||||
|
||||
beforeEach(() => {
|
||||
strategy = new NewUnifiedDiffStrategy(0.97)
|
||||
})
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should use default confidence threshold when not provided', () => {
|
||||
describe("constructor", () => {
|
||||
it("should use default confidence threshold when not provided", () => {
|
||||
const defaultStrategy = new NewUnifiedDiffStrategy()
|
||||
expect(defaultStrategy['confidenceThreshold']).toBe(1)
|
||||
expect(defaultStrategy["confidenceThreshold"]).toBe(1)
|
||||
})
|
||||
|
||||
it('should use provided confidence threshold', () => {
|
||||
it("should use provided confidence threshold", () => {
|
||||
const customStrategy = new NewUnifiedDiffStrategy(0.85)
|
||||
expect(customStrategy['confidenceThreshold']).toBe(0.85)
|
||||
expect(customStrategy["confidenceThreshold"]).toBe(0.85)
|
||||
})
|
||||
|
||||
it('should enforce minimum confidence threshold', () => {
|
||||
it("should enforce minimum confidence threshold", () => {
|
||||
const lowStrategy = new NewUnifiedDiffStrategy(0.7) // Below minimum of 0.8
|
||||
expect(lowStrategy['confidenceThreshold']).toBe(0.8)
|
||||
expect(lowStrategy["confidenceThreshold"]).toBe(0.8)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getToolDescription', () => {
|
||||
it('should return tool description with correct cwd', () => {
|
||||
const cwd = '/test/path'
|
||||
describe("getToolDescription", () => {
|
||||
it("should return tool description with correct cwd", () => {
|
||||
const cwd = "/test/path"
|
||||
const description = strategy.getToolDescription({ cwd })
|
||||
|
||||
expect(description).toContain('apply_diff')
|
||||
expect(description).toContain("apply_diff")
|
||||
expect(description).toContain(cwd)
|
||||
expect(description).toContain('Parameters:')
|
||||
expect(description).toContain('Format Requirements:')
|
||||
expect(description).toContain("Parameters:")
|
||||
expect(description).toContain("Format Requirements:")
|
||||
})
|
||||
})
|
||||
|
||||
it('should apply simple diff correctly', async () => {
|
||||
it("should apply simple diff correctly", async () => {
|
||||
const original = `line1
|
||||
line2
|
||||
line3`;
|
||||
line3`
|
||||
|
||||
const diff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ -49,24 +48,24 @@ line3`;
|
||||
+new line
|
||||
line2
|
||||
-line3
|
||||
+modified line3`;
|
||||
+modified line3`
|
||||
|
||||
const result = await strategy.applyDiff(original, diff);
|
||||
expect(result.success).toBe(true);
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`line1
|
||||
new line
|
||||
line2
|
||||
modified line3`);
|
||||
modified line3`)
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
it('should handle multiple hunks', async () => {
|
||||
it("should handle multiple hunks", async () => {
|
||||
const original = `line1
|
||||
line2
|
||||
line3
|
||||
line4
|
||||
line5`;
|
||||
line5`
|
||||
|
||||
const diff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ -80,10 +79,10 @@ line5`;
|
||||
line4
|
||||
-line5
|
||||
+modified line5
|
||||
+new line at end`;
|
||||
+new line at end`
|
||||
|
||||
const result = await strategy.applyDiff(original, diff);
|
||||
expect(result.success).toBe(true);
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`line1
|
||||
new line
|
||||
@@ -91,11 +90,11 @@ line2
|
||||
modified line3
|
||||
line4
|
||||
modified line5
|
||||
new line at end`);
|
||||
new line at end`)
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
it('should handle complex large', async () => {
|
||||
it("should handle complex large", async () => {
|
||||
const original = `line1
|
||||
line2
|
||||
line3
|
||||
@@ -105,7 +104,7 @@ line6
|
||||
line7
|
||||
line8
|
||||
line9
|
||||
line10`;
|
||||
line10`
|
||||
|
||||
const diff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ -130,10 +129,10 @@ line10`;
|
||||
line9
|
||||
-line10
|
||||
+final line
|
||||
+very last line`;
|
||||
+very last line`
|
||||
|
||||
const result = await strategy.applyDiff(original, diff);
|
||||
expect(result.success).toBe(true);
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`line1
|
||||
header line
|
||||
@@ -150,11 +149,11 @@ changed line8
|
||||
bonus line
|
||||
line9
|
||||
final line
|
||||
very last line`);
|
||||
very last line`)
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
it('should handle indentation changes', async () => {
|
||||
it("should handle indentation changes", async () => {
|
||||
const original = `first line
|
||||
indented line
|
||||
double indented line
|
||||
@@ -164,7 +163,7 @@ no indent
|
||||
double indent again
|
||||
triple indent
|
||||
back to single
|
||||
last line`;
|
||||
last line`
|
||||
|
||||
const diff = `--- original
|
||||
+++ modified
|
||||
@@ -181,7 +180,7 @@ last line`;
|
||||
- triple indent
|
||||
+ hi there mate
|
||||
back to single
|
||||
last line`;
|
||||
last line`
|
||||
|
||||
const expected = `first line
|
||||
indented line
|
||||
@@ -194,17 +193,16 @@ no indent
|
||||
double indent again
|
||||
hi there mate
|
||||
back to single
|
||||
last line`;
|
||||
last line`
|
||||
|
||||
const result = await strategy.applyDiff(original, diff);
|
||||
expect(result.success).toBe(true);
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(expected);
|
||||
expect(result.content).toBe(expected)
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle high level edits', async () => {
|
||||
})
|
||||
|
||||
it("should handle high level edits", async () => {
|
||||
const original = `def factorial(n):
|
||||
if n == 0:
|
||||
return 1
|
||||
@@ -228,14 +226,14 @@ const expected = `def factorial(number):
|
||||
else:
|
||||
return number * factorial(number-1)`
|
||||
|
||||
const result = await strategy.applyDiff(original, diff);
|
||||
expect(result.success).toBe(true);
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(expected);
|
||||
expect(result.content).toBe(expected)
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
it('it should handle very complex edits', async () => {
|
||||
it("it should handle very complex edits", async () => {
|
||||
const original = `//Initialize the array that will hold the primes
|
||||
var primeArray = [];
|
||||
/*Write a function that checks for primeness and
|
||||
@@ -321,56 +319,55 @@ for (var i = 2; primeArray.length < numPrimes; i++) {
|
||||
console.log(primeArray);
|
||||
`
|
||||
|
||||
|
||||
const result = await strategy.applyDiff(original, diff);
|
||||
expect(result.success).toBe(true);
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(expected);
|
||||
expect(result.content).toBe(expected)
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
describe('error handling and edge cases', () => {
|
||||
it('should reject completely invalid diff format', async () => {
|
||||
const original = 'line1\nline2\nline3';
|
||||
const invalidDiff = 'this is not a diff at all';
|
||||
describe("error handling and edge cases", () => {
|
||||
it("should reject completely invalid diff format", async () => {
|
||||
const original = "line1\nline2\nline3"
|
||||
const invalidDiff = "this is not a diff at all"
|
||||
|
||||
const result = await strategy.applyDiff(original, invalidDiff);
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
const result = await strategy.applyDiff(original, invalidDiff)
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
|
||||
it('should reject diff with invalid hunk format', async () => {
|
||||
const original = 'line1\nline2\nline3';
|
||||
it("should reject diff with invalid hunk format", async () => {
|
||||
const original = "line1\nline2\nline3"
|
||||
const invalidHunkDiff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
invalid hunk header
|
||||
line1
|
||||
-line2
|
||||
+new line`;
|
||||
+new line`
|
||||
|
||||
const result = await strategy.applyDiff(original, invalidHunkDiff);
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
const result = await strategy.applyDiff(original, invalidHunkDiff)
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
|
||||
it('should fail when diff tries to modify non-existent content', async () => {
|
||||
const original = 'line1\nline2\nline3';
|
||||
it("should fail when diff tries to modify non-existent content", async () => {
|
||||
const original = "line1\nline2\nline3"
|
||||
const nonMatchingDiff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ ... @@
|
||||
line1
|
||||
-nonexistent line
|
||||
+new line
|
||||
line3`;
|
||||
line3`
|
||||
|
||||
const result = await strategy.applyDiff(original, nonMatchingDiff);
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
const result = await strategy.applyDiff(original, nonMatchingDiff)
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle overlapping hunks', async () => {
|
||||
it("should handle overlapping hunks", async () => {
|
||||
const original = `line1
|
||||
line2
|
||||
line3
|
||||
line4
|
||||
line5`;
|
||||
line5`
|
||||
const overlappingDiff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ ... @@
|
||||
@@ -384,18 +381,18 @@ line5`;
|
||||
-line3
|
||||
-line4
|
||||
+modified3and4
|
||||
line5`;
|
||||
line5`
|
||||
|
||||
const result = await strategy.applyDiff(original, overlappingDiff);
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
const result = await strategy.applyDiff(original, overlappingDiff)
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle empty lines modifications', async () => {
|
||||
it("should handle empty lines modifications", async () => {
|
||||
const original = `line1
|
||||
|
||||
line3
|
||||
|
||||
line5`;
|
||||
line5`
|
||||
const emptyLinesDiff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ ... @@
|
||||
@@ -404,53 +401,53 @@ line5`;
|
||||
-line3
|
||||
+line3modified
|
||||
|
||||
line5`;
|
||||
line5`
|
||||
|
||||
const result = await strategy.applyDiff(original, emptyLinesDiff);
|
||||
expect(result.success).toBe(true);
|
||||
const result = await strategy.applyDiff(original, emptyLinesDiff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`line1
|
||||
|
||||
line3modified
|
||||
|
||||
line5`);
|
||||
line5`)
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
it('should handle mixed line endings in diff', async () => {
|
||||
const original = 'line1\r\nline2\nline3\r\n';
|
||||
it("should handle mixed line endings in diff", async () => {
|
||||
const original = "line1\r\nline2\nline3\r\n"
|
||||
const mixedEndingsDiff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ ... @@
|
||||
line1\r
|
||||
-line2
|
||||
+modified2\r
|
||||
line3`;
|
||||
line3`
|
||||
|
||||
const result = await strategy.applyDiff(original, mixedEndingsDiff);
|
||||
expect(result.success).toBe(true);
|
||||
const result = await strategy.applyDiff(original, mixedEndingsDiff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe('line1\r\nmodified2\r\nline3\r\n');
|
||||
expect(result.content).toBe("line1\r\nmodified2\r\nline3\r\n")
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
it('should handle partial line modifications', async () => {
|
||||
const original = 'const value = oldValue + 123;';
|
||||
it("should handle partial line modifications", async () => {
|
||||
const original = "const value = oldValue + 123;"
|
||||
const partialDiff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ ... @@
|
||||
-const value = oldValue + 123;
|
||||
+const value = newValue + 123;`;
|
||||
+const value = newValue + 123;`
|
||||
|
||||
const result = await strategy.applyDiff(original, partialDiff);
|
||||
expect(result.success).toBe(true);
|
||||
const result = await strategy.applyDiff(original, partialDiff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe('const value = newValue + 123;');
|
||||
expect(result.content).toBe("const value = newValue + 123;")
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
it('should handle slightly malformed but recoverable diff', async () => {
|
||||
const original = 'line1\nline2\nline3';
|
||||
it("should handle slightly malformed but recoverable diff", async () => {
|
||||
const original = "line1\nline2\nline3"
|
||||
// Missing space after --- and +++
|
||||
const slightlyBadDiff = `---a/file.txt
|
||||
+++b/file.txt
|
||||
@@ -458,18 +455,18 @@ line5`);
|
||||
line1
|
||||
-line2
|
||||
+new line
|
||||
line3`;
|
||||
line3`
|
||||
|
||||
const result = await strategy.applyDiff(original, slightlyBadDiff);
|
||||
expect(result.success).toBe(true);
|
||||
const result = await strategy.applyDiff(original, slightlyBadDiff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe('line1\nnew line\nline3');
|
||||
expect(result.content).toBe("line1\nnew line\nline3")
|
||||
}
|
||||
});
|
||||
});
|
||||
})
|
||||
})
|
||||
|
||||
describe('similar code sections', () => {
|
||||
it('should correctly modify the right section when similar code exists', async () => {
|
||||
describe("similar code sections", () => {
|
||||
it("should correctly modify the right section when similar code exists", async () => {
|
||||
const original = `function add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
@@ -480,7 +477,7 @@ function subtract(a, b) {
|
||||
|
||||
function multiply(a, b) {
|
||||
return a + b; // Bug here
|
||||
}`;
|
||||
}`
|
||||
|
||||
const diff = `--- a/math.js
|
||||
+++ b/math.js
|
||||
@@ -488,10 +485,10 @@ function multiply(a, b) {
|
||||
function multiply(a, b) {
|
||||
- return a + b; // Bug here
|
||||
+ return a * b;
|
||||
}`;
|
||||
}`
|
||||
|
||||
const result = await strategy.applyDiff(original, diff);
|
||||
expect(result.success).toBe(true);
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`function add(a, b) {
|
||||
return a + b;
|
||||
@@ -503,11 +500,11 @@ function subtract(a, b) {
|
||||
|
||||
function multiply(a, b) {
|
||||
return a * b;
|
||||
}`);
|
||||
}`)
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
it('should handle multiple similar sections with correct context', async () => {
|
||||
it("should handle multiple similar sections with correct context", async () => {
|
||||
const original = `if (condition) {
|
||||
doSomething();
|
||||
doSomething();
|
||||
@@ -518,7 +515,7 @@ if (otherCondition) {
|
||||
doSomething();
|
||||
doSomething();
|
||||
doSomething();
|
||||
}`;
|
||||
}`
|
||||
|
||||
const diff = `--- a/file.js
|
||||
+++ b/file.js
|
||||
@@ -528,10 +525,10 @@ if (otherCondition) {
|
||||
- doSomething();
|
||||
+ doSomethingElse();
|
||||
doSomething();
|
||||
}`;
|
||||
}`
|
||||
|
||||
const result = await strategy.applyDiff(original, diff);
|
||||
expect(result.success).toBe(true);
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`if (condition) {
|
||||
doSomething();
|
||||
@@ -543,13 +540,13 @@ if (otherCondition) {
|
||||
doSomething();
|
||||
doSomethingElse();
|
||||
doSomething();
|
||||
}`);
|
||||
}`)
|
||||
}
|
||||
});
|
||||
});
|
||||
})
|
||||
})
|
||||
|
||||
describe('hunk splitting', () => {
|
||||
it('should handle large diffs with multiple non-contiguous changes', async () => {
|
||||
describe("hunk splitting", () => {
|
||||
it("should handle large diffs with multiple non-contiguous changes", async () => {
|
||||
const original = `import { readFile } from 'fs';
|
||||
import { join } from 'path';
|
||||
import { Logger } from './logger';
|
||||
@@ -595,7 +592,7 @@ export {
|
||||
validateInput,
|
||||
writeOutput,
|
||||
parseConfig
|
||||
};`;
|
||||
};`
|
||||
|
||||
const diff = `--- a/file.ts
|
||||
+++ b/file.ts
|
||||
@@ -672,7 +669,7 @@ export {
|
||||
- parseConfig
|
||||
+ parseConfig,
|
||||
+ type Config
|
||||
};`;
|
||||
};`
|
||||
|
||||
const expected = `import { readFile, writeFile } from 'fs';
|
||||
import { join } from 'path';
|
||||
@@ -727,13 +724,13 @@ export {
|
||||
writeOutput,
|
||||
parseConfig,
|
||||
type Config
|
||||
};`;
|
||||
};`
|
||||
|
||||
const result = await strategy.applyDiff(original, diff);
|
||||
expect(result.success).toBe(true);
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(expected);
|
||||
expect(result.content).toBe(expected)
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { SearchReplaceDiffStrategy } from '../search-replace'
|
||||
import { SearchReplaceDiffStrategy } from "../search-replace"
|
||||
|
||||
describe('SearchReplaceDiffStrategy', () => {
|
||||
describe('exact matching', () => {
|
||||
describe("SearchReplaceDiffStrategy", () => {
|
||||
describe("exact matching", () => {
|
||||
let strategy: SearchReplaceDiffStrategy
|
||||
|
||||
beforeEach(() => {
|
||||
strategy = new SearchReplaceDiffStrategy(1.0, 5) // Default 1.0 threshold for exact matching, 5 line buffer for tests
|
||||
})
|
||||
|
||||
it('should replace matching content', async () => {
|
||||
it("should replace matching content", async () => {
|
||||
const originalContent = 'function hello() {\n console.log("hello")\n}\n'
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
@@ -28,8 +28,8 @@ function hello() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should match content with different surrounding whitespace', async () => {
|
||||
const originalContent = '\nfunction example() {\n return 42;\n}\n\n'
|
||||
it("should match content with different surrounding whitespace", async () => {
|
||||
const originalContent = "\nfunction example() {\n return 42;\n}\n\n"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
function example() {
|
||||
@@ -44,12 +44,12 @@ function example() {
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe('\nfunction example() {\n return 43;\n}\n\n')
|
||||
expect(result.content).toBe("\nfunction example() {\n return 43;\n}\n\n")
|
||||
}
|
||||
})
|
||||
|
||||
it('should match content with different indentation in search block', async () => {
|
||||
const originalContent = ' function test() {\n return true;\n }\n'
|
||||
it("should match content with different indentation in search block", async () => {
|
||||
const originalContent = " function test() {\n return true;\n }\n"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
function test() {
|
||||
@@ -64,11 +64,11 @@ function test() {
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(' function test() {\n return false;\n }\n')
|
||||
expect(result.content).toBe(" function test() {\n return false;\n }\n")
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle tab-based indentation', async () => {
|
||||
it("should handle tab-based indentation", async () => {
|
||||
const originalContent = "function test() {\n\treturn true;\n}\n"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
@@ -88,7 +88,7 @@ function test() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should preserve mixed tabs and spaces', async () => {
|
||||
it("should preserve mixed tabs and spaces", async () => {
|
||||
const originalContent = "\tclass Example {\n\t constructor() {\n\t\tthis.value = 0;\n\t }\n\t}"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
@@ -108,11 +108,13 @@ function test() {
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe("\tclass Example {\n\t constructor() {\n\t\tthis.value = 1;\n\t }\n\t}")
|
||||
expect(result.content).toBe(
|
||||
"\tclass Example {\n\t constructor() {\n\t\tthis.value = 1;\n\t }\n\t}",
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle additional indentation with tabs', async () => {
|
||||
it("should handle additional indentation with tabs", async () => {
|
||||
const originalContent = "\tfunction test() {\n\t\treturn true;\n\t}"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
@@ -133,7 +135,7 @@ function test() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should preserve exact indentation characters when adding lines', async () => {
|
||||
it("should preserve exact indentation characters when adding lines", async () => {
|
||||
const originalContent = "\tfunction test() {\n\t\treturn true;\n\t}"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
@@ -151,11 +153,13 @@ function test() {
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe("\tfunction test() {\n\t\t// First comment\n\t\t// Second comment\n\t\treturn true;\n\t}")
|
||||
expect(result.content).toBe(
|
||||
"\tfunction test() {\n\t\t// First comment\n\t\t// Second comment\n\t\treturn true;\n\t}",
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle Windows-style CRLF line endings', async () => {
|
||||
it("should handle Windows-style CRLF line endings", async () => {
|
||||
const originalContent = "function test() {\r\n return true;\r\n}\r\n"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
@@ -175,7 +179,7 @@ function test() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should return false if search content does not match', async () => {
|
||||
it("should return false if search content does not match", async () => {
|
||||
const originalContent = 'function hello() {\n console.log("hello")\n}\n'
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
@@ -192,7 +196,7 @@ function hello() {
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false if diff format is invalid', async () => {
|
||||
it("should return false if diff format is invalid", async () => {
|
||||
const originalContent = 'function hello() {\n console.log("hello")\n}\n'
|
||||
const diffContent = `test.ts\nInvalid diff format`
|
||||
|
||||
@@ -200,8 +204,9 @@ function hello() {
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle multiple lines with proper indentation', async () => {
|
||||
const originalContent = 'class Example {\n constructor() {\n this.value = 0\n }\n\n getValue() {\n return this.value\n }\n}\n'
|
||||
it("should handle multiple lines with proper indentation", async () => {
|
||||
const originalContent =
|
||||
"class Example {\n constructor() {\n this.value = 0\n }\n\n getValue() {\n return this.value\n }\n}\n"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
getValue() {
|
||||
@@ -218,11 +223,13 @@ function hello() {
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe('class Example {\n constructor() {\n this.value = 0\n }\n\n getValue() {\n // Add logging\n console.log("Getting value")\n return this.value\n }\n}\n')
|
||||
expect(result.content).toBe(
|
||||
'class Example {\n constructor() {\n this.value = 0\n }\n\n getValue() {\n // Add logging\n console.log("Getting value")\n return this.value\n }\n}\n',
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
it('should preserve whitespace exactly in the output', async () => {
|
||||
it("should preserve whitespace exactly in the output", async () => {
|
||||
const originalContent = " indented\n more indented\n back\n"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
@@ -242,8 +249,8 @@ function hello() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should preserve indentation when adding new lines after existing content', async () => {
|
||||
const originalContent = ' onScroll={() => updateHighlights()}'
|
||||
it("should preserve indentation when adding new lines after existing content", async () => {
|
||||
const originalContent = " onScroll={() => updateHighlights()}"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
onScroll={() => updateHighlights()}
|
||||
@@ -258,11 +265,13 @@ function hello() {
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(' onScroll={() => updateHighlights()}\n onDragOver={(e) => {\n e.preventDefault()\n e.stopPropagation()\n }}')
|
||||
expect(result.content).toBe(
|
||||
" onScroll={() => updateHighlights()}\n onDragOver={(e) => {\n e.preventDefault()\n e.stopPropagation()\n }}",
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle varying indentation levels correctly', async () => {
|
||||
it("should handle varying indentation levels correctly", async () => {
|
||||
const originalContent = `
|
||||
class Example {
|
||||
constructor() {
|
||||
@@ -271,7 +280,7 @@ class Example {
|
||||
this.init();
|
||||
}
|
||||
}
|
||||
}`.trim();
|
||||
}`.trim()
|
||||
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
@@ -294,12 +303,13 @@ class Example {
|
||||
}
|
||||
}
|
||||
}
|
||||
>>>>>>> REPLACE`.trim();
|
||||
>>>>>>> REPLACE`.trim()
|
||||
|
||||
const result = await strategy.applyDiff(originalContent, diffContent);
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`
|
||||
expect(result.content).toBe(
|
||||
`
|
||||
class Example {
|
||||
constructor() {
|
||||
this.value = 1;
|
||||
@@ -309,11 +319,12 @@ class Example {
|
||||
this.validate();
|
||||
}
|
||||
}
|
||||
}`.trim());
|
||||
}`.trim(),
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle mixed indentation styles in the same file', async () => {
|
||||
it("should handle mixed indentation styles in the same file", async () => {
|
||||
const originalContent = `class Example {
|
||||
constructor() {
|
||||
this.value = 0;
|
||||
@@ -321,7 +332,7 @@ class Example {
|
||||
this.init();
|
||||
}
|
||||
}
|
||||
}`.trim();
|
||||
}`.trim()
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
constructor() {
|
||||
@@ -338,9 +349,9 @@ class Example {
|
||||
this.validate();
|
||||
}
|
||||
}
|
||||
>>>>>>> REPLACE`;
|
||||
>>>>>>> REPLACE`
|
||||
|
||||
const result = await strategy.applyDiff(originalContent, diffContent);
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`class Example {
|
||||
@@ -351,17 +362,17 @@ class Example {
|
||||
this.validate();
|
||||
}
|
||||
}
|
||||
}`);
|
||||
}`)
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle Python-style significant whitespace', async () => {
|
||||
it("should handle Python-style significant whitespace", async () => {
|
||||
const originalContent = `def example():
|
||||
if condition:
|
||||
do_something()
|
||||
for item in items:
|
||||
process(item)
|
||||
return True`.trim();
|
||||
return True`.trim()
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
if condition:
|
||||
@@ -374,9 +385,9 @@ class Example {
|
||||
while items:
|
||||
item = items.pop()
|
||||
process(item)
|
||||
>>>>>>> REPLACE`;
|
||||
>>>>>>> REPLACE`
|
||||
|
||||
const result = await strategy.applyDiff(originalContent, diffContent);
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`def example():
|
||||
@@ -385,18 +396,18 @@ class Example {
|
||||
while items:
|
||||
item = items.pop()
|
||||
process(item)
|
||||
return True`);
|
||||
return True`)
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
it('should preserve empty lines with indentation', async () => {
|
||||
it("should preserve empty lines with indentation", async () => {
|
||||
const originalContent = `function test() {
|
||||
const x = 1;
|
||||
|
||||
if (x) {
|
||||
return true;
|
||||
}
|
||||
}`.trim();
|
||||
}`.trim()
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
const x = 1;
|
||||
@@ -407,9 +418,9 @@ class Example {
|
||||
|
||||
// Check x
|
||||
if (x) {
|
||||
>>>>>>> REPLACE`;
|
||||
>>>>>>> REPLACE`
|
||||
|
||||
const result = await strategy.applyDiff(originalContent, diffContent);
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`function test() {
|
||||
@@ -419,18 +430,18 @@ class Example {
|
||||
if (x) {
|
||||
return true;
|
||||
}
|
||||
}`);
|
||||
}`)
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
it('should handle indentation when replacing entire blocks', async () => {
|
||||
it("should handle indentation when replacing entire blocks", async () => {
|
||||
const originalContent = `class Test {
|
||||
method() {
|
||||
if (true) {
|
||||
console.log("test");
|
||||
}
|
||||
}
|
||||
}`.trim();
|
||||
}`.trim()
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
method() {
|
||||
@@ -448,9 +459,9 @@ class Example {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
>>>>>>> REPLACE`;
|
||||
>>>>>>> REPLACE`
|
||||
|
||||
const result = await strategy.applyDiff(originalContent, diffContent);
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`class Test {
|
||||
@@ -463,11 +474,11 @@ class Example {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
}`);
|
||||
}`)
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
it('should handle negative indentation relative to search content', async () => {
|
||||
it("should handle negative indentation relative to search content", async () => {
|
||||
const originalContent = `class Example {
|
||||
constructor() {
|
||||
if (true) {
|
||||
@@ -475,7 +486,7 @@ class Example {
|
||||
this.setup();
|
||||
}
|
||||
}
|
||||
}`.trim();
|
||||
}`.trim()
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
this.init();
|
||||
@@ -483,9 +494,9 @@ class Example {
|
||||
=======
|
||||
this.init();
|
||||
this.setup();
|
||||
>>>>>>> REPLACE`;
|
||||
>>>>>>> REPLACE`
|
||||
|
||||
const result = await strategy.applyDiff(originalContent, diffContent);
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`class Example {
|
||||
@@ -495,26 +506,26 @@ class Example {
|
||||
this.setup();
|
||||
}
|
||||
}
|
||||
}`);
|
||||
}`)
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
it('should handle extreme negative indentation (no indent)', async () => {
|
||||
it("should handle extreme negative indentation (no indent)", async () => {
|
||||
const originalContent = `class Example {
|
||||
constructor() {
|
||||
if (true) {
|
||||
this.init();
|
||||
}
|
||||
}
|
||||
}`.trim();
|
||||
}`.trim()
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
this.init();
|
||||
=======
|
||||
this.init();
|
||||
>>>>>>> REPLACE`;
|
||||
>>>>>>> REPLACE`
|
||||
|
||||
const result = await strategy.applyDiff(originalContent, diffContent);
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`class Example {
|
||||
@@ -523,11 +534,11 @@ this.init();
|
||||
this.init();
|
||||
}
|
||||
}
|
||||
}`);
|
||||
}`)
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
it('should handle mixed indentation changes in replace block', async () => {
|
||||
it("should handle mixed indentation changes in replace block", async () => {
|
||||
const originalContent = `class Example {
|
||||
constructor() {
|
||||
if (true) {
|
||||
@@ -536,7 +547,7 @@ this.init();
|
||||
this.validate();
|
||||
}
|
||||
}
|
||||
}`.trim();
|
||||
}`.trim()
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
this.init();
|
||||
@@ -546,9 +557,9 @@ this.init();
|
||||
this.init();
|
||||
this.setup();
|
||||
this.validate();
|
||||
>>>>>>> REPLACE`;
|
||||
>>>>>>> REPLACE`
|
||||
|
||||
const result = await strategy.applyDiff(originalContent, diffContent);
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`class Example {
|
||||
@@ -559,11 +570,11 @@ this.init();
|
||||
this.validate();
|
||||
}
|
||||
}
|
||||
}`);
|
||||
}`)
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
it('should find matches from middle out', async () => {
|
||||
it("should find matches from middle out", async () => {
|
||||
const originalContent = `
|
||||
function one() {
|
||||
return "target";
|
||||
@@ -621,16 +632,16 @@ function five() {
|
||||
})
|
||||
})
|
||||
|
||||
describe('line number stripping', () => {
|
||||
describe('line number stripping', () => {
|
||||
describe("line number stripping", () => {
|
||||
describe("line number stripping", () => {
|
||||
let strategy: SearchReplaceDiffStrategy
|
||||
|
||||
beforeEach(() => {
|
||||
strategy = new SearchReplaceDiffStrategy()
|
||||
})
|
||||
|
||||
it('should strip line numbers from both search and replace sections', async () => {
|
||||
const originalContent = 'function test() {\n return true;\n}\n'
|
||||
it("should strip line numbers from both search and replace sections", async () => {
|
||||
const originalContent = "function test() {\n return true;\n}\n"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
1 | function test() {
|
||||
@@ -645,12 +656,12 @@ function five() {
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe('function test() {\n return false;\n}\n')
|
||||
expect(result.content).toBe("function test() {\n return false;\n}\n")
|
||||
}
|
||||
})
|
||||
|
||||
it('should strip line numbers with leading spaces', async () => {
|
||||
const originalContent = 'function test() {\n return true;\n}\n'
|
||||
it("should strip line numbers with leading spaces", async () => {
|
||||
const originalContent = "function test() {\n return true;\n}\n"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
1 | function test() {
|
||||
@@ -665,12 +676,12 @@ function five() {
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe('function test() {\n return false;\n}\n')
|
||||
expect(result.content).toBe("function test() {\n return false;\n}\n")
|
||||
}
|
||||
})
|
||||
|
||||
it('should not strip when not all lines have numbers in either section', async () => {
|
||||
const originalContent = 'function test() {\n return true;\n}\n'
|
||||
it("should not strip when not all lines have numbers in either section", async () => {
|
||||
const originalContent = "function test() {\n return true;\n}\n"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
1 | function test() {
|
||||
@@ -686,8 +697,8 @@ function five() {
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
|
||||
it('should preserve content that naturally starts with pipe', async () => {
|
||||
const originalContent = '|header|another|\n|---|---|\n|data|more|\n'
|
||||
it("should preserve content that naturally starts with pipe", async () => {
|
||||
const originalContent = "|header|another|\n|---|---|\n|data|more|\n"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
1 | |header|another|
|
||||
@@ -702,12 +713,12 @@ function five() {
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe('|header|another|\n|---|---|\n|data|updated|\n')
|
||||
expect(result.content).toBe("|header|another|\n|---|---|\n|data|updated|\n")
|
||||
}
|
||||
})
|
||||
|
||||
it('should preserve indentation when stripping line numbers', async () => {
|
||||
const originalContent = ' function test() {\n return true;\n }\n'
|
||||
it("should preserve indentation when stripping line numbers", async () => {
|
||||
const originalContent = " function test() {\n return true;\n }\n"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
1 | function test() {
|
||||
@@ -722,12 +733,12 @@ function five() {
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(' function test() {\n return false;\n }\n')
|
||||
expect(result.content).toBe(" function test() {\n return false;\n }\n")
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle different line numbers between sections', async () => {
|
||||
const originalContent = 'function test() {\n return true;\n}\n'
|
||||
it("should handle different line numbers between sections", async () => {
|
||||
const originalContent = "function test() {\n return true;\n}\n"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
10 | function test() {
|
||||
@@ -742,12 +753,12 @@ function five() {
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe('function test() {\n return false;\n}\n')
|
||||
expect(result.content).toBe("function test() {\n return false;\n}\n")
|
||||
}
|
||||
})
|
||||
|
||||
it('should not strip content that starts with pipe but no line number', async () => {
|
||||
const originalContent = '| Pipe\n|---|\n| Data\n'
|
||||
it("should not strip content that starts with pipe but no line number", async () => {
|
||||
const originalContent = "| Pipe\n|---|\n| Data\n"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
| Pipe
|
||||
@@ -762,12 +773,12 @@ function five() {
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe('| Pipe\n|---|\n| Updated\n')
|
||||
expect(result.content).toBe("| Pipe\n|---|\n| Updated\n")
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle mix of line-numbered and pipe-only content', async () => {
|
||||
const originalContent = '| Pipe\n|---|\n| Data\n'
|
||||
it("should handle mix of line-numbered and pipe-only content", async () => {
|
||||
const originalContent = "| Pipe\n|---|\n| Data\n"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
| Pipe
|
||||
@@ -782,21 +793,21 @@ function five() {
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe('1 | | Pipe\n2 | |---|\n3 | | NewData\n')
|
||||
expect(result.content).toBe("1 | | Pipe\n2 | |---|\n3 | | NewData\n")
|
||||
}
|
||||
})
|
||||
})
|
||||
});
|
||||
})
|
||||
|
||||
describe('insertion/deletion', () => {
|
||||
describe("insertion/deletion", () => {
|
||||
let strategy: SearchReplaceDiffStrategy
|
||||
|
||||
beforeEach(() => {
|
||||
strategy = new SearchReplaceDiffStrategy()
|
||||
})
|
||||
|
||||
describe('deletion', () => {
|
||||
it('should delete code when replace block is empty', async () => {
|
||||
describe("deletion", () => {
|
||||
it("should delete code when replace block is empty", async () => {
|
||||
const originalContent = `function test() {
|
||||
console.log("hello");
|
||||
// Comment to remove
|
||||
@@ -818,7 +829,7 @@ function five() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should delete multiple lines when replace block is empty', async () => {
|
||||
it("should delete multiple lines when replace block is empty", async () => {
|
||||
const originalContent = `class Example {
|
||||
constructor() {
|
||||
// Initialize
|
||||
@@ -848,7 +859,7 @@ function five() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should preserve indentation when deleting nested code', async () => {
|
||||
it("should preserve indentation when deleting nested code", async () => {
|
||||
const originalContent = `function outer() {
|
||||
if (true) {
|
||||
// Remove this
|
||||
@@ -877,8 +888,8 @@ function five() {
|
||||
})
|
||||
})
|
||||
|
||||
describe('insertion', () => {
|
||||
it('should insert code at specified line when search block is empty', async () => {
|
||||
describe("insertion", () => {
|
||||
it("should insert code at specified line when search block is empty", async () => {
|
||||
const originalContent = `function test() {
|
||||
const x = 1;
|
||||
return x;
|
||||
@@ -900,7 +911,7 @@ function five() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should preserve indentation when inserting at nested location', async () => {
|
||||
it("should preserve indentation when inserting at nested location", async () => {
|
||||
const originalContent = `function test() {
|
||||
if (true) {
|
||||
const x = 1;
|
||||
@@ -926,7 +937,7 @@ function five() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle insertion at start of file', async () => {
|
||||
it("should handle insertion at start of file", async () => {
|
||||
const originalContent = `function test() {
|
||||
return true;
|
||||
}`
|
||||
@@ -950,7 +961,7 @@ function test() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle insertion at end of file', async () => {
|
||||
it("should handle insertion at end of file", async () => {
|
||||
const originalContent = `function test() {
|
||||
return true;
|
||||
}`
|
||||
@@ -972,7 +983,7 @@ function test() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should error if no start_line is provided for insertion', async () => {
|
||||
it("should error if no start_line is provided for insertion", async () => {
|
||||
const originalContent = `function test() {
|
||||
return true;
|
||||
}`
|
||||
@@ -988,14 +999,15 @@ console.log("test");
|
||||
})
|
||||
})
|
||||
|
||||
describe('fuzzy matching', () => {
|
||||
describe("fuzzy matching", () => {
|
||||
let strategy: SearchReplaceDiffStrategy
|
||||
beforeEach(() => {
|
||||
strategy = new SearchReplaceDiffStrategy(0.9, 5) // 90% similarity threshold, 5 line buffer for tests
|
||||
})
|
||||
|
||||
it('should match content with small differences (>90% similar)', async () => {
|
||||
const originalContent = 'function getData() {\n const results = fetchData();\n return results.filter(Boolean);\n}\n'
|
||||
it("should match content with small differences (>90% similar)", async () => {
|
||||
const originalContent =
|
||||
"function getData() {\n const results = fetchData();\n return results.filter(Boolean);\n}\n"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
function getData() {
|
||||
@@ -1014,12 +1026,14 @@ function getData() {
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe('function getData() {\n const data = fetchData();\n return data.filter(Boolean);\n}\n')
|
||||
expect(result.content).toBe(
|
||||
"function getData() {\n const data = fetchData();\n return data.filter(Boolean);\n}\n",
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
it('should not match when content is too different (<90% similar)', async () => {
|
||||
const originalContent = 'function processUsers(data) {\n return data.map(user => user.name);\n}\n'
|
||||
it("should not match when content is too different (<90% similar)", async () => {
|
||||
const originalContent = "function processUsers(data) {\n return data.map(user => user.name);\n}\n"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
function handleItems(items) {
|
||||
@@ -1035,8 +1049,8 @@ function processData(data) {
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
|
||||
it('should match content with extra whitespace', async () => {
|
||||
const originalContent = 'function sum(a, b) {\n return a + b;\n}'
|
||||
it("should match content with extra whitespace", async () => {
|
||||
const originalContent = "function sum(a, b) {\n return a + b;\n}"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
function sum(a, b) {
|
||||
@@ -1051,12 +1065,12 @@ function sum(a, b) {
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe('function sum(a, b) {\n return a + b + 1;\n}')
|
||||
expect(result.content).toBe("function sum(a, b) {\n return a + b + 1;\n}")
|
||||
}
|
||||
})
|
||||
|
||||
it('should not exact match empty lines', async () => {
|
||||
const originalContent = 'function sum(a, b) {\n\n return a + b;\n}'
|
||||
it("should not exact match empty lines", async () => {
|
||||
const originalContent = "function sum(a, b) {\n\n return a + b;\n}"
|
||||
const diffContent = `test.ts
|
||||
<<<<<<< SEARCH
|
||||
function sum(a, b) {
|
||||
@@ -1073,14 +1087,14 @@ function sum(a, b) {
|
||||
})
|
||||
})
|
||||
|
||||
describe('line-constrained search', () => {
|
||||
describe("line-constrained search", () => {
|
||||
let strategy: SearchReplaceDiffStrategy
|
||||
|
||||
beforeEach(() => {
|
||||
strategy = new SearchReplaceDiffStrategy(0.9, 5)
|
||||
})
|
||||
|
||||
it('should find and replace within specified line range', async () => {
|
||||
it("should find and replace within specified line range", async () => {
|
||||
const originalContent = `
|
||||
function one() {
|
||||
return 1;
|
||||
@@ -1122,7 +1136,7 @@ function three() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should find and replace within buffer zone (5 lines before/after)', async () => {
|
||||
it("should find and replace within buffer zone (5 lines before/after)", async () => {
|
||||
const originalContent = `
|
||||
function one() {
|
||||
return 1;
|
||||
@@ -1166,7 +1180,7 @@ function three() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should not find matches outside search range and buffer zone', async () => {
|
||||
it("should not find matches outside search range and buffer zone", async () => {
|
||||
const originalContent = `
|
||||
function one() {
|
||||
return 1;
|
||||
@@ -1205,7 +1219,7 @@ function five() {
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle search range at start of file', async () => {
|
||||
it("should handle search range at start of file", async () => {
|
||||
const originalContent = `
|
||||
function one() {
|
||||
return 1;
|
||||
@@ -1239,7 +1253,7 @@ function two() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle search range at end of file', async () => {
|
||||
it("should handle search range at end of file", async () => {
|
||||
const originalContent = `
|
||||
function one() {
|
||||
return 1;
|
||||
@@ -1273,7 +1287,7 @@ function two() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should match specific instance of duplicate code using line numbers', async () => {
|
||||
it("should match specific instance of duplicate code using line numbers", async () => {
|
||||
const originalContent = `
|
||||
function processData(data) {
|
||||
return data.map(x => x * 2);
|
||||
@@ -1330,7 +1344,7 @@ function moreStuff() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should search from start line to end of file when only start_line is provided', async () => {
|
||||
it("should search from start line to end of file when only start_line is provided", async () => {
|
||||
const originalContent = `
|
||||
function one() {
|
||||
return 1;
|
||||
@@ -1373,7 +1387,7 @@ function three() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should search from start of file to end line when only end_line is provided', async () => {
|
||||
it("should search from start of file to end line when only end_line is provided", async () => {
|
||||
const originalContent = `
|
||||
function one() {
|
||||
return 1;
|
||||
@@ -1416,7 +1430,7 @@ function three() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should prioritize exact line match over expanded search', async () => {
|
||||
it("should prioritize exact line match over expanded search", async () => {
|
||||
const originalContent = `
|
||||
function one() {
|
||||
return 1;
|
||||
@@ -1468,7 +1482,7 @@ function two() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should fall back to expanded search only if exact match fails', async () => {
|
||||
it("should fall back to expanded search only if exact match fails", async () => {
|
||||
const originalContent = `
|
||||
function one() {
|
||||
return 1;
|
||||
@@ -1512,32 +1526,32 @@ function two() {
|
||||
})
|
||||
})
|
||||
|
||||
describe('getToolDescription', () => {
|
||||
describe("getToolDescription", () => {
|
||||
let strategy: SearchReplaceDiffStrategy
|
||||
|
||||
beforeEach(() => {
|
||||
strategy = new SearchReplaceDiffStrategy()
|
||||
})
|
||||
|
||||
it('should include the current working directory', async () => {
|
||||
const cwd = '/test/dir'
|
||||
it("should include the current working directory", async () => {
|
||||
const cwd = "/test/dir"
|
||||
const description = await strategy.getToolDescription({ cwd })
|
||||
expect(description).toContain(`relative to the current working directory ${cwd}`)
|
||||
})
|
||||
|
||||
it('should include required format elements', async () => {
|
||||
const description = await strategy.getToolDescription({ cwd: '/test' })
|
||||
expect(description).toContain('<<<<<<< SEARCH')
|
||||
expect(description).toContain('=======')
|
||||
expect(description).toContain('>>>>>>> REPLACE')
|
||||
expect(description).toContain('<apply_diff>')
|
||||
expect(description).toContain('</apply_diff>')
|
||||
it("should include required format elements", async () => {
|
||||
const description = await strategy.getToolDescription({ cwd: "/test" })
|
||||
expect(description).toContain("<<<<<<< SEARCH")
|
||||
expect(description).toContain("=======")
|
||||
expect(description).toContain(">>>>>>> REPLACE")
|
||||
expect(description).toContain("<apply_diff>")
|
||||
expect(description).toContain("</apply_diff>")
|
||||
})
|
||||
|
||||
it('should document start_line and end_line parameters', async () => {
|
||||
const description = await strategy.getToolDescription({ cwd: '/test' })
|
||||
expect(description).toContain('start_line: (required) The line number where the search block starts.')
|
||||
expect(description).toContain('end_line: (required) The line number where the search block ends.')
|
||||
it("should document start_line and end_line parameters", async () => {
|
||||
const description = await strategy.getToolDescription({ cwd: "/test" })
|
||||
expect(description).toContain("start_line: (required) The line number where the search block starts.")
|
||||
expect(description).toContain("end_line: (required) The line number where the search block ends.")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
import { UnifiedDiffStrategy } from '../unified'
|
||||
import { UnifiedDiffStrategy } from "../unified"
|
||||
|
||||
describe('UnifiedDiffStrategy', () => {
|
||||
describe("UnifiedDiffStrategy", () => {
|
||||
let strategy: UnifiedDiffStrategy
|
||||
|
||||
beforeEach(() => {
|
||||
strategy = new UnifiedDiffStrategy()
|
||||
})
|
||||
|
||||
describe('getToolDescription', () => {
|
||||
it('should return tool description with correct cwd', () => {
|
||||
const cwd = '/test/path'
|
||||
describe("getToolDescription", () => {
|
||||
it("should return tool description with correct cwd", () => {
|
||||
const cwd = "/test/path"
|
||||
const description = strategy.getToolDescription({ cwd })
|
||||
|
||||
expect(description).toContain('apply_diff')
|
||||
expect(description).toContain("apply_diff")
|
||||
expect(description).toContain(cwd)
|
||||
expect(description).toContain('Parameters:')
|
||||
expect(description).toContain('Format Requirements:')
|
||||
expect(description).toContain("Parameters:")
|
||||
expect(description).toContain("Format Requirements:")
|
||||
})
|
||||
})
|
||||
|
||||
describe('applyDiff', () => {
|
||||
it('should successfully apply a function modification diff', async () => {
|
||||
describe("applyDiff", () => {
|
||||
it("should successfully apply a function modification diff", async () => {
|
||||
const originalContent = `import { Logger } from '../logger';
|
||||
|
||||
function calculateTotal(items: number[]): number {
|
||||
@@ -65,7 +65,7 @@ export { calculateTotal };`
|
||||
}
|
||||
})
|
||||
|
||||
it('should successfully apply a diff adding a new method', async () => {
|
||||
it("should successfully apply a diff adding a new method", async () => {
|
||||
const originalContent = `class Calculator {
|
||||
add(a: number, b: number): number {
|
||||
return a + b;
|
||||
@@ -102,7 +102,7 @@ export { calculateTotal };`
|
||||
}
|
||||
})
|
||||
|
||||
it('should successfully apply a diff modifying imports', async () => {
|
||||
it("should successfully apply a diff modifying imports", async () => {
|
||||
const originalContent = `import { useState } from 'react';
|
||||
import { Button } from './components';
|
||||
|
||||
@@ -140,7 +140,7 @@ function App() {
|
||||
}
|
||||
})
|
||||
|
||||
it('should successfully apply a diff with multiple hunks', async () => {
|
||||
it("should successfully apply a diff with multiple hunks", async () => {
|
||||
const originalContent = `import { readFile, writeFile } from 'fs';
|
||||
|
||||
function processFile(path: string) {
|
||||
@@ -205,8 +205,8 @@ export { processFile };`
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle empty original content', async () => {
|
||||
const originalContent = ''
|
||||
it("should handle empty original content", async () => {
|
||||
const originalContent = ""
|
||||
const diffContent = `--- empty.ts
|
||||
+++ empty.ts
|
||||
@@ -0,0 +1,3 @@
|
||||
@@ -226,4 +226,3 @@ export { processFile };`
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -265,8 +265,8 @@ describe("applyGitFallback", () => {
|
||||
{ type: "context", content: "line1", indent: "" },
|
||||
{ type: "remove", content: "line2", indent: "" },
|
||||
{ type: "add", content: "new line2", indent: "" },
|
||||
{ type: "context", content: "line3", indent: "" }
|
||||
]
|
||||
{ type: "context", content: "line3", indent: "" },
|
||||
],
|
||||
} as Hunk
|
||||
|
||||
const content = ["line1", "line2", "line3"]
|
||||
@@ -281,8 +281,8 @@ describe("applyGitFallback", () => {
|
||||
const hunk = {
|
||||
changes: [
|
||||
{ type: "context", content: "nonexistent", indent: "" },
|
||||
{ type: "add", content: "new line", indent: "" }
|
||||
]
|
||||
{ type: "add", content: "new line", indent: "" },
|
||||
],
|
||||
} as Hunk
|
||||
|
||||
const content = ["line1", "line2", "line3"]
|
||||
|
||||
@@ -3,7 +3,7 @@ import { findAnchorMatch, findExactMatch, findSimilarityMatch, findLevenshteinMa
|
||||
type SearchStrategy = (
|
||||
searchStr: string,
|
||||
content: string[],
|
||||
startIndex?: number
|
||||
startIndex?: number,
|
||||
) => {
|
||||
index: number
|
||||
confidence: number
|
||||
|
||||
@@ -28,11 +28,7 @@ function inferIndentation(line: string, contextLines: string[], previousIndent:
|
||||
}
|
||||
|
||||
// Context matching edit strategy
|
||||
export function applyContextMatching(
|
||||
hunk: Hunk,
|
||||
content: string[],
|
||||
matchPosition: number,
|
||||
): EditResult {
|
||||
export function applyContextMatching(hunk: Hunk, content: string[], matchPosition: number): EditResult {
|
||||
if (matchPosition === -1) {
|
||||
return { confidence: 0, result: content, strategy: "context" }
|
||||
}
|
||||
@@ -85,16 +81,12 @@ export function applyContextMatching(
|
||||
return {
|
||||
confidence,
|
||||
result: newResult,
|
||||
strategy: "context"
|
||||
strategy: "context",
|
||||
}
|
||||
}
|
||||
|
||||
// DMP edit strategy
|
||||
export function applyDMP(
|
||||
hunk: Hunk,
|
||||
content: string[],
|
||||
matchPosition: number,
|
||||
): EditResult {
|
||||
export function applyDMP(hunk: Hunk, content: string[], matchPosition: number): EditResult {
|
||||
if (matchPosition === -1) {
|
||||
return { confidence: 0, result: content, strategy: "dmp" }
|
||||
}
|
||||
@@ -276,12 +268,12 @@ export async function applyEdit(
|
||||
content: string[],
|
||||
matchPosition: number,
|
||||
confidence: number,
|
||||
confidenceThreshold: number = 0.97
|
||||
confidenceThreshold: number = 0.97,
|
||||
): Promise<EditResult> {
|
||||
// Don't attempt regular edits if confidence is too low
|
||||
if (confidence < confidenceThreshold) {
|
||||
console.log(
|
||||
`Search confidence (${confidence}) below minimum threshold (${confidenceThreshold}), trying git fallback...`
|
||||
`Search confidence (${confidence}) below minimum threshold (${confidenceThreshold}), trying git fallback...`,
|
||||
)
|
||||
return applyGitFallback(hunk, content)
|
||||
}
|
||||
|
||||
@@ -242,7 +242,7 @@ Your diff here
|
||||
originalContent: string,
|
||||
diffContent: string,
|
||||
startLine?: number,
|
||||
endLine?: number
|
||||
endLine?: number,
|
||||
): Promise<DiffResult> {
|
||||
const parsedDiff = this.parseUnifiedDiff(diffContent)
|
||||
const originalLines = originalContent.split("\n")
|
||||
@@ -280,7 +280,7 @@ Your diff here
|
||||
subHunkResult,
|
||||
subSearchResult.index,
|
||||
subSearchResult.confidence,
|
||||
this.confidenceThreshold
|
||||
this.confidenceThreshold,
|
||||
)
|
||||
if (subEditResult.confidence >= this.confidenceThreshold) {
|
||||
subHunkResult = subEditResult.result
|
||||
@@ -302,12 +302,12 @@ Your diff here
|
||||
const contextRatio = contextLines / totalLines
|
||||
|
||||
let errorMsg = `Failed to find a matching location in the file (${Math.floor(
|
||||
confidence * 100
|
||||
confidence * 100,
|
||||
)}% confidence, needs ${Math.floor(this.confidenceThreshold * 100)}%)\n\n`
|
||||
errorMsg += "Debug Info:\n"
|
||||
errorMsg += `- Search Strategy Used: ${strategy}\n`
|
||||
errorMsg += `- Context Lines: ${contextLines} out of ${totalLines} total lines (${Math.floor(
|
||||
contextRatio * 100
|
||||
contextRatio * 100,
|
||||
)}%)\n`
|
||||
errorMsg += `- Attempted to split into ${subHunks.length} sub-hunks but still failed\n`
|
||||
|
||||
@@ -339,7 +339,7 @@ Your diff here
|
||||
} else {
|
||||
// Edit failure - likely due to content mismatch
|
||||
let errorMsg = `Failed to apply the edit using ${editResult.strategy} strategy (${Math.floor(
|
||||
editResult.confidence * 100
|
||||
editResult.confidence * 100,
|
||||
)}% confidence)\n\n`
|
||||
errorMsg += "Debug Info:\n"
|
||||
errorMsg += "- The location was found but the content didn't match exactly\n"
|
||||
|
||||
@@ -69,26 +69,26 @@ export function getDMPSimilarity(original: string, modified: string): number {
|
||||
export function validateEditResult(hunk: Hunk, result: string): number {
|
||||
// Build the expected text from the hunk
|
||||
const expectedText = hunk.changes
|
||||
.filter(change => change.type === "context" || change.type === "add")
|
||||
.map(change => change.indent ? change.indent + change.content : change.content)
|
||||
.join("\n");
|
||||
.filter((change) => change.type === "context" || change.type === "add")
|
||||
.map((change) => (change.indent ? change.indent + change.content : change.content))
|
||||
.join("\n")
|
||||
|
||||
// Calculate similarity between the result and expected text
|
||||
const similarity = getDMPSimilarity(expectedText, result);
|
||||
const similarity = getDMPSimilarity(expectedText, result)
|
||||
|
||||
// If the result is unchanged from original, return low confidence
|
||||
const originalText = hunk.changes
|
||||
.filter(change => change.type === "context" || change.type === "remove")
|
||||
.map(change => change.indent ? change.indent + change.content : change.content)
|
||||
.join("\n");
|
||||
.filter((change) => change.type === "context" || change.type === "remove")
|
||||
.map((change) => (change.indent ? change.indent + change.content : change.content))
|
||||
.join("\n")
|
||||
|
||||
const originalSimilarity = getDMPSimilarity(originalText, result);
|
||||
const originalSimilarity = getDMPSimilarity(originalText, result)
|
||||
if (originalSimilarity > 0.97 && similarity !== 1) {
|
||||
return 0.8 * similarity; // Some confidence since we found the right location
|
||||
return 0.8 * similarity // Some confidence since we found the right location
|
||||
}
|
||||
|
||||
// For partial matches, scale the confidence but keep it high if we're close
|
||||
return similarity;
|
||||
return similarity
|
||||
}
|
||||
|
||||
// Helper function to validate context lines against original content
|
||||
@@ -114,7 +114,7 @@ function validateContextLines(searchStr: string, content: string, confidenceThre
|
||||
function createOverlappingWindows(
|
||||
content: string[],
|
||||
searchSize: number,
|
||||
overlapSize: number = DEFAULT_OVERLAP_SIZE
|
||||
overlapSize: number = DEFAULT_OVERLAP_SIZE,
|
||||
): { window: string[]; startIndex: number }[] {
|
||||
const windows: { window: string[]; startIndex: number }[] = []
|
||||
|
||||
@@ -140,7 +140,7 @@ function createOverlappingWindows(
|
||||
// Helper function to combine overlapping matches
|
||||
function combineOverlappingMatches(
|
||||
matches: (SearchResult & { windowIndex: number })[],
|
||||
overlapSize: number = DEFAULT_OVERLAP_SIZE
|
||||
overlapSize: number = DEFAULT_OVERLAP_SIZE,
|
||||
): SearchResult[] {
|
||||
if (matches.length === 0) {
|
||||
return []
|
||||
@@ -162,7 +162,7 @@ function combineOverlappingMatches(
|
||||
(m) =>
|
||||
Math.abs(m.windowIndex - match.windowIndex) === 1 &&
|
||||
Math.abs(m.index - match.index) <= overlapSize &&
|
||||
!usedIndices.has(m.windowIndex)
|
||||
!usedIndices.has(m.windowIndex),
|
||||
)
|
||||
|
||||
if (overlapping.length > 0) {
|
||||
@@ -196,7 +196,7 @@ export function findExactMatch(
|
||||
searchStr: string,
|
||||
content: string[],
|
||||
startIndex: number = 0,
|
||||
confidenceThreshold: number = 0.97
|
||||
confidenceThreshold: number = 0.97,
|
||||
): SearchResult {
|
||||
const searchLines = searchStr.split("\n")
|
||||
const windows = createOverlappingWindows(content.slice(startIndex), searchLines.length)
|
||||
@@ -210,7 +210,7 @@ export function findExactMatch(
|
||||
const matchedContent = windowData.window
|
||||
.slice(
|
||||
windowStr.slice(0, exactMatch).split("\n").length - 1,
|
||||
windowStr.slice(0, exactMatch).split("\n").length - 1 + searchLines.length
|
||||
windowStr.slice(0, exactMatch).split("\n").length - 1 + searchLines.length,
|
||||
)
|
||||
.join("\n")
|
||||
|
||||
@@ -236,7 +236,7 @@ export function findSimilarityMatch(
|
||||
searchStr: string,
|
||||
content: string[],
|
||||
startIndex: number = 0,
|
||||
confidenceThreshold: number = 0.97
|
||||
confidenceThreshold: number = 0.97,
|
||||
): SearchResult {
|
||||
const searchLines = searchStr.split("\n")
|
||||
let bestScore = 0
|
||||
@@ -269,7 +269,7 @@ export function findLevenshteinMatch(
|
||||
searchStr: string,
|
||||
content: string[],
|
||||
startIndex: number = 0,
|
||||
confidenceThreshold: number = 0.97
|
||||
confidenceThreshold: number = 0.97,
|
||||
): SearchResult {
|
||||
const searchLines = searchStr.split("\n")
|
||||
const candidates = []
|
||||
@@ -324,7 +324,7 @@ export function findAnchorMatch(
|
||||
searchStr: string,
|
||||
content: string[],
|
||||
startIndex: number = 0,
|
||||
confidenceThreshold: number = 0.97
|
||||
confidenceThreshold: number = 0.97,
|
||||
): SearchResult {
|
||||
const searchLines = searchStr.split("\n")
|
||||
const { first, last } = identifyAnchors(searchStr)
|
||||
@@ -391,7 +391,7 @@ export function findBestMatch(
|
||||
searchStr: string,
|
||||
content: string[],
|
||||
startIndex: number = 0,
|
||||
confidenceThreshold: number = 0.97
|
||||
confidenceThreshold: number = 0.97,
|
||||
): SearchResult {
|
||||
const strategies = [findExactMatch, findAnchorMatch, findSimilarityMatch, findLevenshteinMatch]
|
||||
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
export type Change = {
|
||||
type: 'context' | 'add' | 'remove';
|
||||
content: string;
|
||||
indent: string;
|
||||
originalLine?: string;
|
||||
};
|
||||
type: "context" | "add" | "remove"
|
||||
content: string
|
||||
indent: string
|
||||
originalLine?: string
|
||||
}
|
||||
|
||||
export type Hunk = {
|
||||
changes: Change[];
|
||||
};
|
||||
changes: Change[]
|
||||
}
|
||||
|
||||
export type Diff = {
|
||||
hunks: Hunk[];
|
||||
};
|
||||
hunks: Hunk[]
|
||||
}
|
||||
|
||||
export type EditResult = {
|
||||
confidence: number;
|
||||
result: string[];
|
||||
strategy: string;
|
||||
};
|
||||
confidence: number
|
||||
result: string[]
|
||||
strategy: string
|
||||
}
|
||||
|
||||
@@ -1,68 +1,70 @@
|
||||
import { DiffStrategy, DiffResult } from "../types"
|
||||
import { addLineNumbers, everyLineHasLineNumbers, stripLineNumbers } from "../../../integrations/misc/extract-text"
|
||||
|
||||
const BUFFER_LINES = 20; // Number of extra context lines to show before and after matches
|
||||
const BUFFER_LINES = 20 // Number of extra context lines to show before and after matches
|
||||
|
||||
function levenshteinDistance(a: string, b: string): number {
|
||||
const matrix: number[][] = [];
|
||||
const matrix: number[][] = []
|
||||
|
||||
// Initialize matrix
|
||||
for (let i = 0; i <= a.length; i++) {
|
||||
matrix[i] = [i];
|
||||
matrix[i] = [i]
|
||||
}
|
||||
for (let j = 0; j <= b.length; j++) {
|
||||
matrix[0][j] = j;
|
||||
matrix[0][j] = j
|
||||
}
|
||||
|
||||
// Fill matrix
|
||||
for (let i = 1; i <= a.length; i++) {
|
||||
for (let j = 1; j <= b.length; j++) {
|
||||
if (a[i - 1] === b[j - 1]) {
|
||||
matrix[i][j] = matrix[i-1][j-1];
|
||||
matrix[i][j] = matrix[i - 1][j - 1]
|
||||
} else {
|
||||
matrix[i][j] = Math.min(
|
||||
matrix[i - 1][j - 1] + 1, // substitution
|
||||
matrix[i][j - 1] + 1, // insertion
|
||||
matrix[i-1][j] + 1 // deletion
|
||||
);
|
||||
matrix[i - 1][j] + 1, // deletion
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return matrix[a.length][b.length];
|
||||
return matrix[a.length][b.length]
|
||||
}
|
||||
|
||||
function getSimilarity(original: string, search: string): number {
|
||||
if (search === '') {
|
||||
return 1;
|
||||
if (search === "") {
|
||||
return 1
|
||||
}
|
||||
|
||||
// Normalize strings by removing extra whitespace but preserve case
|
||||
const normalizeStr = (str: string) => str.replace(/\s+/g, ' ').trim();
|
||||
const normalizeStr = (str: string) => str.replace(/\s+/g, " ").trim()
|
||||
|
||||
const normalizedOriginal = normalizeStr(original);
|
||||
const normalizedSearch = normalizeStr(search);
|
||||
const normalizedOriginal = normalizeStr(original)
|
||||
const normalizedSearch = normalizeStr(search)
|
||||
|
||||
if (normalizedOriginal === normalizedSearch) { return 1; }
|
||||
if (normalizedOriginal === normalizedSearch) {
|
||||
return 1
|
||||
}
|
||||
|
||||
// Calculate Levenshtein distance
|
||||
const distance = levenshteinDistance(normalizedOriginal, normalizedSearch);
|
||||
const distance = levenshteinDistance(normalizedOriginal, normalizedSearch)
|
||||
|
||||
// Calculate similarity ratio (0 to 1, where 1 is exact match)
|
||||
const maxLength = Math.max(normalizedOriginal.length, normalizedSearch.length);
|
||||
return 1 - (distance / maxLength);
|
||||
const maxLength = Math.max(normalizedOriginal.length, normalizedSearch.length)
|
||||
return 1 - distance / maxLength
|
||||
}
|
||||
|
||||
export class SearchReplaceDiffStrategy implements DiffStrategy {
|
||||
private fuzzyThreshold: number;
|
||||
private bufferLines: number;
|
||||
private fuzzyThreshold: number
|
||||
private bufferLines: number
|
||||
|
||||
constructor(fuzzyThreshold?: number, bufferLines?: number) {
|
||||
// Use provided threshold or default to exact matching (1.0)
|
||||
// Note: fuzzyThreshold is inverted in UI (0% = 1.0, 10% = 0.9)
|
||||
// so we use it directly here
|
||||
this.fuzzyThreshold = fuzzyThreshold ?? 1.0;
|
||||
this.bufferLines = bufferLines ?? BUFFER_LINES;
|
||||
this.fuzzyThreshold = fuzzyThreshold ?? 1.0
|
||||
this.bufferLines = bufferLines ?? BUFFER_LINES
|
||||
}
|
||||
|
||||
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
|
||||
@@ -127,191 +129,202 @@ Your search/replace content here
|
||||
</apply_diff>`
|
||||
}
|
||||
|
||||
async applyDiff(originalContent: string, diffContent: string, startLine?: number, endLine?: number): Promise<DiffResult> {
|
||||
async applyDiff(
|
||||
originalContent: string,
|
||||
diffContent: string,
|
||||
startLine?: number,
|
||||
endLine?: number,
|
||||
): Promise<DiffResult> {
|
||||
// Extract the search and replace blocks
|
||||
const match = diffContent.match(/<<<<<<< SEARCH\n([\s\S]*?)\n?=======\n([\s\S]*?)\n?>>>>>>> REPLACE/);
|
||||
const match = diffContent.match(/<<<<<<< SEARCH\n([\s\S]*?)\n?=======\n([\s\S]*?)\n?>>>>>>> REPLACE/)
|
||||
if (!match) {
|
||||
return {
|
||||
success: false,
|
||||
error: `Invalid diff format - missing required SEARCH/REPLACE sections\n\nDebug Info:\n- Expected Format: <<<<<<< SEARCH\\n[search content]\\n=======\\n[replace content]\\n>>>>>>> REPLACE\n- Tip: Make sure to include both SEARCH and REPLACE sections with correct markers`
|
||||
};
|
||||
error: `Invalid diff format - missing required SEARCH/REPLACE sections\n\nDebug Info:\n- Expected Format: <<<<<<< SEARCH\\n[search content]\\n=======\\n[replace content]\\n>>>>>>> REPLACE\n- Tip: Make sure to include both SEARCH and REPLACE sections with correct markers`,
|
||||
}
|
||||
}
|
||||
|
||||
let [_, searchContent, replaceContent] = match;
|
||||
let [_, searchContent, replaceContent] = match
|
||||
|
||||
// Detect line ending from original content
|
||||
const lineEnding = originalContent.includes('\r\n') ? '\r\n' : '\n';
|
||||
const lineEnding = originalContent.includes("\r\n") ? "\r\n" : "\n"
|
||||
|
||||
// Strip line numbers from search and replace content if every line starts with a line number
|
||||
if (everyLineHasLineNumbers(searchContent) && everyLineHasLineNumbers(replaceContent)) {
|
||||
searchContent = stripLineNumbers(searchContent);
|
||||
replaceContent = stripLineNumbers(replaceContent);
|
||||
searchContent = stripLineNumbers(searchContent)
|
||||
replaceContent = stripLineNumbers(replaceContent)
|
||||
}
|
||||
|
||||
// Split content into lines, handling both \n and \r\n
|
||||
const searchLines = searchContent === '' ? [] : searchContent.split(/\r?\n/);
|
||||
const replaceLines = replaceContent === '' ? [] : replaceContent.split(/\r?\n/);
|
||||
const originalLines = originalContent.split(/\r?\n/);
|
||||
const searchLines = searchContent === "" ? [] : searchContent.split(/\r?\n/)
|
||||
const replaceLines = replaceContent === "" ? [] : replaceContent.split(/\r?\n/)
|
||||
const originalLines = originalContent.split(/\r?\n/)
|
||||
|
||||
// Validate that empty search requires start line
|
||||
if (searchLines.length === 0 && !startLine) {
|
||||
return {
|
||||
success: false,
|
||||
error: `Empty search content requires start_line to be specified\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, specify the line number where content should be inserted`
|
||||
};
|
||||
error: `Empty search content requires start_line to be specified\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, specify the line number where content should be inserted`,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that empty search requires same start and end line
|
||||
if (searchLines.length === 0 && startLine && endLine && startLine !== endLine) {
|
||||
return {
|
||||
success: false,
|
||||
error: `Empty search content requires start_line and end_line to be the same (got ${startLine}-${endLine})\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, use the same line number for both start_line and end_line`
|
||||
};
|
||||
error: `Empty search content requires start_line and end_line to be the same (got ${startLine}-${endLine})\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, use the same line number for both start_line and end_line`,
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize search variables
|
||||
let matchIndex = -1;
|
||||
let bestMatchScore = 0;
|
||||
let bestMatchContent = "";
|
||||
const searchChunk = searchLines.join('\n');
|
||||
let matchIndex = -1
|
||||
let bestMatchScore = 0
|
||||
let bestMatchContent = ""
|
||||
const searchChunk = searchLines.join("\n")
|
||||
|
||||
// Determine search bounds
|
||||
let searchStartIndex = 0;
|
||||
let searchEndIndex = originalLines.length;
|
||||
let searchStartIndex = 0
|
||||
let searchEndIndex = originalLines.length
|
||||
|
||||
// Validate and handle line range if provided
|
||||
if (startLine && endLine) {
|
||||
// Convert to 0-based index
|
||||
const exactStartIndex = startLine - 1;
|
||||
const exactEndIndex = endLine - 1;
|
||||
const exactStartIndex = startLine - 1
|
||||
const exactEndIndex = endLine - 1
|
||||
|
||||
if (exactStartIndex < 0 || exactEndIndex > originalLines.length || exactStartIndex > exactEndIndex) {
|
||||
return {
|
||||
success: false,
|
||||
error: `Line range ${startLine}-${endLine} is invalid (file has ${originalLines.length} lines)\n\nDebug Info:\n- Requested Range: lines ${startLine}-${endLine}\n- File Bounds: lines 1-${originalLines.length}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Try exact match first
|
||||
const originalChunk = originalLines.slice(exactStartIndex, exactEndIndex + 1).join('\n');
|
||||
const similarity = getSimilarity(originalChunk, searchChunk);
|
||||
const originalChunk = originalLines.slice(exactStartIndex, exactEndIndex + 1).join("\n")
|
||||
const similarity = getSimilarity(originalChunk, searchChunk)
|
||||
if (similarity >= this.fuzzyThreshold) {
|
||||
matchIndex = exactStartIndex;
|
||||
bestMatchScore = similarity;
|
||||
bestMatchContent = originalChunk;
|
||||
matchIndex = exactStartIndex
|
||||
bestMatchScore = similarity
|
||||
bestMatchContent = originalChunk
|
||||
} else {
|
||||
// Set bounds for buffered search
|
||||
searchStartIndex = Math.max(0, startLine - (this.bufferLines + 1));
|
||||
searchEndIndex = Math.min(originalLines.length, endLine + this.bufferLines);
|
||||
searchStartIndex = Math.max(0, startLine - (this.bufferLines + 1))
|
||||
searchEndIndex = Math.min(originalLines.length, endLine + this.bufferLines)
|
||||
}
|
||||
}
|
||||
|
||||
// If no match found yet, try middle-out search within bounds
|
||||
if (matchIndex === -1) {
|
||||
const midPoint = Math.floor((searchStartIndex + searchEndIndex) / 2);
|
||||
let leftIndex = midPoint;
|
||||
let rightIndex = midPoint + 1;
|
||||
const midPoint = Math.floor((searchStartIndex + searchEndIndex) / 2)
|
||||
let leftIndex = midPoint
|
||||
let rightIndex = midPoint + 1
|
||||
|
||||
// Search outward from the middle within bounds
|
||||
while (leftIndex >= searchStartIndex || rightIndex <= searchEndIndex - searchLines.length) {
|
||||
// Check left side if still in range
|
||||
if (leftIndex >= searchStartIndex) {
|
||||
const originalChunk = originalLines.slice(leftIndex, leftIndex + searchLines.length).join('\n');
|
||||
const similarity = getSimilarity(originalChunk, searchChunk);
|
||||
const originalChunk = originalLines.slice(leftIndex, leftIndex + searchLines.length).join("\n")
|
||||
const similarity = getSimilarity(originalChunk, searchChunk)
|
||||
if (similarity > bestMatchScore) {
|
||||
bestMatchScore = similarity;
|
||||
matchIndex = leftIndex;
|
||||
bestMatchContent = originalChunk;
|
||||
bestMatchScore = similarity
|
||||
matchIndex = leftIndex
|
||||
bestMatchContent = originalChunk
|
||||
}
|
||||
leftIndex--;
|
||||
leftIndex--
|
||||
}
|
||||
|
||||
// Check right side if still in range
|
||||
if (rightIndex <= searchEndIndex - searchLines.length) {
|
||||
const originalChunk = originalLines.slice(rightIndex, rightIndex + searchLines.length).join('\n');
|
||||
const similarity = getSimilarity(originalChunk, searchChunk);
|
||||
const originalChunk = originalLines.slice(rightIndex, rightIndex + searchLines.length).join("\n")
|
||||
const similarity = getSimilarity(originalChunk, searchChunk)
|
||||
if (similarity > bestMatchScore) {
|
||||
bestMatchScore = similarity;
|
||||
matchIndex = rightIndex;
|
||||
bestMatchContent = originalChunk;
|
||||
bestMatchScore = similarity
|
||||
matchIndex = rightIndex
|
||||
bestMatchContent = originalChunk
|
||||
}
|
||||
rightIndex++;
|
||||
rightIndex++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Require similarity to meet threshold
|
||||
if (matchIndex === -1 || bestMatchScore < this.fuzzyThreshold) {
|
||||
const searchChunk = searchLines.join('\n');
|
||||
const originalContentSection = startLine !== undefined && endLine !== undefined
|
||||
const searchChunk = searchLines.join("\n")
|
||||
const originalContentSection =
|
||||
startLine !== undefined && endLine !== undefined
|
||||
? `\n\nOriginal Content:\n${addLineNumbers(
|
||||
originalLines.slice(
|
||||
originalLines
|
||||
.slice(
|
||||
Math.max(0, startLine - 1 - this.bufferLines),
|
||||
Math.min(originalLines.length, endLine + this.bufferLines)
|
||||
).join('\n'),
|
||||
Math.max(1, startLine - this.bufferLines)
|
||||
Math.min(originalLines.length, endLine + this.bufferLines),
|
||||
)
|
||||
.join("\n"),
|
||||
Math.max(1, startLine - this.bufferLines),
|
||||
)}`
|
||||
: `\n\nOriginal Content:\n${addLineNumbers(originalLines.join('\n'))}`;
|
||||
: `\n\nOriginal Content:\n${addLineNumbers(originalLines.join("\n"))}`
|
||||
|
||||
const bestMatchSection = bestMatchContent
|
||||
? `\n\nBest Match Found:\n${addLineNumbers(bestMatchContent, matchIndex + 1)}`
|
||||
: `\n\nBest Match Found:\n(no match)`;
|
||||
: `\n\nBest Match Found:\n(no match)`
|
||||
|
||||
const lineRange = startLine || endLine ?
|
||||
` at ${startLine ? `start: ${startLine}` : 'start'} to ${endLine ? `end: ${endLine}` : 'end'}` : '';
|
||||
const lineRange =
|
||||
startLine || endLine
|
||||
? ` at ${startLine ? `start: ${startLine}` : "start"} to ${endLine ? `end: ${endLine}` : "end"}`
|
||||
: ""
|
||||
return {
|
||||
success: false,
|
||||
error: `No sufficiently similar match found${lineRange} (${Math.floor(bestMatchScore * 100)}% similar, needs ${Math.floor(this.fuzzyThreshold * 100)}%)\n\nDebug Info:\n- Similarity Score: ${Math.floor(bestMatchScore * 100)}%\n- Required Threshold: ${Math.floor(this.fuzzyThreshold * 100)}%\n- Search Range: ${startLine && endLine ? `lines ${startLine}-${endLine}` : 'start to end'}\n\nSearch Content:\n${searchChunk}${bestMatchSection}${originalContentSection}`
|
||||
};
|
||||
error: `No sufficiently similar match found${lineRange} (${Math.floor(bestMatchScore * 100)}% similar, needs ${Math.floor(this.fuzzyThreshold * 100)}%)\n\nDebug Info:\n- Similarity Score: ${Math.floor(bestMatchScore * 100)}%\n- Required Threshold: ${Math.floor(this.fuzzyThreshold * 100)}%\n- Search Range: ${startLine && endLine ? `lines ${startLine}-${endLine}` : "start to end"}\n\nSearch Content:\n${searchChunk}${bestMatchSection}${originalContentSection}`,
|
||||
}
|
||||
}
|
||||
|
||||
// Get the matched lines from the original content
|
||||
const matchedLines = originalLines.slice(matchIndex, matchIndex + searchLines.length);
|
||||
const matchedLines = originalLines.slice(matchIndex, matchIndex + searchLines.length)
|
||||
|
||||
// Get the exact indentation (preserving tabs/spaces) of each line
|
||||
const originalIndents = matchedLines.map(line => {
|
||||
const match = line.match(/^[\t ]*/);
|
||||
return match ? match[0] : '';
|
||||
});
|
||||
const originalIndents = matchedLines.map((line) => {
|
||||
const match = line.match(/^[\t ]*/)
|
||||
return match ? match[0] : ""
|
||||
})
|
||||
|
||||
// Get the exact indentation of each line in the search block
|
||||
const searchIndents = searchLines.map(line => {
|
||||
const match = line.match(/^[\t ]*/);
|
||||
return match ? match[0] : '';
|
||||
});
|
||||
const searchIndents = searchLines.map((line) => {
|
||||
const match = line.match(/^[\t ]*/)
|
||||
return match ? match[0] : ""
|
||||
})
|
||||
|
||||
// Apply the replacement while preserving exact indentation
|
||||
const indentedReplaceLines = replaceLines.map((line, i) => {
|
||||
// Get the matched line's exact indentation
|
||||
const matchedIndent = originalIndents[0] || '';
|
||||
const matchedIndent = originalIndents[0] || ""
|
||||
|
||||
// Get the current line's indentation relative to the search content
|
||||
const currentIndentMatch = line.match(/^[\t ]*/);
|
||||
const currentIndent = currentIndentMatch ? currentIndentMatch[0] : '';
|
||||
const searchBaseIndent = searchIndents[0] || '';
|
||||
const currentIndentMatch = line.match(/^[\t ]*/)
|
||||
const currentIndent = currentIndentMatch ? currentIndentMatch[0] : ""
|
||||
const searchBaseIndent = searchIndents[0] || ""
|
||||
|
||||
// Calculate the relative indentation level
|
||||
const searchBaseLevel = searchBaseIndent.length;
|
||||
const currentLevel = currentIndent.length;
|
||||
const relativeLevel = currentLevel - searchBaseLevel;
|
||||
const searchBaseLevel = searchBaseIndent.length
|
||||
const currentLevel = currentIndent.length
|
||||
const relativeLevel = currentLevel - searchBaseLevel
|
||||
|
||||
// If relative level is negative, remove indentation from matched indent
|
||||
// If positive, add to matched indent
|
||||
const finalIndent = relativeLevel < 0
|
||||
const finalIndent =
|
||||
relativeLevel < 0
|
||||
? matchedIndent.slice(0, Math.max(0, matchedIndent.length + relativeLevel))
|
||||
: matchedIndent + currentIndent.slice(searchBaseLevel);
|
||||
: matchedIndent + currentIndent.slice(searchBaseLevel)
|
||||
|
||||
return finalIndent + line.trim();
|
||||
});
|
||||
return finalIndent + line.trim()
|
||||
})
|
||||
|
||||
// Construct the final content
|
||||
const beforeMatch = originalLines.slice(0, matchIndex);
|
||||
const afterMatch = originalLines.slice(matchIndex + searchLines.length);
|
||||
const beforeMatch = originalLines.slice(0, matchIndex)
|
||||
const afterMatch = originalLines.slice(matchIndex + searchLines.length)
|
||||
|
||||
const finalContent = [...beforeMatch, ...indentedReplaceLines, ...afterMatch].join(lineEnding);
|
||||
const finalContent = [...beforeMatch, ...indentedReplaceLines, ...afterMatch].join(lineEnding)
|
||||
return {
|
||||
success: true,
|
||||
content: finalContent
|
||||
};
|
||||
content: finalContent,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -116,21 +116,21 @@ Your diff here
|
||||
success: false,
|
||||
error: "Failed to apply unified diff - patch rejected",
|
||||
details: {
|
||||
searchContent: diffContent
|
||||
}
|
||||
searchContent: diffContent,
|
||||
},
|
||||
}
|
||||
}
|
||||
return {
|
||||
success: true,
|
||||
content: result
|
||||
content: result,
|
||||
}
|
||||
} catch (error) {
|
||||
return {
|
||||
success: false,
|
||||
error: `Error applying unified diff: ${error.message}`,
|
||||
details: {
|
||||
searchContent: diffContent
|
||||
}
|
||||
searchContent: diffContent,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,13 +4,17 @@
|
||||
|
||||
export type DiffResult =
|
||||
| { success: true; content: string }
|
||||
| { success: false; error: string; details?: {
|
||||
similarity?: number;
|
||||
threshold?: number;
|
||||
matchedRange?: { start: number; end: number };
|
||||
searchContent?: string;
|
||||
bestMatch?: string;
|
||||
}};
|
||||
| {
|
||||
success: false
|
||||
error: string
|
||||
details?: {
|
||||
similarity?: number
|
||||
threshold?: number
|
||||
matchedRange?: { start: number; end: number }
|
||||
searchContent?: string
|
||||
bestMatch?: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface DiffStrategy {
|
||||
/**
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
// Create mock vscode module before importing anything
|
||||
const createMockUri = (scheme: string, path: string) => ({
|
||||
scheme,
|
||||
authority: '',
|
||||
authority: "",
|
||||
path,
|
||||
query: '',
|
||||
fragment: '',
|
||||
query: "",
|
||||
fragment: "",
|
||||
fsPath: path,
|
||||
with: jest.fn(),
|
||||
toString: () => path,
|
||||
toJSON: () => ({
|
||||
scheme,
|
||||
authority: '',
|
||||
authority: "",
|
||||
path,
|
||||
query: '',
|
||||
fragment: ''
|
||||
})
|
||||
query: "",
|
||||
fragment: "",
|
||||
}),
|
||||
})
|
||||
|
||||
const mockExecuteCommand = jest.fn()
|
||||
@@ -23,9 +23,11 @@ const mockShowErrorMessage = jest.fn()
|
||||
|
||||
const mockVscode = {
|
||||
workspace: {
|
||||
workspaceFolders: [{
|
||||
uri: { fsPath: "/test/workspace" }
|
||||
}]
|
||||
workspaceFolders: [
|
||||
{
|
||||
uri: { fsPath: "/test/workspace" },
|
||||
},
|
||||
],
|
||||
},
|
||||
window: {
|
||||
showErrorMessage: mockShowErrorMessage,
|
||||
@@ -34,17 +36,17 @@ const mockVscode = {
|
||||
createTextEditorDecorationType: jest.fn(),
|
||||
createOutputChannel: jest.fn(),
|
||||
createWebviewPanel: jest.fn(),
|
||||
activeTextEditor: undefined
|
||||
activeTextEditor: undefined,
|
||||
},
|
||||
commands: {
|
||||
executeCommand: mockExecuteCommand
|
||||
executeCommand: mockExecuteCommand,
|
||||
},
|
||||
env: {
|
||||
openExternal: mockOpenExternal
|
||||
openExternal: mockOpenExternal,
|
||||
},
|
||||
Uri: {
|
||||
parse: jest.fn((url: string) => createMockUri('https', url)),
|
||||
file: jest.fn((path: string) => createMockUri('file', path))
|
||||
parse: jest.fn((url: string) => createMockUri("https", url)),
|
||||
file: jest.fn((path: string) => createMockUri("file", path)),
|
||||
},
|
||||
Position: jest.fn(),
|
||||
Range: jest.fn(),
|
||||
@@ -54,12 +56,12 @@ const mockVscode = {
|
||||
Error: 0,
|
||||
Warning: 1,
|
||||
Information: 2,
|
||||
Hint: 3
|
||||
}
|
||||
Hint: 3,
|
||||
},
|
||||
}
|
||||
|
||||
// Mock modules
|
||||
jest.mock('vscode', () => mockVscode)
|
||||
jest.mock("vscode", () => mockVscode)
|
||||
jest.mock("../../../services/browser/UrlContentFetcher")
|
||||
jest.mock("../../../utils/git")
|
||||
|
||||
@@ -97,11 +99,7 @@ Detailed commit message with multiple lines
|
||||
|
||||
jest.mocked(git.getCommitInfo).mockResolvedValue(commitInfo)
|
||||
|
||||
const result = await parseMentions(
|
||||
`Check out this commit @${commitHash}`,
|
||||
mockCwd,
|
||||
mockUrlContentFetcher
|
||||
)
|
||||
const result = await parseMentions(`Check out this commit @${commitHash}`, mockCwd, mockUrlContentFetcher)
|
||||
|
||||
expect(result).toContain(`'${commitHash}' (see below for commit info)`)
|
||||
expect(result).toContain(`<git_commit hash="${commitHash}">`)
|
||||
@@ -114,11 +112,7 @@ Detailed commit message with multiple lines
|
||||
|
||||
jest.mocked(git.getCommitInfo).mockRejectedValue(new Error(errorMessage))
|
||||
|
||||
const result = await parseMentions(
|
||||
`Check out this commit @${commitHash}`,
|
||||
mockCwd,
|
||||
mockUrlContentFetcher
|
||||
)
|
||||
const result = await parseMentions(`Check out this commit @${commitHash}`, mockCwd, mockUrlContentFetcher)
|
||||
|
||||
expect(result).toContain(`'${commitHash}' (see below for commit info)`)
|
||||
expect(result).toContain(`<git_commit hash="${commitHash}">`)
|
||||
@@ -143,13 +137,15 @@ Detailed commit message with multiple lines
|
||||
const mockUri = mockVscode.Uri.parse(url)
|
||||
expect(mockOpenExternal).toHaveBeenCalled()
|
||||
const calledArg = mockOpenExternal.mock.calls[0][0]
|
||||
expect(calledArg).toEqual(expect.objectContaining({
|
||||
expect(calledArg).toEqual(
|
||||
expect.objectContaining({
|
||||
scheme: mockUri.scheme,
|
||||
authority: mockUri.authority,
|
||||
path: mockUri.path,
|
||||
query: mockUri.query,
|
||||
fragment: mockUri.fragment
|
||||
}))
|
||||
fragment: mockUri.fragment,
|
||||
}),
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,12 +1,10 @@
|
||||
import { Mode, isToolAllowedForMode, TestToolName, getModeConfig } from '../shared/modes';
|
||||
import { Mode, isToolAllowedForMode, TestToolName, getModeConfig } from "../shared/modes"
|
||||
|
||||
export { isToolAllowedForMode };
|
||||
export type { TestToolName };
|
||||
export { isToolAllowedForMode }
|
||||
export type { TestToolName }
|
||||
|
||||
export function validateToolUse(toolName: TestToolName, mode: Mode): void {
|
||||
if (!isToolAllowedForMode(toolName, mode)) {
|
||||
throw new Error(
|
||||
`Tool "${toolName}" is not allowed in ${mode} mode.`
|
||||
);
|
||||
throw new Error(`Tool "${toolName}" is not allowed in ${mode} mode.`)
|
||||
}
|
||||
}
|
||||
@@ -1,67 +1,68 @@
|
||||
import { SYSTEM_PROMPT, addCustomInstructions } from '../system'
|
||||
import { McpHub } from '../../../services/mcp/McpHub'
|
||||
import { McpServer } from '../../../shared/mcp'
|
||||
import { ClineProvider } from '../../../core/webview/ClineProvider'
|
||||
import { SearchReplaceDiffStrategy } from '../../../core/diff/strategies/search-replace'
|
||||
import fs from 'fs/promises'
|
||||
import os from 'os'
|
||||
import { defaultModeSlug, modes } from '../../../shared/modes'
|
||||
import { SYSTEM_PROMPT, addCustomInstructions } from "../system"
|
||||
import { McpHub } from "../../../services/mcp/McpHub"
|
||||
import { McpServer } from "../../../shared/mcp"
|
||||
import { ClineProvider } from "../../../core/webview/ClineProvider"
|
||||
import { SearchReplaceDiffStrategy } from "../../../core/diff/strategies/search-replace"
|
||||
import fs from "fs/promises"
|
||||
import os from "os"
|
||||
import { defaultModeSlug, modes } from "../../../shared/modes"
|
||||
// Import path utils to get access to toPosix string extension
|
||||
import '../../../utils/path'
|
||||
import "../../../utils/path"
|
||||
|
||||
// Mock environment-specific values for consistent tests
|
||||
jest.mock('os', () => ({
|
||||
...jest.requireActual('os'),
|
||||
homedir: () => '/home/user'
|
||||
jest.mock("os", () => ({
|
||||
...jest.requireActual("os"),
|
||||
homedir: () => "/home/user",
|
||||
}))
|
||||
|
||||
jest.mock('default-shell', () => '/bin/bash')
|
||||
jest.mock("default-shell", () => "/bin/bash")
|
||||
|
||||
jest.mock('os-name', () => () => 'Linux')
|
||||
jest.mock("os-name", () => () => "Linux")
|
||||
|
||||
// Mock fs.readFile to return empty mcpServers config and mock rules files
|
||||
jest.mock('fs/promises', () => ({
|
||||
...jest.requireActual('fs/promises'),
|
||||
jest.mock("fs/promises", () => ({
|
||||
...jest.requireActual("fs/promises"),
|
||||
readFile: jest.fn().mockImplementation(async (path: string) => {
|
||||
if (path.endsWith('mcpSettings.json')) {
|
||||
if (path.endsWith("mcpSettings.json")) {
|
||||
return '{"mcpServers": {}}'
|
||||
}
|
||||
if (path.endsWith('.clinerules-code')) {
|
||||
return '# Code Mode Rules\n1. Code specific rule'
|
||||
if (path.endsWith(".clinerules-code")) {
|
||||
return "# Code Mode Rules\n1. Code specific rule"
|
||||
}
|
||||
if (path.endsWith('.clinerules-ask')) {
|
||||
return '# Ask Mode Rules\n1. Ask specific rule'
|
||||
if (path.endsWith(".clinerules-ask")) {
|
||||
return "# Ask Mode Rules\n1. Ask specific rule"
|
||||
}
|
||||
if (path.endsWith('.clinerules-architect')) {
|
||||
return '# Architect Mode Rules\n1. Architect specific rule'
|
||||
if (path.endsWith(".clinerules-architect")) {
|
||||
return "# Architect Mode Rules\n1. Architect specific rule"
|
||||
}
|
||||
if (path.endsWith('.clinerules')) {
|
||||
return '# Test Rules\n1. First rule\n2. Second rule'
|
||||
if (path.endsWith(".clinerules")) {
|
||||
return "# Test Rules\n1. First rule\n2. Second rule"
|
||||
}
|
||||
return ''
|
||||
return ""
|
||||
}),
|
||||
writeFile: jest.fn().mockResolvedValue(undefined)
|
||||
writeFile: jest.fn().mockResolvedValue(undefined),
|
||||
}))
|
||||
|
||||
// Create a minimal mock of ClineProvider
|
||||
const mockProvider = {
|
||||
ensureMcpServersDirectoryExists: async () => '/mock/mcp/path',
|
||||
ensureSettingsDirectoryExists: async () => '/mock/settings/path',
|
||||
ensureMcpServersDirectoryExists: async () => "/mock/mcp/path",
|
||||
ensureSettingsDirectoryExists: async () => "/mock/settings/path",
|
||||
postMessageToWebview: async () => {},
|
||||
context: {
|
||||
extension: {
|
||||
packageJSON: {
|
||||
version: '1.0.0'
|
||||
}
|
||||
}
|
||||
}
|
||||
version: "1.0.0",
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as ClineProvider
|
||||
|
||||
// Instead of extending McpHub, create a mock that implements just what we need
|
||||
const createMockMcpHub = (): McpHub => ({
|
||||
const createMockMcpHub = (): McpHub =>
|
||||
({
|
||||
getServers: () => [],
|
||||
getMcpServersPath: async () => '/mock/mcp/path',
|
||||
getMcpSettingsFilePath: async () => '/mock/settings/path',
|
||||
getMcpServersPath: async () => "/mock/mcp/path",
|
||||
getMcpSettingsFilePath: async () => "/mock/settings/path",
|
||||
dispose: async () => {},
|
||||
// Add other required public methods with no-op implementations
|
||||
restartConnection: async () => {},
|
||||
@@ -70,10 +71,10 @@ const createMockMcpHub = (): McpHub => ({
|
||||
toggleServerDisabled: async () => {},
|
||||
toggleToolAlwaysAllow: async () => {},
|
||||
isConnecting: false,
|
||||
connections: []
|
||||
} as unknown as McpHub)
|
||||
connections: [],
|
||||
}) as unknown as McpHub
|
||||
|
||||
describe('SYSTEM_PROMPT', () => {
|
||||
describe("SYSTEM_PROMPT", () => {
|
||||
let mockMcpHub: McpHub
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -87,73 +88,63 @@ describe('SYSTEM_PROMPT', () => {
|
||||
}
|
||||
})
|
||||
|
||||
it('should maintain consistent system prompt', async () => {
|
||||
it("should maintain consistent system prompt", async () => {
|
||||
const prompt = await SYSTEM_PROMPT(
|
||||
'/test/path',
|
||||
"/test/path",
|
||||
false, // supportsComputerUse
|
||||
undefined, // mcpHub
|
||||
undefined, // diffStrategy
|
||||
undefined // browserViewportSize
|
||||
undefined, // browserViewportSize
|
||||
)
|
||||
|
||||
expect(prompt).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should include browser actions when supportsComputerUse is true', async () => {
|
||||
const prompt = await SYSTEM_PROMPT(
|
||||
'/test/path',
|
||||
true,
|
||||
undefined,
|
||||
undefined,
|
||||
'1280x800'
|
||||
)
|
||||
it("should include browser actions when supportsComputerUse is true", async () => {
|
||||
const prompt = await SYSTEM_PROMPT("/test/path", true, undefined, undefined, "1280x800")
|
||||
|
||||
expect(prompt).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should include MCP server info when mcpHub is provided', async () => {
|
||||
it("should include MCP server info when mcpHub is provided", async () => {
|
||||
mockMcpHub = createMockMcpHub()
|
||||
|
||||
const prompt = await SYSTEM_PROMPT(
|
||||
'/test/path',
|
||||
false,
|
||||
mockMcpHub
|
||||
)
|
||||
const prompt = await SYSTEM_PROMPT("/test/path", false, mockMcpHub)
|
||||
|
||||
expect(prompt).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should explicitly handle undefined mcpHub', async () => {
|
||||
it("should explicitly handle undefined mcpHub", async () => {
|
||||
const prompt = await SYSTEM_PROMPT(
|
||||
'/test/path',
|
||||
"/test/path",
|
||||
false,
|
||||
undefined, // explicitly undefined mcpHub
|
||||
undefined,
|
||||
undefined
|
||||
undefined,
|
||||
)
|
||||
|
||||
expect(prompt).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should handle different browser viewport sizes', async () => {
|
||||
it("should handle different browser viewport sizes", async () => {
|
||||
const prompt = await SYSTEM_PROMPT(
|
||||
'/test/path',
|
||||
"/test/path",
|
||||
true,
|
||||
undefined,
|
||||
undefined,
|
||||
'900x600' // different viewport size
|
||||
"900x600", // different viewport size
|
||||
)
|
||||
|
||||
expect(prompt).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should include diff strategy tool description', async () => {
|
||||
it("should include diff strategy tool description", async () => {
|
||||
const prompt = await SYSTEM_PROMPT(
|
||||
'/test/path',
|
||||
"/test/path",
|
||||
false,
|
||||
undefined,
|
||||
new SearchReplaceDiffStrategy(), // Use actual diff strategy from the codebase
|
||||
undefined
|
||||
undefined,
|
||||
)
|
||||
|
||||
expect(prompt).toMatchSnapshot()
|
||||
@@ -164,253 +155,197 @@ describe('SYSTEM_PROMPT', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('addCustomInstructions', () => {
|
||||
describe("addCustomInstructions", () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should generate correct prompt for architect mode', async () => {
|
||||
const prompt = await SYSTEM_PROMPT(
|
||||
'/test/path',
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
'architect'
|
||||
)
|
||||
it("should generate correct prompt for architect mode", async () => {
|
||||
const prompt = await SYSTEM_PROMPT("/test/path", false, undefined, undefined, undefined, "architect")
|
||||
|
||||
expect(prompt).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should generate correct prompt for ask mode', async () => {
|
||||
const prompt = await SYSTEM_PROMPT(
|
||||
'/test/path',
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
'ask'
|
||||
)
|
||||
it("should generate correct prompt for ask mode", async () => {
|
||||
const prompt = await SYSTEM_PROMPT("/test/path", false, undefined, undefined, undefined, "ask")
|
||||
|
||||
expect(prompt).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should prioritize mode-specific rules for code mode', async () => {
|
||||
const instructions = await addCustomInstructions(
|
||||
{},
|
||||
'/test/path',
|
||||
defaultModeSlug
|
||||
)
|
||||
it("should prioritize mode-specific rules for code mode", async () => {
|
||||
const instructions = await addCustomInstructions({}, "/test/path", defaultModeSlug)
|
||||
expect(instructions).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should prioritize mode-specific rules for ask mode', async () => {
|
||||
const instructions = await addCustomInstructions(
|
||||
{},
|
||||
'/test/path',
|
||||
modes[2].slug
|
||||
)
|
||||
it("should prioritize mode-specific rules for ask mode", async () => {
|
||||
const instructions = await addCustomInstructions({}, "/test/path", modes[2].slug)
|
||||
expect(instructions).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should prioritize mode-specific rules for architect mode', async () => {
|
||||
const instructions = await addCustomInstructions(
|
||||
{},
|
||||
'/test/path',
|
||||
modes[1].slug
|
||||
)
|
||||
it("should prioritize mode-specific rules for architect mode", async () => {
|
||||
const instructions = await addCustomInstructions({}, "/test/path", modes[1].slug)
|
||||
|
||||
expect(instructions).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should prioritize mode-specific rules for test engineer mode', async () => {
|
||||
it("should prioritize mode-specific rules for test engineer mode", async () => {
|
||||
// Mock readFile to include test engineer rules
|
||||
const mockReadFile = jest.fn().mockImplementation(async (path: string) => {
|
||||
if (path.endsWith('.clinerules-test')) {
|
||||
return '# Test Engineer Rules\n1. Always write tests first\n2. Get approval before modifying non-test code'
|
||||
if (path.endsWith(".clinerules-test")) {
|
||||
return "# Test Engineer Rules\n1. Always write tests first\n2. Get approval before modifying non-test code"
|
||||
}
|
||||
if (path.endsWith('.clinerules')) {
|
||||
return '# Test Rules\n1. First rule\n2. Second rule'
|
||||
if (path.endsWith(".clinerules")) {
|
||||
return "# Test Rules\n1. First rule\n2. Second rule"
|
||||
}
|
||||
return ''
|
||||
return ""
|
||||
})
|
||||
jest.spyOn(fs, 'readFile').mockImplementation(mockReadFile)
|
||||
jest.spyOn(fs, "readFile").mockImplementation(mockReadFile)
|
||||
|
||||
const instructions = await addCustomInstructions(
|
||||
{},
|
||||
'/test/path',
|
||||
'test'
|
||||
)
|
||||
const instructions = await addCustomInstructions({}, "/test/path", "test")
|
||||
expect(instructions).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should prioritize mode-specific rules for code reviewer mode', async () => {
|
||||
it("should prioritize mode-specific rules for code reviewer mode", async () => {
|
||||
// Mock readFile to include code reviewer rules
|
||||
const mockReadFile = jest.fn().mockImplementation(async (path: string) => {
|
||||
if (path.endsWith('.clinerules-review')) {
|
||||
return '# Code Reviewer Rules\n1. Provide specific examples in feedback\n2. Focus on maintainability and best practices'
|
||||
if (path.endsWith(".clinerules-review")) {
|
||||
return "# Code Reviewer Rules\n1. Provide specific examples in feedback\n2. Focus on maintainability and best practices"
|
||||
}
|
||||
if (path.endsWith('.clinerules')) {
|
||||
return '# Test Rules\n1. First rule\n2. Second rule'
|
||||
if (path.endsWith(".clinerules")) {
|
||||
return "# Test Rules\n1. First rule\n2. Second rule"
|
||||
}
|
||||
return ''
|
||||
return ""
|
||||
})
|
||||
jest.spyOn(fs, 'readFile').mockImplementation(mockReadFile)
|
||||
jest.spyOn(fs, "readFile").mockImplementation(mockReadFile)
|
||||
|
||||
const instructions = await addCustomInstructions(
|
||||
{},
|
||||
'/test/path',
|
||||
'review'
|
||||
)
|
||||
const instructions = await addCustomInstructions({}, "/test/path", "review")
|
||||
expect(instructions).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should generate correct prompt for test engineer mode', async () => {
|
||||
const prompt = await SYSTEM_PROMPT(
|
||||
'/test/path',
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
'test'
|
||||
)
|
||||
it("should generate correct prompt for test engineer mode", async () => {
|
||||
const prompt = await SYSTEM_PROMPT("/test/path", false, undefined, undefined, undefined, "test")
|
||||
|
||||
// Verify test engineer role requirements
|
||||
expect(prompt).toContain('must ask the user to confirm before making ANY changes to non-test code')
|
||||
expect(prompt).toContain('ask the user to confirm your test plan')
|
||||
expect(prompt).toContain("must ask the user to confirm before making ANY changes to non-test code")
|
||||
expect(prompt).toContain("ask the user to confirm your test plan")
|
||||
expect(prompt).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should generate correct prompt for code reviewer mode', async () => {
|
||||
const prompt = await SYSTEM_PROMPT(
|
||||
'/test/path',
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
'review'
|
||||
)
|
||||
it("should generate correct prompt for code reviewer mode", async () => {
|
||||
const prompt = await SYSTEM_PROMPT("/test/path", false, undefined, undefined, undefined, "review")
|
||||
|
||||
// Verify code reviewer role constraints
|
||||
expect(prompt).toContain('providing detailed, actionable feedback')
|
||||
expect(prompt).toContain('maintain a read-only approach')
|
||||
expect(prompt).toContain("providing detailed, actionable feedback")
|
||||
expect(prompt).toContain("maintain a read-only approach")
|
||||
expect(prompt).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should fall back to generic rules when mode-specific rules not found', async () => {
|
||||
it("should fall back to generic rules when mode-specific rules not found", async () => {
|
||||
// Mock readFile to return ENOENT for mode-specific file
|
||||
const mockReadFile = jest.fn().mockImplementation(async (path: string) => {
|
||||
if (path.endsWith('.clinerules-code') ||
|
||||
path.endsWith('.clinerules-test') ||
|
||||
path.endsWith('.clinerules-review')) {
|
||||
const error = new Error('ENOENT') as NodeJS.ErrnoException
|
||||
error.code = 'ENOENT'
|
||||
if (
|
||||
path.endsWith(".clinerules-code") ||
|
||||
path.endsWith(".clinerules-test") ||
|
||||
path.endsWith(".clinerules-review")
|
||||
) {
|
||||
const error = new Error("ENOENT") as NodeJS.ErrnoException
|
||||
error.code = "ENOENT"
|
||||
throw error
|
||||
}
|
||||
if (path.endsWith('.clinerules')) {
|
||||
return '# Test Rules\n1. First rule\n2. Second rule'
|
||||
if (path.endsWith(".clinerules")) {
|
||||
return "# Test Rules\n1. First rule\n2. Second rule"
|
||||
}
|
||||
return ''
|
||||
return ""
|
||||
})
|
||||
jest.spyOn(fs, 'readFile').mockImplementation(mockReadFile)
|
||||
jest.spyOn(fs, "readFile").mockImplementation(mockReadFile)
|
||||
|
||||
const instructions = await addCustomInstructions({}, "/test/path", defaultModeSlug)
|
||||
|
||||
expect(instructions).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it("should include preferred language when provided", async () => {
|
||||
const instructions = await addCustomInstructions(
|
||||
{},
|
||||
'/test/path',
|
||||
defaultModeSlug
|
||||
{ preferredLanguage: "Spanish" },
|
||||
"/test/path",
|
||||
defaultModeSlug,
|
||||
)
|
||||
|
||||
expect(instructions).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should include preferred language when provided', async () => {
|
||||
it("should include custom instructions when provided", async () => {
|
||||
const instructions = await addCustomInstructions(
|
||||
{ preferredLanguage: 'Spanish' },
|
||||
'/test/path',
|
||||
defaultModeSlug
|
||||
{ customInstructions: "Custom test instructions" },
|
||||
"/test/path",
|
||||
)
|
||||
|
||||
expect(instructions).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should include custom instructions when provided', async () => {
|
||||
const instructions = await addCustomInstructions(
|
||||
{ customInstructions: 'Custom test instructions' },
|
||||
'/test/path'
|
||||
)
|
||||
|
||||
expect(instructions).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should combine all custom instructions', async () => {
|
||||
it("should combine all custom instructions", async () => {
|
||||
const instructions = await addCustomInstructions(
|
||||
{
|
||||
customInstructions: 'Custom test instructions',
|
||||
preferredLanguage: 'French'
|
||||
customInstructions: "Custom test instructions",
|
||||
preferredLanguage: "French",
|
||||
},
|
||||
'/test/path',
|
||||
defaultModeSlug
|
||||
"/test/path",
|
||||
defaultModeSlug,
|
||||
)
|
||||
expect(instructions).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should handle undefined mode-specific instructions', async () => {
|
||||
it("should handle undefined mode-specific instructions", async () => {
|
||||
const instructions = await addCustomInstructions({}, "/test/path")
|
||||
|
||||
expect(instructions).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it("should trim mode-specific instructions", async () => {
|
||||
const instructions = await addCustomInstructions(
|
||||
{},
|
||||
'/test/path'
|
||||
{ customInstructions: " Custom mode instructions " },
|
||||
"/test/path",
|
||||
)
|
||||
|
||||
expect(instructions).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should trim mode-specific instructions', async () => {
|
||||
const instructions = await addCustomInstructions(
|
||||
{ customInstructions: ' Custom mode instructions ' },
|
||||
'/test/path'
|
||||
)
|
||||
it("should handle empty mode-specific instructions", async () => {
|
||||
const instructions = await addCustomInstructions({ customInstructions: "" }, "/test/path")
|
||||
|
||||
expect(instructions).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should handle empty mode-specific instructions', async () => {
|
||||
const instructions = await addCustomInstructions(
|
||||
{ customInstructions: '' },
|
||||
'/test/path'
|
||||
)
|
||||
|
||||
expect(instructions).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should combine global and mode-specific instructions', async () => {
|
||||
it("should combine global and mode-specific instructions", async () => {
|
||||
const instructions = await addCustomInstructions(
|
||||
{
|
||||
customInstructions: 'Global instructions',
|
||||
customInstructions: "Global instructions",
|
||||
customPrompts: {
|
||||
code: { customInstructions: 'Mode-specific instructions' }
|
||||
}
|
||||
code: { customInstructions: "Mode-specific instructions" },
|
||||
},
|
||||
'/test/path',
|
||||
defaultModeSlug
|
||||
},
|
||||
"/test/path",
|
||||
defaultModeSlug,
|
||||
)
|
||||
|
||||
expect(instructions).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should prioritize mode-specific instructions after global ones', async () => {
|
||||
it("should prioritize mode-specific instructions after global ones", async () => {
|
||||
const instructions = await addCustomInstructions(
|
||||
{
|
||||
customInstructions: 'First instruction',
|
||||
customInstructions: "First instruction",
|
||||
customPrompts: {
|
||||
code: { customInstructions: 'Second instruction' }
|
||||
}
|
||||
code: { customInstructions: "Second instruction" },
|
||||
},
|
||||
'/test/path',
|
||||
defaultModeSlug
|
||||
},
|
||||
"/test/path",
|
||||
defaultModeSlug,
|
||||
)
|
||||
|
||||
const instructionParts = instructions.split('\n\n')
|
||||
const globalIndex = instructionParts.findIndex(part => part.includes('First instruction'))
|
||||
const modeSpecificIndex = instructionParts.findIndex(part => part.includes('Second instruction'))
|
||||
const instructionParts = instructions.split("\n\n")
|
||||
const globalIndex = instructionParts.findIndex((part) => part.includes("First instruction"))
|
||||
const modeSpecificIndex = instructionParts.findIndex((part) => part.includes("Second instruction"))
|
||||
|
||||
expect(globalIndex).toBeLessThan(modeSpecificIndex)
|
||||
expect(instructions).toMatchSnapshot()
|
||||
|
||||
@@ -22,7 +22,11 @@ CAPABILITIES
|
||||
supportsComputerUse
|
||||
? "\n- You can use the browser_action tool to interact with websites (including html files and locally running development servers) through a Puppeteer-controlled browser when you feel it is necessary in accomplishing the user's task. This tool is particularly useful for web development tasks as it allows you to launch a browser, navigate to pages, interact with elements through clicks and keyboard input, and capture the results through screenshots and console logs. This tool may be useful at key stages of web development tasks-such as after implementing new features, making substantial changes, when troubleshooting issues, or to verify the result of your work. You can analyze the provided screenshots to ensure correct rendering or identify errors, and review console logs for runtime issues.\n - For example, if asked to add a component to a react website, you might create the necessary files, use execute_command to run the site locally, then use browser_action to launch the browser, navigate to the local server, and verify the component renders & functions correctly before closing the browser."
|
||||
: ""
|
||||
}${mcpHub ? `
|
||||
}${
|
||||
mcpHub
|
||||
? `
|
||||
- You have access to MCP servers that may provide additional tools and resources. Each server may provide different capabilities that you can use to accomplish tasks more effectively.
|
||||
` : ''}`
|
||||
`
|
||||
: ""
|
||||
}`
|
||||
}
|
||||
@@ -1,19 +1,19 @@
|
||||
import fs from 'fs/promises'
|
||||
import path from 'path'
|
||||
import fs from "fs/promises"
|
||||
import path from "path"
|
||||
|
||||
export async function loadRuleFiles(cwd: string): Promise<string> {
|
||||
const ruleFiles = ['.clinerules', '.cursorrules', '.windsurfrules']
|
||||
let combinedRules = ''
|
||||
const ruleFiles = [".clinerules", ".cursorrules", ".windsurfrules"]
|
||||
let combinedRules = ""
|
||||
|
||||
for (const file of ruleFiles) {
|
||||
try {
|
||||
const content = await fs.readFile(path.join(cwd, file), 'utf-8')
|
||||
const content = await fs.readFile(path.join(cwd, file), "utf-8")
|
||||
if (content.trim()) {
|
||||
combinedRules += `\n# Rules from ${file}:\n${content.trim()}\n`
|
||||
}
|
||||
} catch (err) {
|
||||
// Silently skip if file doesn't exist
|
||||
if ((err as NodeJS.ErrnoException).code !== 'ENOENT') {
|
||||
if ((err as NodeJS.ErrnoException).code !== "ENOENT") {
|
||||
throw err
|
||||
}
|
||||
}
|
||||
@@ -22,7 +22,11 @@ export async function loadRuleFiles(cwd: string): Promise<string> {
|
||||
return combinedRules
|
||||
}
|
||||
|
||||
export async function addCustomInstructions(customInstructions: string, cwd: string, preferredLanguage?: string): Promise<string> {
|
||||
export async function addCustomInstructions(
|
||||
customInstructions: string,
|
||||
cwd: string,
|
||||
preferredLanguage?: string,
|
||||
): Promise<string> {
|
||||
const ruleFileContent = await loadRuleFiles(cwd)
|
||||
const allInstructions = []
|
||||
|
||||
@@ -38,9 +42,10 @@ export async function addCustomInstructions(customInstructions: string, cwd: str
|
||||
allInstructions.push(ruleFileContent.trim())
|
||||
}
|
||||
|
||||
const joinedInstructions = allInstructions.join('\n\n')
|
||||
const joinedInstructions = allInstructions.join("\n\n")
|
||||
|
||||
return joinedInstructions ? `
|
||||
return joinedInstructions
|
||||
? `
|
||||
====
|
||||
|
||||
USER'S CUSTOM INSTRUCTIONS
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
export { getRulesSection } from './rules'
|
||||
export { getSystemInfoSection } from './system-info'
|
||||
export { getObjectiveSection } from './objective'
|
||||
export { addCustomInstructions } from './custom-instructions'
|
||||
export { getSharedToolUseSection } from './tool-use'
|
||||
export { getMcpServersSection } from './mcp-servers'
|
||||
export { getToolUseGuidelinesSection } from './tool-use-guidelines'
|
||||
export { getCapabilitiesSection } from './capabilities'
|
||||
export { getRulesSection } from "./rules"
|
||||
export { getSystemInfoSection } from "./system-info"
|
||||
export { getObjectiveSection } from "./objective"
|
||||
export { addCustomInstructions } from "./custom-instructions"
|
||||
export { getSharedToolUseSection } from "./tool-use"
|
||||
export { getMcpServersSection } from "./mcp-servers"
|
||||
export { getToolUseGuidelinesSection } from "./tool-use-guidelines"
|
||||
export { getCapabilitiesSection } from "./capabilities"
|
||||
|
||||
@@ -3,10 +3,11 @@ import { McpHub } from "../../../services/mcp/McpHub"
|
||||
|
||||
export async function getMcpServersSection(mcpHub?: McpHub, diffStrategy?: DiffStrategy): Promise<string> {
|
||||
if (!mcpHub) {
|
||||
return '';
|
||||
return ""
|
||||
}
|
||||
|
||||
const connectedServers = mcpHub.getServers().length > 0
|
||||
const connectedServers =
|
||||
mcpHub.getServers().length > 0
|
||||
? `${mcpHub
|
||||
.getServers()
|
||||
.filter((server) => server.status === "connected")
|
||||
@@ -40,7 +41,7 @@ export async function getMcpServersSection(mcpHub?: McpHub, diffStrategy?: DiffS
|
||||
)
|
||||
})
|
||||
.join("\n\n")}`
|
||||
: "(No MCP servers currently connected)";
|
||||
: "(No MCP servers currently connected)"
|
||||
|
||||
return `MCP SERVERS
|
||||
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import { DiffStrategy } from "../../diff/DiffStrategy"
|
||||
|
||||
export function getRulesSection(
|
||||
cwd: string,
|
||||
supportsComputerUse: boolean,
|
||||
diffStrategy?: DiffStrategy
|
||||
): string {
|
||||
export function getRulesSection(cwd: string, supportsComputerUse: boolean, diffStrategy?: DiffStrategy): string {
|
||||
return `====
|
||||
|
||||
RULES
|
||||
|
||||
@@ -9,39 +9,39 @@ import {
|
||||
getSharedToolUseSection,
|
||||
getMcpServersSection,
|
||||
getToolUseGuidelinesSection,
|
||||
getCapabilitiesSection
|
||||
getCapabilitiesSection,
|
||||
} from "./sections"
|
||||
import fs from 'fs/promises'
|
||||
import path from 'path'
|
||||
import fs from "fs/promises"
|
||||
import path from "path"
|
||||
|
||||
async function loadRuleFiles(cwd: string, mode: Mode): Promise<string> {
|
||||
let combinedRules = ''
|
||||
let combinedRules = ""
|
||||
|
||||
// First try mode-specific rules
|
||||
const modeSpecificFile = `.clinerules-${mode}`
|
||||
try {
|
||||
const content = await fs.readFile(path.join(cwd, modeSpecificFile), 'utf-8')
|
||||
const content = await fs.readFile(path.join(cwd, modeSpecificFile), "utf-8")
|
||||
if (content.trim()) {
|
||||
combinedRules += `\n# Rules from ${modeSpecificFile}:\n${content.trim()}\n`
|
||||
}
|
||||
} catch (err) {
|
||||
// Silently skip if file doesn't exist
|
||||
if ((err as NodeJS.ErrnoException).code !== 'ENOENT') {
|
||||
if ((err as NodeJS.ErrnoException).code !== "ENOENT") {
|
||||
throw err
|
||||
}
|
||||
}
|
||||
|
||||
// Then try generic rules files
|
||||
const genericRuleFiles = ['.clinerules']
|
||||
const genericRuleFiles = [".clinerules"]
|
||||
for (const file of genericRuleFiles) {
|
||||
try {
|
||||
const content = await fs.readFile(path.join(cwd, file), 'utf-8')
|
||||
const content = await fs.readFile(path.join(cwd, file), "utf-8")
|
||||
if (content.trim()) {
|
||||
combinedRules += `\n# Rules from ${file}:\n${content.trim()}\n`
|
||||
}
|
||||
} catch (err) {
|
||||
// Silently skip if file doesn't exist
|
||||
if ((err as NodeJS.ErrnoException).code !== 'ENOENT') {
|
||||
if ((err as NodeJS.ErrnoException).code !== "ENOENT") {
|
||||
throw err
|
||||
}
|
||||
}
|
||||
@@ -51,16 +51,12 @@ async function loadRuleFiles(cwd: string, mode: Mode): Promise<string> {
|
||||
}
|
||||
|
||||
interface State {
|
||||
customInstructions?: string;
|
||||
customPrompts?: CustomPrompts;
|
||||
preferredLanguage?: string;
|
||||
customInstructions?: string
|
||||
customPrompts?: CustomPrompts
|
||||
preferredLanguage?: string
|
||||
}
|
||||
|
||||
export async function addCustomInstructions(
|
||||
state: State,
|
||||
cwd: string,
|
||||
mode: Mode = defaultModeSlug
|
||||
): Promise<string> {
|
||||
export async function addCustomInstructions(state: State, cwd: string, mode: Mode = defaultModeSlug): Promise<string> {
|
||||
const ruleFileContent = await loadRuleFiles(cwd, mode)
|
||||
const allInstructions = []
|
||||
|
||||
@@ -73,7 +69,7 @@ export async function addCustomInstructions(
|
||||
}
|
||||
|
||||
const customPrompt = state.customPrompts?.[mode]
|
||||
if (typeof customPrompt === 'object' && customPrompt?.customInstructions?.trim()) {
|
||||
if (typeof customPrompt === "object" && customPrompt?.customInstructions?.trim()) {
|
||||
allInstructions.push(customPrompt.customInstructions.trim())
|
||||
}
|
||||
|
||||
@@ -81,9 +77,10 @@ export async function addCustomInstructions(
|
||||
allInstructions.push(ruleFileContent.trim())
|
||||
}
|
||||
|
||||
const joinedInstructions = allInstructions.join('\n\n')
|
||||
const joinedInstructions = allInstructions.join("\n\n")
|
||||
|
||||
return joinedInstructions ? `
|
||||
return joinedInstructions
|
||||
? `
|
||||
====
|
||||
|
||||
USER'S CUSTOM INSTRUCTIONS
|
||||
@@ -119,9 +116,9 @@ ${getRulesSection(cwd, supportsComputerUse, diffStrategy)}
|
||||
|
||||
${getSystemInfoSection(cwd)}
|
||||
|
||||
${getObjectiveSection()}`;
|
||||
${getObjectiveSection()}`
|
||||
|
||||
return basePrompt;
|
||||
return basePrompt
|
||||
}
|
||||
|
||||
export const SYSTEM_PROMPT = async (
|
||||
@@ -134,15 +131,15 @@ export const SYSTEM_PROMPT = async (
|
||||
customPrompts?: CustomPrompts,
|
||||
) => {
|
||||
const getPromptComponent = (value: unknown) => {
|
||||
if (typeof value === 'object' && value !== null) {
|
||||
return value as PromptComponent;
|
||||
if (typeof value === "object" && value !== null) {
|
||||
return value as PromptComponent
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
|
||||
// Use default mode if not found
|
||||
const currentMode = modes.find(m => m.slug === mode) || modes[0];
|
||||
const promptComponent = getPromptComponent(customPrompts?.[currentMode.slug]);
|
||||
const currentMode = modes.find((m) => m.slug === mode) || modes[0]
|
||||
const promptComponent = getPromptComponent(customPrompts?.[currentMode.slug])
|
||||
|
||||
return generatePrompt(
|
||||
cwd,
|
||||
@@ -151,6 +148,6 @@ export const SYSTEM_PROMPT = async (
|
||||
mcpHub,
|
||||
diffStrategy,
|
||||
browserViewportSize,
|
||||
promptComponent
|
||||
);
|
||||
promptComponent,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { ToolArgs } from './types';
|
||||
import { ToolArgs } from "./types"
|
||||
|
||||
export function getAccessMcpResourceDescription(args: ToolArgs): string | undefined {
|
||||
if (!args.mcpHub) {
|
||||
return undefined;
|
||||
return undefined
|
||||
}
|
||||
return `## access_mcp_resource
|
||||
Description: Request to access a resource provided by a connected MCP server. Resources represent data sources that can be used as context, such as files, API responses, or system information.
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { ToolArgs } from './types';
|
||||
import { ToolArgs } from "./types"
|
||||
|
||||
export function getBrowserActionDescription(args: ToolArgs): string | undefined {
|
||||
if (!args.supportsComputerUse) {
|
||||
return undefined;
|
||||
return undefined
|
||||
}
|
||||
return `## browser_action
|
||||
Description: Request to interact with a Puppeteer-controlled browser. Every action, except \`close\`, will be responded to with a screenshot of the browser's current state, along with any new console logs. You may only perform one browser action per message, and wait for the user's response including a screenshot and logs to determine the next action.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ToolArgs } from './types';
|
||||
import { ToolArgs } from "./types"
|
||||
|
||||
export function getExecuteCommandDescription(args: ToolArgs): string | undefined {
|
||||
return `## execute_command
|
||||
|
||||
@@ -1,34 +1,35 @@
|
||||
import { getExecuteCommandDescription } from './execute-command'
|
||||
import { getReadFileDescription } from './read-file'
|
||||
import { getWriteToFileDescription } from './write-to-file'
|
||||
import { getSearchFilesDescription } from './search-files'
|
||||
import { getListFilesDescription } from './list-files'
|
||||
import { getListCodeDefinitionNamesDescription } from './list-code-definition-names'
|
||||
import { getBrowserActionDescription } from './browser-action'
|
||||
import { getAskFollowupQuestionDescription } from './ask-followup-question'
|
||||
import { getAttemptCompletionDescription } from './attempt-completion'
|
||||
import { getUseMcpToolDescription } from './use-mcp-tool'
|
||||
import { getAccessMcpResourceDescription } from './access-mcp-resource'
|
||||
import { DiffStrategy } from '../../diff/DiffStrategy'
|
||||
import { McpHub } from '../../../services/mcp/McpHub'
|
||||
import { Mode, ToolName, getModeConfig, isToolAllowedForMode } from '../../../shared/modes'
|
||||
import { ToolArgs } from './types'
|
||||
import { getExecuteCommandDescription } from "./execute-command"
|
||||
import { getReadFileDescription } from "./read-file"
|
||||
import { getWriteToFileDescription } from "./write-to-file"
|
||||
import { getSearchFilesDescription } from "./search-files"
|
||||
import { getListFilesDescription } from "./list-files"
|
||||
import { getListCodeDefinitionNamesDescription } from "./list-code-definition-names"
|
||||
import { getBrowserActionDescription } from "./browser-action"
|
||||
import { getAskFollowupQuestionDescription } from "./ask-followup-question"
|
||||
import { getAttemptCompletionDescription } from "./attempt-completion"
|
||||
import { getUseMcpToolDescription } from "./use-mcp-tool"
|
||||
import { getAccessMcpResourceDescription } from "./access-mcp-resource"
|
||||
import { DiffStrategy } from "../../diff/DiffStrategy"
|
||||
import { McpHub } from "../../../services/mcp/McpHub"
|
||||
import { Mode, ToolName, getModeConfig, isToolAllowedForMode } from "../../../shared/modes"
|
||||
import { ToolArgs } from "./types"
|
||||
|
||||
// Map of tool names to their description functions
|
||||
const toolDescriptionMap: Record<string, (args: ToolArgs) => string | undefined> = {
|
||||
'execute_command': args => getExecuteCommandDescription(args),
|
||||
'read_file': args => getReadFileDescription(args),
|
||||
'write_to_file': args => getWriteToFileDescription(args),
|
||||
'search_files': args => getSearchFilesDescription(args),
|
||||
'list_files': args => getListFilesDescription(args),
|
||||
'list_code_definition_names': args => getListCodeDefinitionNamesDescription(args),
|
||||
'browser_action': args => getBrowserActionDescription(args),
|
||||
'ask_followup_question': () => getAskFollowupQuestionDescription(),
|
||||
'attempt_completion': () => getAttemptCompletionDescription(),
|
||||
'use_mcp_tool': args => getUseMcpToolDescription(args),
|
||||
'access_mcp_resource': args => getAccessMcpResourceDescription(args),
|
||||
'apply_diff': args => args.diffStrategy ? args.diffStrategy.getToolDescription({ cwd: args.cwd, toolOptions: args.toolOptions }) : ''
|
||||
};
|
||||
execute_command: (args) => getExecuteCommandDescription(args),
|
||||
read_file: (args) => getReadFileDescription(args),
|
||||
write_to_file: (args) => getWriteToFileDescription(args),
|
||||
search_files: (args) => getSearchFilesDescription(args),
|
||||
list_files: (args) => getListFilesDescription(args),
|
||||
list_code_definition_names: (args) => getListCodeDefinitionNamesDescription(args),
|
||||
browser_action: (args) => getBrowserActionDescription(args),
|
||||
ask_followup_question: () => getAskFollowupQuestionDescription(),
|
||||
attempt_completion: () => getAttemptCompletionDescription(),
|
||||
use_mcp_tool: (args) => getUseMcpToolDescription(args),
|
||||
access_mcp_resource: (args) => getAccessMcpResourceDescription(args),
|
||||
apply_diff: (args) =>
|
||||
args.diffStrategy ? args.diffStrategy.getToolDescription({ cwd: args.cwd, toolOptions: args.toolOptions }) : "",
|
||||
}
|
||||
|
||||
export function getToolDescriptionsForMode(
|
||||
mode: Mode,
|
||||
@@ -36,31 +37,31 @@ export function getToolDescriptionsForMode(
|
||||
supportsComputerUse: boolean,
|
||||
diffStrategy?: DiffStrategy,
|
||||
browserViewportSize?: string,
|
||||
mcpHub?: McpHub
|
||||
mcpHub?: McpHub,
|
||||
): string {
|
||||
const config = getModeConfig(mode);
|
||||
const config = getModeConfig(mode)
|
||||
const args: ToolArgs = {
|
||||
cwd,
|
||||
supportsComputerUse,
|
||||
diffStrategy,
|
||||
browserViewportSize,
|
||||
mcpHub
|
||||
};
|
||||
mcpHub,
|
||||
}
|
||||
|
||||
// Map tool descriptions in the exact order specified in the mode's tools array
|
||||
const descriptions = config.tools.map(([toolName, toolOptions]) => {
|
||||
const descriptionFn = toolDescriptionMap[toolName];
|
||||
const descriptionFn = toolDescriptionMap[toolName]
|
||||
if (!descriptionFn || !isToolAllowedForMode(toolName as ToolName, mode)) {
|
||||
return undefined;
|
||||
return undefined
|
||||
}
|
||||
|
||||
return descriptionFn({
|
||||
...args,
|
||||
toolOptions
|
||||
});
|
||||
});
|
||||
toolOptions,
|
||||
})
|
||||
})
|
||||
|
||||
return `# Tools\n\n${descriptions.filter(Boolean).join('\n\n')}`;
|
||||
return `# Tools\n\n${descriptions.filter(Boolean).join("\n\n")}`
|
||||
}
|
||||
|
||||
// Export individual description functions for backward compatibility
|
||||
@@ -75,5 +76,5 @@ export {
|
||||
getAskFollowupQuestionDescription,
|
||||
getAttemptCompletionDescription,
|
||||
getUseMcpToolDescription,
|
||||
getAccessMcpResourceDescription
|
||||
getAccessMcpResourceDescription,
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ToolArgs } from './types';
|
||||
import { ToolArgs } from "./types"
|
||||
|
||||
export function getListCodeDefinitionNamesDescription(args: ToolArgs): string {
|
||||
return `## list_code_definition_names
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ToolArgs } from './types';
|
||||
import { ToolArgs } from "./types"
|
||||
|
||||
export function getListFilesDescription(args: ToolArgs): string {
|
||||
return `## list_files
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ToolArgs } from './types';
|
||||
import { ToolArgs } from "./types"
|
||||
|
||||
export function getReadFileDescription(args: ToolArgs): string {
|
||||
return `## read_file
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ToolArgs } from './types';
|
||||
import { ToolArgs } from "./types"
|
||||
|
||||
export function getSearchFilesDescription(args: ToolArgs): string {
|
||||
return `## search_files
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { DiffStrategy } from '../../diff/DiffStrategy'
|
||||
import { McpHub } from '../../../services/mcp/McpHub'
|
||||
import { DiffStrategy } from "../../diff/DiffStrategy"
|
||||
import { McpHub } from "../../../services/mcp/McpHub"
|
||||
|
||||
export type ToolArgs = {
|
||||
cwd: string;
|
||||
supportsComputerUse: boolean;
|
||||
diffStrategy?: DiffStrategy;
|
||||
browserViewportSize?: string;
|
||||
mcpHub?: McpHub;
|
||||
toolOptions?: any;
|
||||
};
|
||||
cwd: string
|
||||
supportsComputerUse: boolean
|
||||
diffStrategy?: DiffStrategy
|
||||
browserViewportSize?: string
|
||||
mcpHub?: McpHub
|
||||
toolOptions?: any
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { ToolArgs } from './types';
|
||||
import { ToolArgs } from "./types"
|
||||
|
||||
export function getUseMcpToolDescription(args: ToolArgs): string | undefined {
|
||||
if (!args.mcpHub) {
|
||||
return undefined;
|
||||
return undefined
|
||||
}
|
||||
return `## use_mcp_tool
|
||||
Description: Request to use a tool provided by a connected MCP server. Each MCP server can provide multiple tools with different capabilities. Tools have defined input schemas that specify required and optional parameters.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ToolArgs } from './types';
|
||||
import { ToolArgs } from "./types"
|
||||
|
||||
export function getWriteToFileDescription(args: ToolArgs): string {
|
||||
return `## write_to_file
|
||||
|
||||
@@ -1,52 +1,52 @@
|
||||
import { Mode } from '../../shared/modes';
|
||||
import { Mode } from "../../shared/modes"
|
||||
|
||||
export type { Mode };
|
||||
export type { Mode }
|
||||
|
||||
export type ToolName =
|
||||
| 'execute_command'
|
||||
| 'read_file'
|
||||
| 'write_to_file'
|
||||
| 'apply_diff'
|
||||
| 'search_files'
|
||||
| 'list_files'
|
||||
| 'list_code_definition_names'
|
||||
| 'browser_action'
|
||||
| 'use_mcp_tool'
|
||||
| 'access_mcp_resource'
|
||||
| 'ask_followup_question'
|
||||
| 'attempt_completion';
|
||||
| "execute_command"
|
||||
| "read_file"
|
||||
| "write_to_file"
|
||||
| "apply_diff"
|
||||
| "search_files"
|
||||
| "list_files"
|
||||
| "list_code_definition_names"
|
||||
| "browser_action"
|
||||
| "use_mcp_tool"
|
||||
| "access_mcp_resource"
|
||||
| "ask_followup_question"
|
||||
| "attempt_completion"
|
||||
|
||||
export const CODE_TOOLS: ToolName[] = [
|
||||
'execute_command',
|
||||
'read_file',
|
||||
'write_to_file',
|
||||
'apply_diff',
|
||||
'search_files',
|
||||
'list_files',
|
||||
'list_code_definition_names',
|
||||
'browser_action',
|
||||
'use_mcp_tool',
|
||||
'access_mcp_resource',
|
||||
'ask_followup_question',
|
||||
'attempt_completion'
|
||||
];
|
||||
"execute_command",
|
||||
"read_file",
|
||||
"write_to_file",
|
||||
"apply_diff",
|
||||
"search_files",
|
||||
"list_files",
|
||||
"list_code_definition_names",
|
||||
"browser_action",
|
||||
"use_mcp_tool",
|
||||
"access_mcp_resource",
|
||||
"ask_followup_question",
|
||||
"attempt_completion",
|
||||
]
|
||||
|
||||
export const ARCHITECT_TOOLS: ToolName[] = [
|
||||
'read_file',
|
||||
'search_files',
|
||||
'list_files',
|
||||
'list_code_definition_names',
|
||||
'ask_followup_question',
|
||||
'attempt_completion'
|
||||
];
|
||||
"read_file",
|
||||
"search_files",
|
||||
"list_files",
|
||||
"list_code_definition_names",
|
||||
"ask_followup_question",
|
||||
"attempt_completion",
|
||||
]
|
||||
|
||||
export const ASK_TOOLS: ToolName[] = [
|
||||
'read_file',
|
||||
'search_files',
|
||||
'list_files',
|
||||
'browser_action',
|
||||
'use_mcp_tool',
|
||||
'access_mcp_resource',
|
||||
'ask_followup_question',
|
||||
'attempt_completion'
|
||||
];
|
||||
"read_file",
|
||||
"search_files",
|
||||
"list_files",
|
||||
"browser_action",
|
||||
"use_mcp_tool",
|
||||
"access_mcp_resource",
|
||||
"ask_followup_question",
|
||||
"attempt_completion",
|
||||
]
|
||||
|
||||
@@ -1,32 +1,32 @@
|
||||
// Shared tools for architect and ask modes - read-only operations plus MCP and browser tools
|
||||
export const READONLY_ALLOWED_TOOLS = [
|
||||
'read_file',
|
||||
'search_files',
|
||||
'list_files',
|
||||
'list_code_definition_names',
|
||||
'browser_action',
|
||||
'use_mcp_tool',
|
||||
'access_mcp_resource',
|
||||
'ask_followup_question',
|
||||
'attempt_completion'
|
||||
] as const;
|
||||
"read_file",
|
||||
"search_files",
|
||||
"list_files",
|
||||
"list_code_definition_names",
|
||||
"browser_action",
|
||||
"use_mcp_tool",
|
||||
"access_mcp_resource",
|
||||
"ask_followup_question",
|
||||
"attempt_completion",
|
||||
] as const
|
||||
|
||||
// Code mode has access to all tools
|
||||
export const CODE_ALLOWED_TOOLS = [
|
||||
'execute_command',
|
||||
'read_file',
|
||||
'write_to_file',
|
||||
'apply_diff',
|
||||
'search_files',
|
||||
'list_files',
|
||||
'list_code_definition_names',
|
||||
'browser_action',
|
||||
'use_mcp_tool',
|
||||
'access_mcp_resource',
|
||||
'ask_followup_question',
|
||||
'attempt_completion'
|
||||
] as const;
|
||||
"execute_command",
|
||||
"read_file",
|
||||
"write_to_file",
|
||||
"apply_diff",
|
||||
"search_files",
|
||||
"list_files",
|
||||
"list_code_definition_names",
|
||||
"browser_action",
|
||||
"use_mcp_tool",
|
||||
"access_mcp_resource",
|
||||
"ask_followup_question",
|
||||
"attempt_completion",
|
||||
] as const
|
||||
|
||||
// Tool name types for type safety
|
||||
export type ReadOnlyToolName = typeof READONLY_ALLOWED_TOOLS[number];
|
||||
export type ToolName = typeof CODE_ALLOWED_TOOLS[number];
|
||||
export type ReadOnlyToolName = (typeof READONLY_ALLOWED_TOOLS)[number]
|
||||
export type ToolName = (typeof CODE_ALLOWED_TOOLS)[number]
|
||||
|
||||
@@ -254,14 +254,12 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
fuzzyMatchThreshold,
|
||||
mode,
|
||||
customInstructions: globalInstructions,
|
||||
experimentalDiffStrategy
|
||||
experimentalDiffStrategy,
|
||||
} = await this.getState()
|
||||
|
||||
const modePrompt = customPrompts?.[mode]
|
||||
const modeInstructions = typeof modePrompt === 'object' ? modePrompt.customInstructions : undefined
|
||||
const effectiveInstructions = [globalInstructions, modeInstructions]
|
||||
.filter(Boolean)
|
||||
.join('\n\n')
|
||||
const modeInstructions = typeof modePrompt === "object" ? modePrompt.customInstructions : undefined
|
||||
const effectiveInstructions = [globalInstructions, modeInstructions].filter(Boolean).join("\n\n")
|
||||
|
||||
this.cline = new Cline(
|
||||
this,
|
||||
@@ -272,7 +270,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
task,
|
||||
images,
|
||||
undefined,
|
||||
experimentalDiffStrategy
|
||||
experimentalDiffStrategy,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -285,14 +283,12 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
fuzzyMatchThreshold,
|
||||
mode,
|
||||
customInstructions: globalInstructions,
|
||||
experimentalDiffStrategy
|
||||
experimentalDiffStrategy,
|
||||
} = await this.getState()
|
||||
|
||||
const modePrompt = customPrompts?.[mode]
|
||||
const modeInstructions = typeof modePrompt === 'object' ? modePrompt.customInstructions : undefined
|
||||
const effectiveInstructions = [globalInstructions, modeInstructions]
|
||||
.filter(Boolean)
|
||||
.join('\n\n')
|
||||
const modeInstructions = typeof modePrompt === "object" ? modePrompt.customInstructions : undefined
|
||||
const effectiveInstructions = [globalInstructions, modeInstructions].filter(Boolean).join("\n\n")
|
||||
|
||||
this.cline = new Cline(
|
||||
this,
|
||||
@@ -303,7 +299,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
undefined,
|
||||
undefined,
|
||||
historyItem,
|
||||
experimentalDiffStrategy
|
||||
experimentalDiffStrategy,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -403,7 +399,6 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
async (message: WebviewMessage) => {
|
||||
switch (message.type) {
|
||||
case "webviewDidLaunch":
|
||||
|
||||
this.postStateToWebview()
|
||||
this.workspaceTracker?.initializeFilePaths() // don't await
|
||||
getTheme().then((theme) =>
|
||||
@@ -450,9 +445,9 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
this.configManager.ListConfig().then(async (listApiConfig) => {
|
||||
|
||||
this.configManager
|
||||
.ListConfig()
|
||||
.then(async (listApiConfig) => {
|
||||
if (!listApiConfig) {
|
||||
return
|
||||
}
|
||||
@@ -460,22 +455,25 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
if (listApiConfig.length === 1) {
|
||||
// check if first time init then sync with exist config
|
||||
if (!checkExistKey(listApiConfig[0])) {
|
||||
const {
|
||||
const { apiConfiguration } = await this.getState()
|
||||
await this.configManager.SaveConfig(
|
||||
listApiConfig[0].name ?? "default",
|
||||
apiConfiguration,
|
||||
} = await this.getState()
|
||||
await this.configManager.SaveConfig(listApiConfig[0].name ?? "default", apiConfiguration)
|
||||
)
|
||||
listApiConfig[0].apiProvider = apiConfiguration.apiProvider
|
||||
}
|
||||
}
|
||||
|
||||
let currentConfigName = await this.getGlobalState("currentApiConfigName") as string
|
||||
let currentConfigName = (await this.getGlobalState("currentApiConfigName")) as string
|
||||
|
||||
if (currentConfigName) {
|
||||
if (!await this.configManager.HasConfig(currentConfigName)) {
|
||||
if (!(await this.configManager.HasConfig(currentConfigName))) {
|
||||
// current config name not valid, get first config in list
|
||||
await this.updateGlobalState("currentApiConfigName", listApiConfig?.[0]?.name)
|
||||
if (listApiConfig?.[0]?.name) {
|
||||
const apiConfig = await this.configManager.LoadConfig(listApiConfig?.[0]?.name);
|
||||
const apiConfig = await this.configManager.LoadConfig(
|
||||
listApiConfig?.[0]?.name,
|
||||
)
|
||||
|
||||
await Promise.all([
|
||||
this.updateGlobalState("listApiConfigMeta", listApiConfig),
|
||||
@@ -485,18 +483,15 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
await this.postStateToWebview()
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
await Promise.all(
|
||||
[
|
||||
await Promise.all([
|
||||
await this.updateGlobalState("listApiConfigMeta", listApiConfig),
|
||||
await this.postMessageToWebview({ type: "listApiConfig", listApiConfig })
|
||||
]
|
||||
)
|
||||
}).catch(console.error);
|
||||
await this.postMessageToWebview({ type: "listApiConfig", listApiConfig }),
|
||||
])
|
||||
})
|
||||
.catch(console.error)
|
||||
|
||||
break
|
||||
case "newTask":
|
||||
@@ -593,7 +588,10 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
break
|
||||
case "refreshOpenAiModels":
|
||||
if (message?.values?.baseUrl && message?.values?.apiKey) {
|
||||
const openAiModels = await this.getOpenAiModels(message?.values?.baseUrl, message?.values?.apiKey)
|
||||
const openAiModels = await this.getOpenAiModels(
|
||||
message?.values?.baseUrl,
|
||||
message?.values?.apiKey,
|
||||
)
|
||||
this.postMessageToWebview({ type: "openAiModels", openAiModels })
|
||||
}
|
||||
break
|
||||
@@ -625,12 +623,12 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
|
||||
break
|
||||
case "allowedCommands":
|
||||
await this.context.globalState.update('allowedCommands', message.commands);
|
||||
await this.context.globalState.update("allowedCommands", message.commands)
|
||||
// Also update workspace settings
|
||||
await vscode.workspace
|
||||
.getConfiguration('roo-cline')
|
||||
.update('allowedCommands', message.commands, vscode.ConfigurationTarget.Global);
|
||||
break;
|
||||
.getConfiguration("roo-cline")
|
||||
.update("allowedCommands", message.commands, vscode.ConfigurationTarget.Global)
|
||||
break
|
||||
case "openMcpSettings": {
|
||||
const mcpSettingsFilePath = await this.mcpHub?.getMcpSettingsFilePath()
|
||||
if (mcpSettingsFilePath) {
|
||||
@@ -651,7 +649,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
await this.mcpHub?.toggleToolAlwaysAllow(
|
||||
message.serverName!,
|
||||
message.toolName!,
|
||||
message.alwaysAllow!
|
||||
message.alwaysAllow!,
|
||||
)
|
||||
} catch (error) {
|
||||
console.error(`Failed to toggle auto-approve for tool ${message.toolName}:`, error)
|
||||
@@ -660,10 +658,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
}
|
||||
case "toggleMcpServer": {
|
||||
try {
|
||||
await this.mcpHub?.toggleServerDisabled(
|
||||
message.serverName!,
|
||||
message.disabled!
|
||||
)
|
||||
await this.mcpHub?.toggleServerDisabled(message.serverName!, message.disabled!)
|
||||
} catch (error) {
|
||||
console.error(`Failed to toggle MCP server ${message.serverName}:`, error)
|
||||
}
|
||||
@@ -739,19 +734,19 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
|
||||
// If this mode has a saved config, use it
|
||||
if (savedConfigId) {
|
||||
const config = listApiConfig?.find(c => c.id === savedConfigId)
|
||||
const config = listApiConfig?.find((c) => c.id === savedConfigId)
|
||||
if (config?.name) {
|
||||
const apiConfig = await this.configManager.LoadConfig(config.name)
|
||||
await Promise.all([
|
||||
this.updateGlobalState("currentApiConfigName", config.name),
|
||||
this.updateApiConfiguration(apiConfig)
|
||||
this.updateApiConfiguration(apiConfig),
|
||||
])
|
||||
}
|
||||
} else {
|
||||
// If no saved config for this mode, save current config as default
|
||||
const currentApiConfigName = await this.getGlobalState("currentApiConfigName")
|
||||
if (currentApiConfigName) {
|
||||
const config = listApiConfig?.find(c => c.name === currentApiConfigName)
|
||||
const config = listApiConfig?.find((c) => c.name === currentApiConfigName)
|
||||
if (config?.id) {
|
||||
await this.configManager.SetModeConfig(newMode, config.id)
|
||||
}
|
||||
@@ -761,11 +756,11 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
await this.postStateToWebview()
|
||||
break
|
||||
case "updateEnhancedPrompt":
|
||||
const existingPrompts = await this.getGlobalState("customPrompts") || {}
|
||||
const existingPrompts = (await this.getGlobalState("customPrompts")) || {}
|
||||
|
||||
const updatedPrompts = {
|
||||
...existingPrompts,
|
||||
enhance: message.text
|
||||
enhance: message.text,
|
||||
}
|
||||
|
||||
await this.updateGlobalState("customPrompts", updatedPrompts)
|
||||
@@ -775,22 +770,22 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
|
||||
const stateWithPrompts = {
|
||||
...currentState,
|
||||
customPrompts: updatedPrompts
|
||||
customPrompts: updatedPrompts,
|
||||
}
|
||||
|
||||
// Post state with prompts
|
||||
this.view?.webview.postMessage({
|
||||
type: "state",
|
||||
state: stateWithPrompts
|
||||
state: stateWithPrompts,
|
||||
})
|
||||
break
|
||||
case "updatePrompt":
|
||||
if (message.promptMode && message.customPrompt !== undefined) {
|
||||
const existingPrompts = await this.getGlobalState("customPrompts") || {}
|
||||
const existingPrompts = (await this.getGlobalState("customPrompts")) || {}
|
||||
|
||||
const updatedPrompts = {
|
||||
...existingPrompts,
|
||||
[message.promptMode]: message.customPrompt
|
||||
[message.promptMode]: message.customPrompt,
|
||||
}
|
||||
|
||||
await this.updateGlobalState("customPrompts", updatedPrompts)
|
||||
@@ -800,13 +795,13 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
|
||||
const stateWithPrompts = {
|
||||
...currentState,
|
||||
customPrompts: updatedPrompts
|
||||
customPrompts: updatedPrompts,
|
||||
}
|
||||
|
||||
// Post state with prompts
|
||||
this.view?.webview.postMessage({
|
||||
type: "state",
|
||||
state: stateWithPrompts
|
||||
state: stateWithPrompts,
|
||||
})
|
||||
}
|
||||
break
|
||||
@@ -817,11 +812,19 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
"Just this message",
|
||||
"This and all subsequent messages",
|
||||
)
|
||||
if ((answer === "Just this message" || answer === "This and all subsequent messages") &&
|
||||
this.cline && typeof message.value === 'number' && message.value) {
|
||||
const timeCutoff = message.value - 1000; // 1 second buffer before the message to delete
|
||||
const messageIndex = this.cline.clineMessages.findIndex(msg => msg.ts && msg.ts >= timeCutoff)
|
||||
const apiConversationHistoryIndex = this.cline.apiConversationHistory.findIndex(msg => msg.ts && msg.ts >= timeCutoff)
|
||||
if (
|
||||
(answer === "Just this message" || answer === "This and all subsequent messages") &&
|
||||
this.cline &&
|
||||
typeof message.value === "number" &&
|
||||
message.value
|
||||
) {
|
||||
const timeCutoff = message.value - 1000 // 1 second buffer before the message to delete
|
||||
const messageIndex = this.cline.clineMessages.findIndex(
|
||||
(msg) => msg.ts && msg.ts >= timeCutoff,
|
||||
)
|
||||
const apiConversationHistoryIndex = this.cline.apiConversationHistory.findIndex(
|
||||
(msg) => msg.ts && msg.ts >= timeCutoff,
|
||||
)
|
||||
|
||||
if (messageIndex !== -1) {
|
||||
const { historyItem } = await this.getTaskWithId(this.cline.taskId)
|
||||
@@ -830,21 +833,23 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
// Find the next user message first
|
||||
const nextUserMessage = this.cline.clineMessages
|
||||
.slice(messageIndex + 1)
|
||||
.find(msg => msg.type === "say" && msg.say === "user_feedback")
|
||||
.find((msg) => msg.type === "say" && msg.say === "user_feedback")
|
||||
|
||||
// Handle UI messages
|
||||
if (nextUserMessage) {
|
||||
// Find absolute index of next user message
|
||||
const nextUserMessageIndex = this.cline.clineMessages.findIndex(msg => msg === nextUserMessage)
|
||||
const nextUserMessageIndex = this.cline.clineMessages.findIndex(
|
||||
(msg) => msg === nextUserMessage,
|
||||
)
|
||||
// Keep messages before current message and after next user message
|
||||
await this.cline.overwriteClineMessages([
|
||||
...this.cline.clineMessages.slice(0, messageIndex),
|
||||
...this.cline.clineMessages.slice(nextUserMessageIndex)
|
||||
...this.cline.clineMessages.slice(nextUserMessageIndex),
|
||||
])
|
||||
} else {
|
||||
// If no next user message, keep only messages before current message
|
||||
await this.cline.overwriteClineMessages(
|
||||
this.cline.clineMessages.slice(0, messageIndex)
|
||||
this.cline.clineMessages.slice(0, messageIndex),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -853,21 +858,30 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
if (nextUserMessage && nextUserMessage.ts) {
|
||||
// Keep messages before current API message and after next user message
|
||||
await this.cline.overwriteApiConversationHistory([
|
||||
...this.cline.apiConversationHistory.slice(0, apiConversationHistoryIndex),
|
||||
...this.cline.apiConversationHistory.filter(msg => msg.ts && msg.ts >= nextUserMessage.ts)
|
||||
...this.cline.apiConversationHistory.slice(
|
||||
0,
|
||||
apiConversationHistoryIndex,
|
||||
),
|
||||
...this.cline.apiConversationHistory.filter(
|
||||
(msg) => msg.ts && msg.ts >= nextUserMessage.ts,
|
||||
),
|
||||
])
|
||||
} else {
|
||||
// If no next user message, keep only messages before current API message
|
||||
await this.cline.overwriteApiConversationHistory(
|
||||
this.cline.apiConversationHistory.slice(0, apiConversationHistoryIndex)
|
||||
this.cline.apiConversationHistory.slice(0, apiConversationHistoryIndex),
|
||||
)
|
||||
}
|
||||
}
|
||||
} else if (answer === "This and all subsequent messages") {
|
||||
// Delete this message and all that follow
|
||||
await this.cline.overwriteClineMessages(this.cline.clineMessages.slice(0, messageIndex))
|
||||
await this.cline.overwriteClineMessages(
|
||||
this.cline.clineMessages.slice(0, messageIndex),
|
||||
)
|
||||
if (apiConversationHistoryIndex !== -1) {
|
||||
await this.cline.overwriteApiConversationHistory(this.cline.apiConversationHistory.slice(0, apiConversationHistoryIndex))
|
||||
await this.cline.overwriteApiConversationHistory(
|
||||
this.cline.apiConversationHistory.slice(0, apiConversationHistoryIndex),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -891,12 +905,13 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
case "enhancePrompt":
|
||||
if (message.text) {
|
||||
try {
|
||||
const { apiConfiguration, customPrompts, listApiConfigMeta, enhancementApiConfigId } = await this.getState()
|
||||
const { apiConfiguration, customPrompts, listApiConfigMeta, enhancementApiConfigId } =
|
||||
await this.getState()
|
||||
|
||||
// Try to get enhancement config first, fall back to current config
|
||||
let configToUse: ApiConfiguration = apiConfiguration
|
||||
if (enhancementApiConfigId) {
|
||||
const config = listApiConfigMeta?.find(c => c.id === enhancementApiConfigId)
|
||||
const config = listApiConfigMeta?.find((c) => c.id === enhancementApiConfigId)
|
||||
if (config?.name) {
|
||||
const loadedConfig = await this.configManager.LoadConfig(config.name)
|
||||
if (loadedConfig.apiProvider) {
|
||||
@@ -906,39 +921,47 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
}
|
||||
|
||||
const getEnhancePrompt = (value: string | PromptComponent | undefined): string => {
|
||||
if (typeof value === 'string') {
|
||||
return value;
|
||||
if (typeof value === "string") {
|
||||
return value
|
||||
}
|
||||
return enhance.prompt; // Use the constant from modes.ts which we know is a string
|
||||
return enhance.prompt // Use the constant from modes.ts which we know is a string
|
||||
}
|
||||
const enhancedPrompt = await enhancePrompt(
|
||||
configToUse,
|
||||
message.text,
|
||||
getEnhancePrompt(customPrompts?.enhance)
|
||||
getEnhancePrompt(customPrompts?.enhance),
|
||||
)
|
||||
await this.postMessageToWebview({
|
||||
type: "enhancedPrompt",
|
||||
text: enhancedPrompt
|
||||
text: enhancedPrompt,
|
||||
})
|
||||
} catch (error) {
|
||||
console.error("Error enhancing prompt:", error)
|
||||
vscode.window.showErrorMessage("Failed to enhance prompt")
|
||||
await this.postMessageToWebview({
|
||||
type: "enhancedPrompt"
|
||||
type: "enhancedPrompt",
|
||||
})
|
||||
}
|
||||
}
|
||||
break
|
||||
case "getSystemPrompt":
|
||||
try {
|
||||
const { apiConfiguration, customPrompts, customInstructions, preferredLanguage, browserViewportSize, mcpEnabled } = await this.getState()
|
||||
const cwd = vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0) || ''
|
||||
const {
|
||||
apiConfiguration,
|
||||
customPrompts,
|
||||
customInstructions,
|
||||
preferredLanguage,
|
||||
browserViewportSize,
|
||||
mcpEnabled,
|
||||
} = await this.getState()
|
||||
const cwd =
|
||||
vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0) || ""
|
||||
|
||||
const mode = message.mode ?? defaultModeSlug
|
||||
const instructions = await addCustomInstructions(
|
||||
{ customInstructions, customPrompts, preferredLanguage },
|
||||
cwd,
|
||||
mode
|
||||
mode,
|
||||
)
|
||||
|
||||
const systemPrompt = await SYSTEM_PROMPT(
|
||||
@@ -948,14 +971,14 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
undefined,
|
||||
browserViewportSize ?? "900x600",
|
||||
mode,
|
||||
customPrompts
|
||||
customPrompts,
|
||||
)
|
||||
const fullPrompt = instructions ? `${systemPrompt}${instructions}` : systemPrompt
|
||||
|
||||
await this.postMessageToWebview({
|
||||
type: "systemPrompt",
|
||||
text: fullPrompt,
|
||||
mode: message.mode
|
||||
mode: message.mode,
|
||||
})
|
||||
} catch (error) {
|
||||
console.error("Error getting system prompt:", error)
|
||||
@@ -969,7 +992,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
const commits = await searchCommits(message.query || "", cwd)
|
||||
await this.postMessageToWebview({
|
||||
type: "commitSearchResults",
|
||||
commits
|
||||
commits,
|
||||
})
|
||||
} catch (error) {
|
||||
console.error("Error searching commits:", error)
|
||||
@@ -981,8 +1004,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
case "upsertApiConfiguration":
|
||||
if (message.text && message.apiConfiguration) {
|
||||
try {
|
||||
await this.configManager.SaveConfig(message.text, message.apiConfiguration);
|
||||
let listApiConfig = await this.configManager.ListConfig();
|
||||
await this.configManager.SaveConfig(message.text, message.apiConfiguration)
|
||||
let listApiConfig = await this.configManager.ListConfig()
|
||||
|
||||
await Promise.all([
|
||||
this.updateGlobalState("listApiConfigMeta", listApiConfig),
|
||||
@@ -1002,18 +1025,16 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
try {
|
||||
const { oldName, newName } = message.values
|
||||
|
||||
await this.configManager.SaveConfig(newName, message.apiConfiguration);
|
||||
await this.configManager.SaveConfig(newName, message.apiConfiguration)
|
||||
await this.configManager.DeleteConfig(oldName)
|
||||
|
||||
let listApiConfig = await this.configManager.ListConfig();
|
||||
const config = listApiConfig?.find(c => c.name === newName);
|
||||
let listApiConfig = await this.configManager.ListConfig()
|
||||
const config = listApiConfig?.find((c) => c.name === newName)
|
||||
|
||||
// Update listApiConfigMeta first to ensure UI has latest data
|
||||
await this.updateGlobalState("listApiConfigMeta", listApiConfig);
|
||||
await this.updateGlobalState("listApiConfigMeta", listApiConfig)
|
||||
|
||||
await Promise.all([
|
||||
this.updateGlobalState("currentApiConfigName", newName),
|
||||
])
|
||||
await Promise.all([this.updateGlobalState("currentApiConfigName", newName)])
|
||||
|
||||
await this.postStateToWebview()
|
||||
} catch (error) {
|
||||
@@ -1025,8 +1046,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
case "loadApiConfiguration":
|
||||
if (message.text) {
|
||||
try {
|
||||
const apiConfig = await this.configManager.LoadConfig(message.text);
|
||||
const listApiConfig = await this.configManager.ListConfig();
|
||||
const apiConfig = await this.configManager.LoadConfig(message.text)
|
||||
const listApiConfig = await this.configManager.ListConfig()
|
||||
|
||||
await Promise.all([
|
||||
this.updateGlobalState("listApiConfigMeta", listApiConfig),
|
||||
@@ -1054,16 +1075,16 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
}
|
||||
|
||||
try {
|
||||
await this.configManager.DeleteConfig(message.text);
|
||||
const listApiConfig = await this.configManager.ListConfig();
|
||||
await this.configManager.DeleteConfig(message.text)
|
||||
const listApiConfig = await this.configManager.ListConfig()
|
||||
|
||||
// Update listApiConfigMeta first to ensure UI has latest data
|
||||
await this.updateGlobalState("listApiConfigMeta", listApiConfig);
|
||||
await this.updateGlobalState("listApiConfigMeta", listApiConfig)
|
||||
|
||||
// If this was the current config, switch to first available
|
||||
let currentApiConfigName = await this.getGlobalState("currentApiConfigName")
|
||||
if (message.text === currentApiConfigName && listApiConfig?.[0]?.name) {
|
||||
const apiConfig = await this.configManager.LoadConfig(listApiConfig[0].name);
|
||||
const apiConfig = await this.configManager.LoadConfig(listApiConfig[0].name)
|
||||
await Promise.all([
|
||||
this.updateGlobalState("currentApiConfigName", listApiConfig[0].name),
|
||||
this.updateApiConfiguration(apiConfig),
|
||||
@@ -1079,7 +1100,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
break
|
||||
case "getListApiConfiguration":
|
||||
try {
|
||||
let listApiConfig = await this.configManager.ListConfig();
|
||||
let listApiConfig = await this.configManager.ListConfig()
|
||||
await this.updateGlobalState("listApiConfigMeta", listApiConfig)
|
||||
this.postMessageToWebview({ type: "listApiConfig", listApiConfig })
|
||||
} catch (error) {
|
||||
@@ -1103,13 +1124,13 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
|
||||
private async updateApiConfiguration(apiConfiguration: ApiConfiguration) {
|
||||
// Update mode's default config
|
||||
const { mode } = await this.getState();
|
||||
const { mode } = await this.getState()
|
||||
if (mode) {
|
||||
const currentApiConfigName = await this.getGlobalState("currentApiConfigName");
|
||||
const listApiConfig = await this.configManager.ListConfig();
|
||||
const config = listApiConfig?.find(c => c.name === currentApiConfigName);
|
||||
const currentApiConfigName = await this.getGlobalState("currentApiConfigName")
|
||||
const listApiConfig = await this.configManager.ListConfig()
|
||||
const config = listApiConfig?.find((c) => c.name === currentApiConfigName)
|
||||
if (config?.id) {
|
||||
await this.configManager.SetModeConfig(mode, config.id);
|
||||
await this.configManager.SetModeConfig(mode, config.id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1252,11 +1273,11 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
// VSCode LM API
|
||||
private async getVsCodeLmModels() {
|
||||
try {
|
||||
const models = await vscode.lm.selectChatModels({});
|
||||
return models || [];
|
||||
const models = await vscode.lm.selectChatModels({})
|
||||
return models || []
|
||||
} catch (error) {
|
||||
console.error('Error fetching VS Code LM models:', error);
|
||||
return [];
|
||||
console.error("Error fetching VS Code LM models:", error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1346,10 +1367,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
}
|
||||
|
||||
async readGlamaModels(): Promise<Record<string, ModelInfo> | undefined> {
|
||||
const glamaModelsFilePath = path.join(
|
||||
await this.ensureCacheDirectoryExists(),
|
||||
GlobalFileNames.glamaModels,
|
||||
)
|
||||
const glamaModelsFilePath = path.join(await this.ensureCacheDirectoryExists(), GlobalFileNames.glamaModels)
|
||||
const fileExists = await fileExistsAtPath(glamaModelsFilePath)
|
||||
if (fileExists) {
|
||||
const fileContents = await fs.readFile(glamaModelsFilePath, "utf8")
|
||||
@@ -1359,10 +1377,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
}
|
||||
|
||||
async refreshGlamaModels() {
|
||||
const glamaModelsFilePath = path.join(
|
||||
await this.ensureCacheDirectoryExists(),
|
||||
GlobalFileNames.glamaModels,
|
||||
)
|
||||
const glamaModelsFilePath = path.join(await this.ensureCacheDirectoryExists(), GlobalFileNames.glamaModels)
|
||||
|
||||
let models: Record<string, ModelInfo> = {}
|
||||
try {
|
||||
@@ -1397,7 +1412,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
}
|
||||
*/
|
||||
if (response.data) {
|
||||
const rawModels = response.data;
|
||||
const rawModels = response.data
|
||||
const parsePrice = (price: any) => {
|
||||
if (price) {
|
||||
return parseFloat(price) * 1_000_000
|
||||
@@ -1565,7 +1580,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
uiMessagesFilePath: string
|
||||
apiConversationHistory: Anthropic.MessageParam[]
|
||||
}> {
|
||||
const history = (await this.getGlobalState("taskHistory") as HistoryItem[] | undefined) || []
|
||||
const history = ((await this.getGlobalState("taskHistory")) as HistoryItem[] | undefined) || []
|
||||
const historyItem = history.find((item) => item.id === id)
|
||||
if (historyItem) {
|
||||
const taskDirPath = path.join(this.context.globalStorageUri.fsPath, "tasks", id)
|
||||
@@ -1630,7 +1645,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
|
||||
async deleteTaskFromState(id: string) {
|
||||
// Remove the task from history
|
||||
const taskHistory = (await this.getGlobalState("taskHistory") as HistoryItem[]) || []
|
||||
const taskHistory = ((await this.getGlobalState("taskHistory")) as HistoryItem[]) || []
|
||||
const updatedTaskHistory = taskHistory.filter((task) => task.id !== id)
|
||||
await this.updateGlobalState("taskHistory", updatedTaskHistory)
|
||||
|
||||
@@ -1675,9 +1690,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
autoApprovalEnabled,
|
||||
} = await this.getState()
|
||||
|
||||
const allowedCommands = vscode.workspace
|
||||
.getConfiguration('roo-cline')
|
||||
.get<string[]>('allowedCommands') || []
|
||||
const allowedCommands = vscode.workspace.getConfiguration("roo-cline").get<string[]>("allowedCommands") || []
|
||||
|
||||
return {
|
||||
version: this.context.extension?.packageJSON?.version ?? "",
|
||||
@@ -1700,7 +1713,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
soundVolume: soundVolume ?? 0.5,
|
||||
browserViewportSize: browserViewportSize ?? "900x600",
|
||||
screenshotQuality: screenshotQuality ?? 75,
|
||||
preferredLanguage: preferredLanguage ?? 'English',
|
||||
preferredLanguage: preferredLanguage ?? "English",
|
||||
writeDelayMs: writeDelayMs ?? 1000,
|
||||
terminalOutputLineLimit: terminalOutputLineLimit ?? 500,
|
||||
fuzzyMatchThreshold: fuzzyMatchThreshold ?? 1.0,
|
||||
@@ -1962,39 +1975,41 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
writeDelayMs: writeDelayMs ?? 1000,
|
||||
terminalOutputLineLimit: terminalOutputLineLimit ?? 500,
|
||||
mode: mode ?? defaultModeSlug,
|
||||
preferredLanguage: preferredLanguage ?? (() => {
|
||||
preferredLanguage:
|
||||
preferredLanguage ??
|
||||
(() => {
|
||||
// Get VSCode's locale setting
|
||||
const vscodeLang = vscode.env.language;
|
||||
const vscodeLang = vscode.env.language
|
||||
// Map VSCode locale to our supported languages
|
||||
const langMap: { [key: string]: string } = {
|
||||
'en': 'English',
|
||||
'ar': 'Arabic',
|
||||
'pt-br': 'Brazilian Portuguese',
|
||||
'cs': 'Czech',
|
||||
'fr': 'French',
|
||||
'de': 'German',
|
||||
'hi': 'Hindi',
|
||||
'hu': 'Hungarian',
|
||||
'it': 'Italian',
|
||||
'ja': 'Japanese',
|
||||
'ko': 'Korean',
|
||||
'pl': 'Polish',
|
||||
'pt': 'Portuguese',
|
||||
'ru': 'Russian',
|
||||
'zh-cn': 'Simplified Chinese',
|
||||
'es': 'Spanish',
|
||||
'zh-tw': 'Traditional Chinese',
|
||||
'tr': 'Turkish'
|
||||
};
|
||||
en: "English",
|
||||
ar: "Arabic",
|
||||
"pt-br": "Brazilian Portuguese",
|
||||
cs: "Czech",
|
||||
fr: "French",
|
||||
de: "German",
|
||||
hi: "Hindi",
|
||||
hu: "Hungarian",
|
||||
it: "Italian",
|
||||
ja: "Japanese",
|
||||
ko: "Korean",
|
||||
pl: "Polish",
|
||||
pt: "Portuguese",
|
||||
ru: "Russian",
|
||||
"zh-cn": "Simplified Chinese",
|
||||
es: "Spanish",
|
||||
"zh-tw": "Traditional Chinese",
|
||||
tr: "Turkish",
|
||||
}
|
||||
// Return mapped language or default to English
|
||||
return langMap[vscodeLang.split('-')[0]] ?? 'English';
|
||||
return langMap[vscodeLang.split("-")[0]] ?? "English"
|
||||
})(),
|
||||
mcpEnabled: mcpEnabled ?? true,
|
||||
alwaysApproveResubmit: alwaysApproveResubmit ?? false,
|
||||
requestDelaySeconds: requestDelaySeconds ?? 5,
|
||||
currentApiConfigName: currentApiConfigName ?? "default",
|
||||
listApiConfigMeta: listApiConfigMeta ?? [],
|
||||
modeApiConfigs: modeApiConfigs ?? {} as Record<Mode, string>,
|
||||
modeApiConfigs: modeApiConfigs ?? ({} as Record<Mode, string>),
|
||||
customPrompts: customPrompts ?? {},
|
||||
enhancementApiConfigId,
|
||||
experimentalDiffStrategy: experimentalDiffStrategy ?? false,
|
||||
@@ -2003,7 +2018,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
}
|
||||
|
||||
async updateTaskHistory(item: HistoryItem): Promise<HistoryItem[]> {
|
||||
const history = (await this.getGlobalState("taskHistory") as HistoryItem[] | undefined) || []
|
||||
const history = ((await this.getGlobalState("taskHistory")) as HistoryItem[] | undefined) || []
|
||||
const existingItemIndex = history.findIndex((h) => h.id === item.id)
|
||||
|
||||
if (existingItemIndex !== -1) {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -27,13 +27,11 @@ export function activate(context: vscode.ExtensionContext) {
|
||||
outputChannel.appendLine("Cline extension activated")
|
||||
|
||||
// Get default commands from configuration
|
||||
const defaultCommands = vscode.workspace
|
||||
.getConfiguration('roo-cline')
|
||||
.get<string[]>('allowedCommands') || [];
|
||||
const defaultCommands = vscode.workspace.getConfiguration("roo-cline").get<string[]>("allowedCommands") || []
|
||||
|
||||
// Initialize global state if not already set
|
||||
if (!context.globalState.get('allowedCommands')) {
|
||||
context.globalState.update('allowedCommands', defaultCommands);
|
||||
if (!context.globalState.get("allowedCommands")) {
|
||||
context.globalState.update("allowedCommands", defaultCommands)
|
||||
}
|
||||
|
||||
const sidebarProvider = new ClineProvider(context, outputChannel)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { DiffViewProvider } from '../DiffViewProvider';
|
||||
import * as vscode from 'vscode';
|
||||
import { DiffViewProvider } from "../DiffViewProvider"
|
||||
import * as vscode from "vscode"
|
||||
|
||||
// Mock vscode
|
||||
jest.mock('vscode', () => ({
|
||||
jest.mock("vscode", () => ({
|
||||
workspace: {
|
||||
applyEdit: jest.fn(),
|
||||
},
|
||||
@@ -19,34 +19,34 @@ jest.mock('vscode', () => ({
|
||||
TextEditorRevealType: {
|
||||
InCenter: 2,
|
||||
},
|
||||
}));
|
||||
}))
|
||||
|
||||
// Mock DecorationController
|
||||
jest.mock('../DecorationController', () => ({
|
||||
jest.mock("../DecorationController", () => ({
|
||||
DecorationController: jest.fn().mockImplementation(() => ({
|
||||
setActiveLine: jest.fn(),
|
||||
updateOverlayAfterLine: jest.fn(),
|
||||
clear: jest.fn(),
|
||||
})),
|
||||
}));
|
||||
}))
|
||||
|
||||
describe('DiffViewProvider', () => {
|
||||
let diffViewProvider: DiffViewProvider;
|
||||
const mockCwd = '/mock/cwd';
|
||||
let mockWorkspaceEdit: { replace: jest.Mock; delete: jest.Mock };
|
||||
describe("DiffViewProvider", () => {
|
||||
let diffViewProvider: DiffViewProvider
|
||||
const mockCwd = "/mock/cwd"
|
||||
let mockWorkspaceEdit: { replace: jest.Mock; delete: jest.Mock }
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
jest.clearAllMocks()
|
||||
mockWorkspaceEdit = {
|
||||
replace: jest.fn(),
|
||||
delete: jest.fn(),
|
||||
};
|
||||
(vscode.WorkspaceEdit as jest.Mock).mockImplementation(() => mockWorkspaceEdit);
|
||||
}
|
||||
;(vscode.WorkspaceEdit as jest.Mock).mockImplementation(() => mockWorkspaceEdit)
|
||||
|
||||
diffViewProvider = new DiffViewProvider(mockCwd);
|
||||
diffViewProvider = new DiffViewProvider(mockCwd)
|
||||
// Mock the necessary properties and methods
|
||||
(diffViewProvider as any).relPath = 'test.txt';
|
||||
(diffViewProvider as any).activeDiffEditor = {
|
||||
;(diffViewProvider as any).relPath = "test.txt"
|
||||
;(diffViewProvider as any).activeDiffEditor = {
|
||||
document: {
|
||||
uri: { fsPath: `${mockCwd}/test.txt` },
|
||||
getText: jest.fn(),
|
||||
@@ -58,43 +58,39 @@ describe('DiffViewProvider', () => {
|
||||
},
|
||||
edit: jest.fn().mockResolvedValue(true),
|
||||
revealRange: jest.fn(),
|
||||
};
|
||||
(diffViewProvider as any).activeLineController = { setActiveLine: jest.fn(), clear: jest.fn() };
|
||||
(diffViewProvider as any).fadedOverlayController = { updateOverlayAfterLine: jest.fn(), clear: jest.fn() };
|
||||
});
|
||||
}
|
||||
;(diffViewProvider as any).activeLineController = { setActiveLine: jest.fn(), clear: jest.fn() }
|
||||
;(diffViewProvider as any).fadedOverlayController = { updateOverlayAfterLine: jest.fn(), clear: jest.fn() }
|
||||
})
|
||||
|
||||
describe('update method', () => {
|
||||
it('should preserve empty last line when original content has one', async () => {
|
||||
(diffViewProvider as any).originalContent = 'Original content\n';
|
||||
await diffViewProvider.update('New content', true);
|
||||
describe("update method", () => {
|
||||
it("should preserve empty last line when original content has one", async () => {
|
||||
;(diffViewProvider as any).originalContent = "Original content\n"
|
||||
await diffViewProvider.update("New content", true)
|
||||
|
||||
expect(mockWorkspaceEdit.replace).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.anything(),
|
||||
'New content\n'
|
||||
);
|
||||
});
|
||||
"New content\n",
|
||||
)
|
||||
})
|
||||
|
||||
it('should not add extra newline when accumulated content already ends with one', async () => {
|
||||
(diffViewProvider as any).originalContent = 'Original content\n';
|
||||
await diffViewProvider.update('New content\n', true);
|
||||
it("should not add extra newline when accumulated content already ends with one", async () => {
|
||||
;(diffViewProvider as any).originalContent = "Original content\n"
|
||||
await diffViewProvider.update("New content\n", true)
|
||||
|
||||
expect(mockWorkspaceEdit.replace).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.anything(),
|
||||
'New content\n'
|
||||
);
|
||||
});
|
||||
"New content\n",
|
||||
)
|
||||
})
|
||||
|
||||
it('should not add newline when original content does not end with one', async () => {
|
||||
(diffViewProvider as any).originalContent = 'Original content';
|
||||
await diffViewProvider.update('New content', true);
|
||||
it("should not add newline when original content does not end with one", async () => {
|
||||
;(diffViewProvider as any).originalContent = "Original content"
|
||||
await diffViewProvider.update("New content", true)
|
||||
|
||||
expect(mockWorkspaceEdit.replace).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.anything(),
|
||||
'New content'
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
expect(mockWorkspaceEdit.replace).toHaveBeenCalledWith(expect.anything(), expect.anything(), "New content")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { detectCodeOmission } from '../detect-omission'
|
||||
import { detectCodeOmission } from "../detect-omission"
|
||||
|
||||
describe('detectCodeOmission', () => {
|
||||
describe("detectCodeOmission", () => {
|
||||
const originalContent = `function example() {
|
||||
// Some code
|
||||
const x = 1;
|
||||
@@ -10,124 +10,132 @@ describe('detectCodeOmission', () => {
|
||||
|
||||
const generateLongContent = (commentLine: string, length: number = 90) => {
|
||||
return `${commentLine}
|
||||
${Array.from({ length }, (_, i) => `const x${i} = ${i};`).join('\n')}
|
||||
${Array.from({ length }, (_, i) => `const x${i} = ${i};`).join("\n")}
|
||||
const y = 2;`
|
||||
}
|
||||
|
||||
it('should skip comment checks for files under 100 lines', () => {
|
||||
it("should skip comment checks for files under 100 lines", () => {
|
||||
const newContent = `// Lines 1-50 remain unchanged
|
||||
const z = 3;`
|
||||
const predictedLineCount = 50
|
||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||
})
|
||||
|
||||
it('should not detect regular comments without omission keywords', () => {
|
||||
const newContent = generateLongContent('// Adding new functionality')
|
||||
it("should not detect regular comments without omission keywords", () => {
|
||||
const newContent = generateLongContent("// Adding new functionality")
|
||||
const predictedLineCount = 150
|
||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||
})
|
||||
|
||||
it('should not detect when comment is part of original content', () => {
|
||||
it("should not detect when comment is part of original content", () => {
|
||||
const originalWithComment = `// Content remains unchanged
|
||||
${originalContent}`
|
||||
const newContent = generateLongContent('// Content remains unchanged')
|
||||
const newContent = generateLongContent("// Content remains unchanged")
|
||||
const predictedLineCount = 150
|
||||
expect(detectCodeOmission(originalWithComment, newContent, predictedLineCount)).toBe(false)
|
||||
})
|
||||
|
||||
it('should not detect code that happens to contain omission keywords', () => {
|
||||
it("should not detect code that happens to contain omission keywords", () => {
|
||||
const newContent = generateLongContent(`const remains = 'some value';
|
||||
const unchanged = true;`)
|
||||
const predictedLineCount = 150
|
||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||
})
|
||||
|
||||
it('should detect suspicious single-line comment when content is more than 20% shorter', () => {
|
||||
const newContent = generateLongContent('// Previous content remains here\nconst x = 1;')
|
||||
it("should detect suspicious single-line comment when content is more than 20% shorter", () => {
|
||||
const newContent = generateLongContent("// Previous content remains here\nconst x = 1;")
|
||||
const predictedLineCount = 150
|
||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
||||
})
|
||||
|
||||
it('should not flag suspicious single-line comment when content is less than 20% shorter', () => {
|
||||
const newContent = generateLongContent('// Previous content remains here', 130)
|
||||
it("should not flag suspicious single-line comment when content is less than 20% shorter", () => {
|
||||
const newContent = generateLongContent("// Previous content remains here", 130)
|
||||
const predictedLineCount = 150
|
||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||
})
|
||||
|
||||
it('should detect suspicious Python-style comment when content is more than 20% shorter', () => {
|
||||
const newContent = generateLongContent('# Previous content remains here\nconst x = 1;')
|
||||
it("should detect suspicious Python-style comment when content is more than 20% shorter", () => {
|
||||
const newContent = generateLongContent("# Previous content remains here\nconst x = 1;")
|
||||
const predictedLineCount = 150
|
||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
||||
})
|
||||
|
||||
it('should not flag suspicious Python-style comment when content is less than 20% shorter', () => {
|
||||
const newContent = generateLongContent('# Previous content remains here', 130)
|
||||
it("should not flag suspicious Python-style comment when content is less than 20% shorter", () => {
|
||||
const newContent = generateLongContent("# Previous content remains here", 130)
|
||||
const predictedLineCount = 150
|
||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||
})
|
||||
|
||||
it('should detect suspicious multi-line comment when content is more than 20% shorter', () => {
|
||||
const newContent = generateLongContent('/* Previous content remains the same */\nconst x = 1;')
|
||||
it("should detect suspicious multi-line comment when content is more than 20% shorter", () => {
|
||||
const newContent = generateLongContent("/* Previous content remains the same */\nconst x = 1;")
|
||||
const predictedLineCount = 150
|
||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
||||
})
|
||||
|
||||
it('should not flag suspicious multi-line comment when content is less than 20% shorter', () => {
|
||||
const newContent = generateLongContent('/* Previous content remains the same */', 130)
|
||||
it("should not flag suspicious multi-line comment when content is less than 20% shorter", () => {
|
||||
const newContent = generateLongContent("/* Previous content remains the same */", 130)
|
||||
const predictedLineCount = 150
|
||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||
})
|
||||
|
||||
it('should detect suspicious JSX comment when content is more than 20% shorter', () => {
|
||||
const newContent = generateLongContent('{/* Rest of the code remains the same */}\nconst x = 1;')
|
||||
it("should detect suspicious JSX comment when content is more than 20% shorter", () => {
|
||||
const newContent = generateLongContent("{/* Rest of the code remains the same */}\nconst x = 1;")
|
||||
const predictedLineCount = 150
|
||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
||||
})
|
||||
|
||||
it('should not flag suspicious JSX comment when content is less than 20% shorter', () => {
|
||||
const newContent = generateLongContent('{/* Rest of the code remains the same */}', 130)
|
||||
it("should not flag suspicious JSX comment when content is less than 20% shorter", () => {
|
||||
const newContent = generateLongContent("{/* Rest of the code remains the same */}", 130)
|
||||
const predictedLineCount = 150
|
||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||
})
|
||||
|
||||
it('should detect suspicious HTML comment when content is more than 20% shorter', () => {
|
||||
const newContent = generateLongContent('<!-- Existing content unchanged -->\nconst x = 1;')
|
||||
it("should detect suspicious HTML comment when content is more than 20% shorter", () => {
|
||||
const newContent = generateLongContent("<!-- Existing content unchanged -->\nconst x = 1;")
|
||||
const predictedLineCount = 150
|
||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
||||
})
|
||||
|
||||
it('should not flag suspicious HTML comment when content is less than 20% shorter', () => {
|
||||
const newContent = generateLongContent('<!-- Existing content unchanged -->', 130)
|
||||
it("should not flag suspicious HTML comment when content is less than 20% shorter", () => {
|
||||
const newContent = generateLongContent("<!-- Existing content unchanged -->", 130)
|
||||
const predictedLineCount = 150
|
||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||
})
|
||||
|
||||
it('should detect suspicious square bracket notation when content is more than 20% shorter', () => {
|
||||
const newContent = generateLongContent('[Previous content from line 1-305 remains exactly the same]\nconst x = 1;')
|
||||
it("should detect suspicious square bracket notation when content is more than 20% shorter", () => {
|
||||
const newContent = generateLongContent(
|
||||
"[Previous content from line 1-305 remains exactly the same]\nconst x = 1;",
|
||||
)
|
||||
const predictedLineCount = 150
|
||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(true)
|
||||
})
|
||||
|
||||
it('should not flag suspicious square bracket notation when content is less than 20% shorter', () => {
|
||||
const newContent = generateLongContent('[Previous content from line 1-305 remains exactly the same]', 130)
|
||||
it("should not flag suspicious square bracket notation when content is less than 20% shorter", () => {
|
||||
const newContent = generateLongContent("[Previous content from line 1-305 remains exactly the same]", 130)
|
||||
const predictedLineCount = 150
|
||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||
})
|
||||
|
||||
it('should not flag content very close to predicted length', () => {
|
||||
const newContent = generateLongContent(`const x = 1;
|
||||
it("should not flag content very close to predicted length", () => {
|
||||
const newContent = generateLongContent(
|
||||
`const x = 1;
|
||||
const y = 2;
|
||||
// This is a legitimate comment that remains here`, 130)
|
||||
// This is a legitimate comment that remains here`,
|
||||
130,
|
||||
)
|
||||
const predictedLineCount = 150
|
||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||
})
|
||||
|
||||
it('should not flag when content is longer than predicted', () => {
|
||||
const newContent = generateLongContent(`const x = 1;
|
||||
it("should not flag when content is longer than predicted", () => {
|
||||
const newContent = generateLongContent(
|
||||
`const x = 1;
|
||||
const y = 2;
|
||||
// Previous content remains here but we added more
|
||||
const z = 3;
|
||||
const w = 4;`, 160)
|
||||
const w = 4;`,
|
||||
160,
|
||||
)
|
||||
const predictedLineCount = 150
|
||||
expect(detectCodeOmission(originalContent, newContent, predictedLineCount)).toBe(false)
|
||||
})
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
export function detectCodeOmission(
|
||||
originalFileContent: string,
|
||||
newFileContent: string,
|
||||
predictedLineCount: number
|
||||
predictedLineCount: number,
|
||||
): boolean {
|
||||
// Skip all checks if predictedLineCount is less than 100
|
||||
if (!predictedLineCount || predictedLineCount < 100) {
|
||||
@@ -20,7 +20,17 @@ export function detectCodeOmission(
|
||||
|
||||
const originalLines = originalFileContent.split("\n")
|
||||
const newLines = newFileContent.split("\n")
|
||||
const omissionKeywords = ["remain", "remains", "unchanged", "rest", "previous", "existing", "content", "same", "..."]
|
||||
const omissionKeywords = [
|
||||
"remain",
|
||||
"remains",
|
||||
"unchanged",
|
||||
"rest",
|
||||
"previous",
|
||||
"existing",
|
||||
"content",
|
||||
"same",
|
||||
"...",
|
||||
]
|
||||
|
||||
const commentPatterns = [
|
||||
/^\s*\/\//, // Single-line comment for most languages
|
||||
@@ -39,7 +49,7 @@ export function detectCodeOmission(
|
||||
if (omissionKeywords.some((keyword) => words.includes(keyword))) {
|
||||
if (!originalLines.includes(line)) {
|
||||
// For files with 100+ lines, only flag if content is more than 20% shorter
|
||||
if (lengthRatio <= 0.80) {
|
||||
if (lengthRatio <= 0.8) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,122 +1,122 @@
|
||||
import { addLineNumbers, everyLineHasLineNumbers, stripLineNumbers, truncateOutput } from '../extract-text';
|
||||
import { addLineNumbers, everyLineHasLineNumbers, stripLineNumbers, truncateOutput } from "../extract-text"
|
||||
|
||||
describe('addLineNumbers', () => {
|
||||
it('should add line numbers starting from 1 by default', () => {
|
||||
const input = 'line 1\nline 2\nline 3';
|
||||
const expected = '1 | line 1\n2 | line 2\n3 | line 3';
|
||||
expect(addLineNumbers(input)).toBe(expected);
|
||||
});
|
||||
describe("addLineNumbers", () => {
|
||||
it("should add line numbers starting from 1 by default", () => {
|
||||
const input = "line 1\nline 2\nline 3"
|
||||
const expected = "1 | line 1\n2 | line 2\n3 | line 3"
|
||||
expect(addLineNumbers(input)).toBe(expected)
|
||||
})
|
||||
|
||||
it('should add line numbers starting from specified line number', () => {
|
||||
const input = 'line 1\nline 2\nline 3';
|
||||
const expected = '10 | line 1\n11 | line 2\n12 | line 3';
|
||||
expect(addLineNumbers(input, 10)).toBe(expected);
|
||||
});
|
||||
it("should add line numbers starting from specified line number", () => {
|
||||
const input = "line 1\nline 2\nline 3"
|
||||
const expected = "10 | line 1\n11 | line 2\n12 | line 3"
|
||||
expect(addLineNumbers(input, 10)).toBe(expected)
|
||||
})
|
||||
|
||||
it('should handle empty content', () => {
|
||||
expect(addLineNumbers('')).toBe('1 | ');
|
||||
expect(addLineNumbers('', 5)).toBe('5 | ');
|
||||
});
|
||||
it("should handle empty content", () => {
|
||||
expect(addLineNumbers("")).toBe("1 | ")
|
||||
expect(addLineNumbers("", 5)).toBe("5 | ")
|
||||
})
|
||||
|
||||
it('should handle single line content', () => {
|
||||
expect(addLineNumbers('single line')).toBe('1 | single line');
|
||||
expect(addLineNumbers('single line', 42)).toBe('42 | single line');
|
||||
});
|
||||
it("should handle single line content", () => {
|
||||
expect(addLineNumbers("single line")).toBe("1 | single line")
|
||||
expect(addLineNumbers("single line", 42)).toBe("42 | single line")
|
||||
})
|
||||
|
||||
it('should pad line numbers based on the highest line number', () => {
|
||||
const input = 'line 1\nline 2';
|
||||
it("should pad line numbers based on the highest line number", () => {
|
||||
const input = "line 1\nline 2"
|
||||
// When starting from 99, highest line will be 100, so needs 3 spaces padding
|
||||
const expected = ' 99 | line 1\n100 | line 2';
|
||||
expect(addLineNumbers(input, 99)).toBe(expected);
|
||||
});
|
||||
});
|
||||
const expected = " 99 | line 1\n100 | line 2"
|
||||
expect(addLineNumbers(input, 99)).toBe(expected)
|
||||
})
|
||||
})
|
||||
|
||||
describe('everyLineHasLineNumbers', () => {
|
||||
it('should return true for content with line numbers', () => {
|
||||
const input = '1 | line one\n2 | line two\n3 | line three';
|
||||
expect(everyLineHasLineNumbers(input)).toBe(true);
|
||||
});
|
||||
describe("everyLineHasLineNumbers", () => {
|
||||
it("should return true for content with line numbers", () => {
|
||||
const input = "1 | line one\n2 | line two\n3 | line three"
|
||||
expect(everyLineHasLineNumbers(input)).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for content with padded line numbers', () => {
|
||||
const input = ' 1 | line one\n 2 | line two\n 3 | line three';
|
||||
expect(everyLineHasLineNumbers(input)).toBe(true);
|
||||
});
|
||||
it("should return true for content with padded line numbers", () => {
|
||||
const input = " 1 | line one\n 2 | line two\n 3 | line three"
|
||||
expect(everyLineHasLineNumbers(input)).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for content without line numbers', () => {
|
||||
const input = 'line one\nline two\nline three';
|
||||
expect(everyLineHasLineNumbers(input)).toBe(false);
|
||||
});
|
||||
it("should return false for content without line numbers", () => {
|
||||
const input = "line one\nline two\nline three"
|
||||
expect(everyLineHasLineNumbers(input)).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for mixed content', () => {
|
||||
const input = '1 | line one\nline two\n3 | line three';
|
||||
expect(everyLineHasLineNumbers(input)).toBe(false);
|
||||
});
|
||||
it("should return false for mixed content", () => {
|
||||
const input = "1 | line one\nline two\n3 | line three"
|
||||
expect(everyLineHasLineNumbers(input)).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle empty content', () => {
|
||||
expect(everyLineHasLineNumbers('')).toBe(false);
|
||||
});
|
||||
it("should handle empty content", () => {
|
||||
expect(everyLineHasLineNumbers("")).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for content with pipe but no line numbers', () => {
|
||||
const input = 'a | b\nc | d';
|
||||
expect(everyLineHasLineNumbers(input)).toBe(false);
|
||||
});
|
||||
});
|
||||
it("should return false for content with pipe but no line numbers", () => {
|
||||
const input = "a | b\nc | d"
|
||||
expect(everyLineHasLineNumbers(input)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('stripLineNumbers', () => {
|
||||
it('should strip line numbers from content', () => {
|
||||
const input = '1 | line one\n2 | line two\n3 | line three';
|
||||
const expected = 'line one\nline two\nline three';
|
||||
expect(stripLineNumbers(input)).toBe(expected);
|
||||
});
|
||||
describe("stripLineNumbers", () => {
|
||||
it("should strip line numbers from content", () => {
|
||||
const input = "1 | line one\n2 | line two\n3 | line three"
|
||||
const expected = "line one\nline two\nline three"
|
||||
expect(stripLineNumbers(input)).toBe(expected)
|
||||
})
|
||||
|
||||
it('should strip padded line numbers', () => {
|
||||
const input = ' 1 | line one\n 2 | line two\n 3 | line three';
|
||||
const expected = 'line one\nline two\nline three';
|
||||
expect(stripLineNumbers(input)).toBe(expected);
|
||||
});
|
||||
it("should strip padded line numbers", () => {
|
||||
const input = " 1 | line one\n 2 | line two\n 3 | line three"
|
||||
const expected = "line one\nline two\nline three"
|
||||
expect(stripLineNumbers(input)).toBe(expected)
|
||||
})
|
||||
|
||||
it('should handle content without line numbers', () => {
|
||||
const input = 'line one\nline two\nline three';
|
||||
expect(stripLineNumbers(input)).toBe(input);
|
||||
});
|
||||
it("should handle content without line numbers", () => {
|
||||
const input = "line one\nline two\nline three"
|
||||
expect(stripLineNumbers(input)).toBe(input)
|
||||
})
|
||||
|
||||
it('should handle empty content', () => {
|
||||
expect(stripLineNumbers('')).toBe('');
|
||||
});
|
||||
it("should handle empty content", () => {
|
||||
expect(stripLineNumbers("")).toBe("")
|
||||
})
|
||||
|
||||
it('should preserve content with pipe but no line numbers', () => {
|
||||
const input = 'a | b\nc | d';
|
||||
expect(stripLineNumbers(input)).toBe(input);
|
||||
});
|
||||
it("should preserve content with pipe but no line numbers", () => {
|
||||
const input = "a | b\nc | d"
|
||||
expect(stripLineNumbers(input)).toBe(input)
|
||||
})
|
||||
|
||||
it('should handle windows-style line endings', () => {
|
||||
const input = '1 | line one\r\n2 | line two\r\n3 | line three';
|
||||
const expected = 'line one\r\nline two\r\nline three';
|
||||
expect(stripLineNumbers(input)).toBe(expected);
|
||||
});
|
||||
it("should handle windows-style line endings", () => {
|
||||
const input = "1 | line one\r\n2 | line two\r\n3 | line three"
|
||||
const expected = "line one\r\nline two\r\nline three"
|
||||
expect(stripLineNumbers(input)).toBe(expected)
|
||||
})
|
||||
|
||||
it('should handle content with varying line number widths', () => {
|
||||
const input = ' 1 | line one\n 10 | line two\n100 | line three';
|
||||
const expected = 'line one\nline two\nline three';
|
||||
expect(stripLineNumbers(input)).toBe(expected);
|
||||
});
|
||||
});
|
||||
it("should handle content with varying line number widths", () => {
|
||||
const input = " 1 | line one\n 10 | line two\n100 | line three"
|
||||
const expected = "line one\nline two\nline three"
|
||||
expect(stripLineNumbers(input)).toBe(expected)
|
||||
})
|
||||
})
|
||||
|
||||
describe('truncateOutput', () => {
|
||||
it('returns original content when no line limit provided', () => {
|
||||
const content = 'line1\nline2\nline3'
|
||||
describe("truncateOutput", () => {
|
||||
it("returns original content when no line limit provided", () => {
|
||||
const content = "line1\nline2\nline3"
|
||||
expect(truncateOutput(content)).toBe(content)
|
||||
})
|
||||
|
||||
it('returns original content when lines are under limit', () => {
|
||||
const content = 'line1\nline2\nline3'
|
||||
it("returns original content when lines are under limit", () => {
|
||||
const content = "line1\nline2\nline3"
|
||||
expect(truncateOutput(content, 5)).toBe(content)
|
||||
})
|
||||
|
||||
it('truncates content with 20/80 split when over limit', () => {
|
||||
it("truncates content with 20/80 split when over limit", () => {
|
||||
// Create 25 lines of content
|
||||
const lines = Array.from({ length: 25 }, (_, i) => `line${i + 1}`)
|
||||
const content = lines.join('\n')
|
||||
const content = lines.join("\n")
|
||||
|
||||
// Set limit to 10 lines
|
||||
const result = truncateOutput(content, 10)
|
||||
@@ -126,51 +126,42 @@ describe('truncateOutput', () => {
|
||||
// - Last 8 lines (80% of 10)
|
||||
// - Omission indicator in between
|
||||
const expectedLines = [
|
||||
'line1',
|
||||
'line2',
|
||||
'',
|
||||
'[...15 lines omitted...]',
|
||||
'',
|
||||
'line18',
|
||||
'line19',
|
||||
'line20',
|
||||
'line21',
|
||||
'line22',
|
||||
'line23',
|
||||
'line24',
|
||||
'line25'
|
||||
"line1",
|
||||
"line2",
|
||||
"",
|
||||
"[...15 lines omitted...]",
|
||||
"",
|
||||
"line18",
|
||||
"line19",
|
||||
"line20",
|
||||
"line21",
|
||||
"line22",
|
||||
"line23",
|
||||
"line24",
|
||||
"line25",
|
||||
]
|
||||
expect(result).toBe(expectedLines.join('\n'))
|
||||
expect(result).toBe(expectedLines.join("\n"))
|
||||
})
|
||||
|
||||
it('handles empty content', () => {
|
||||
expect(truncateOutput('', 10)).toBe('')
|
||||
it("handles empty content", () => {
|
||||
expect(truncateOutput("", 10)).toBe("")
|
||||
})
|
||||
|
||||
it('handles single line content', () => {
|
||||
expect(truncateOutput('single line', 10)).toBe('single line')
|
||||
it("handles single line content", () => {
|
||||
expect(truncateOutput("single line", 10)).toBe("single line")
|
||||
})
|
||||
|
||||
it('handles windows-style line endings', () => {
|
||||
it("handles windows-style line endings", () => {
|
||||
// Create content with windows line endings
|
||||
const lines = Array.from({ length: 15 }, (_, i) => `line${i + 1}`)
|
||||
const content = lines.join('\r\n')
|
||||
const content = lines.join("\r\n")
|
||||
|
||||
const result = truncateOutput(content, 5)
|
||||
|
||||
// Should keep first line (20% of 5 = 1) and last 4 lines (80% of 5 = 4)
|
||||
// Split result by either \r\n or \n to normalize line endings
|
||||
const resultLines = result.split(/\r?\n/)
|
||||
const expectedLines = [
|
||||
'line1',
|
||||
'',
|
||||
'[...10 lines omitted...]',
|
||||
'',
|
||||
'line12',
|
||||
'line13',
|
||||
'line14',
|
||||
'line15'
|
||||
]
|
||||
const expectedLines = ["line1", "", "[...10 lines omitted...]", "", "line12", "line13", "line14", "line15"]
|
||||
expect(resultLines).toEqual(expectedLines)
|
||||
})
|
||||
})
|
||||
@@ -55,19 +55,20 @@ async function extractTextFromIPYNB(filePath: string): Promise<string> {
|
||||
}
|
||||
|
||||
export function addLineNumbers(content: string, startLine: number = 1): string {
|
||||
const lines = content.split('\n')
|
||||
const lines = content.split("\n")
|
||||
const maxLineNumberWidth = String(startLine + lines.length - 1).length
|
||||
return lines
|
||||
.map((line, index) => {
|
||||
const lineNumber = String(startLine + index).padStart(maxLineNumberWidth, ' ')
|
||||
const lineNumber = String(startLine + index).padStart(maxLineNumberWidth, " ")
|
||||
return `${lineNumber} | ${line}`
|
||||
}).join('\n')
|
||||
})
|
||||
.join("\n")
|
||||
}
|
||||
// Checks if every line in the content has line numbers prefixed (e.g., "1 | content" or "123 | content")
|
||||
// Line numbers must be followed by a single pipe character (not double pipes)
|
||||
export function everyLineHasLineNumbers(content: string): boolean {
|
||||
const lines = content.split(/\r?\n/)
|
||||
return lines.length > 0 && lines.every(line => /^\s*\d+\s+\|(?!\|)/.test(line))
|
||||
return lines.length > 0 && lines.every((line) => /^\s*\d+\s+\|(?!\|)/.test(line))
|
||||
}
|
||||
|
||||
// Strips line numbers from content while preserving the actual content
|
||||
@@ -78,14 +79,14 @@ export function stripLineNumbers(content: string): string {
|
||||
const lines = content.split(/\r?\n/)
|
||||
|
||||
// Process each line
|
||||
const processedLines = lines.map(line => {
|
||||
const processedLines = lines.map((line) => {
|
||||
// Match line number pattern and capture everything after the pipe
|
||||
const match = line.match(/^\s*\d+\s+\|(?!\|)\s?(.*)$/)
|
||||
return match ? match[1] : line
|
||||
})
|
||||
|
||||
// Join back with original line endings
|
||||
const lineEnding = content.includes('\r\n') ? '\r\n' : '\n'
|
||||
const lineEnding = content.includes("\r\n") ? "\r\n" : "\n"
|
||||
return processedLines.join(lineEnding)
|
||||
}
|
||||
|
||||
@@ -109,7 +110,7 @@ export function truncateOutput(content: string, lineLimit?: number): string {
|
||||
return content
|
||||
}
|
||||
|
||||
const lines = content.split('\n')
|
||||
const lines = content.split("\n")
|
||||
if (lines.length <= lineLimit) {
|
||||
return content
|
||||
}
|
||||
@@ -119,6 +120,6 @@ export function truncateOutput(content: string, lineLimit?: number): string {
|
||||
return [
|
||||
...lines.slice(0, beforeLimit),
|
||||
`\n[...${lines.length - lineLimit} lines omitted...]\n`,
|
||||
...lines.slice(-afterLimit)
|
||||
].join('\n')
|
||||
...lines.slice(-afterLimit),
|
||||
].join("\n")
|
||||
}
|
||||
@@ -21,8 +21,8 @@ export async function openImage(dataUri: string) {
|
||||
}
|
||||
|
||||
interface OpenFileOptions {
|
||||
create?: boolean;
|
||||
content?: string;
|
||||
create?: boolean
|
||||
content?: string
|
||||
}
|
||||
|
||||
export async function openFile(filePath: string, options: OpenFileOptions = {}) {
|
||||
@@ -30,13 +30,11 @@ export async function openFile(filePath: string, options: OpenFileOptions = {})
|
||||
// Get workspace root
|
||||
const workspaceRoot = vscode.workspace.workspaceFolders?.[0]?.uri.fsPath
|
||||
if (!workspaceRoot) {
|
||||
throw new Error('No workspace root found')
|
||||
throw new Error("No workspace root found")
|
||||
}
|
||||
|
||||
// If path starts with ./, resolve it relative to workspace root
|
||||
const fullPath = filePath.startsWith('./') ?
|
||||
path.join(workspaceRoot, filePath.slice(2)) :
|
||||
filePath
|
||||
const fullPath = filePath.startsWith("./") ? path.join(workspaceRoot, filePath.slice(2)) : filePath
|
||||
|
||||
const uri = vscode.Uri.file(fullPath)
|
||||
|
||||
@@ -46,12 +44,12 @@ export async function openFile(filePath: string, options: OpenFileOptions = {})
|
||||
} catch {
|
||||
// File doesn't exist
|
||||
if (!options.create) {
|
||||
throw new Error('File does not exist')
|
||||
throw new Error("File does not exist")
|
||||
}
|
||||
|
||||
// Create with provided content or empty string
|
||||
const content = options.content || ''
|
||||
await vscode.workspace.fs.writeFile(uri, Buffer.from(content, 'utf8'))
|
||||
const content = options.content || ""
|
||||
await vscode.workspace.fs.writeFile(uri, Buffer.from(content, "utf8"))
|
||||
}
|
||||
|
||||
// Check if the document is already open in a tab group that's not in the active editor's column
|
||||
|
||||
@@ -146,7 +146,9 @@ export class TerminalManager {
|
||||
process.run(terminal, command)
|
||||
} else {
|
||||
// docs recommend waiting 3s for shell integration to activate
|
||||
pWaitFor(() => (terminalInfo.terminal as ExtendedTerminal).shellIntegration !== undefined, { timeout: 4000 }).finally(() => {
|
||||
pWaitFor(() => (terminalInfo.terminal as ExtendedTerminal).shellIntegration !== undefined, {
|
||||
timeout: 4000,
|
||||
}).finally(() => {
|
||||
const existingProcess = this.processes.get(terminalInfo.id)
|
||||
if (existingProcess && existingProcess.waitForShellIntegration) {
|
||||
existingProcess.waitForShellIntegration = false
|
||||
|
||||
@@ -19,8 +19,8 @@ export class TerminalRegistry {
|
||||
name: "Roo Cline",
|
||||
iconPath: new vscode.ThemeIcon("rocket"),
|
||||
env: {
|
||||
PAGER: "cat"
|
||||
}
|
||||
PAGER: "cat",
|
||||
},
|
||||
})
|
||||
const newInfo: TerminalInfo = {
|
||||
terminal,
|
||||
|
||||
@@ -7,11 +7,13 @@ jest.mock("vscode")
|
||||
|
||||
describe("TerminalProcess", () => {
|
||||
let terminalProcess: TerminalProcess
|
||||
let mockTerminal: jest.Mocked<vscode.Terminal & {
|
||||
let mockTerminal: jest.Mocked<
|
||||
vscode.Terminal & {
|
||||
shellIntegration: {
|
||||
executeCommand: jest.Mock
|
||||
}
|
||||
}>
|
||||
}
|
||||
>
|
||||
let mockExecution: any
|
||||
let mockStream: AsyncIterableIterator<string>
|
||||
|
||||
@@ -21,7 +23,7 @@ describe("TerminalProcess", () => {
|
||||
// Create properly typed mock terminal
|
||||
mockTerminal = {
|
||||
shellIntegration: {
|
||||
executeCommand: jest.fn()
|
||||
executeCommand: jest.fn(),
|
||||
},
|
||||
name: "Mock Terminal",
|
||||
processId: Promise.resolve(123),
|
||||
@@ -31,12 +33,14 @@ describe("TerminalProcess", () => {
|
||||
dispose: jest.fn(),
|
||||
hide: jest.fn(),
|
||||
show: jest.fn(),
|
||||
sendText: jest.fn()
|
||||
} as unknown as jest.Mocked<vscode.Terminal & {
|
||||
sendText: jest.fn(),
|
||||
} as unknown as jest.Mocked<
|
||||
vscode.Terminal & {
|
||||
shellIntegration: {
|
||||
executeCommand: jest.Mock
|
||||
}
|
||||
}>
|
||||
}
|
||||
>
|
||||
|
||||
// Reset event listeners
|
||||
terminalProcess.removeAllListeners()
|
||||
@@ -62,7 +66,7 @@ describe("TerminalProcess", () => {
|
||||
})()
|
||||
|
||||
mockExecution = {
|
||||
read: jest.fn().mockReturnValue(mockStream)
|
||||
read: jest.fn().mockReturnValue(mockStream),
|
||||
}
|
||||
|
||||
mockTerminal.shellIntegration.executeCommand.mockReturnValue(mockExecution)
|
||||
@@ -81,7 +85,7 @@ describe("TerminalProcess", () => {
|
||||
it("handles terminals without shell integration", async () => {
|
||||
const noShellTerminal = {
|
||||
sendText: jest.fn(),
|
||||
shellIntegration: undefined
|
||||
shellIntegration: undefined,
|
||||
} as unknown as vscode.Terminal
|
||||
|
||||
const noShellPromise = new Promise<void>((resolve) => {
|
||||
@@ -103,20 +107,20 @@ describe("TerminalProcess", () => {
|
||||
})
|
||||
|
||||
// Create a promise that resolves when the first chunk is processed
|
||||
const firstChunkProcessed = new Promise<void>(resolve => {
|
||||
const firstChunkProcessed = new Promise<void>((resolve) => {
|
||||
terminalProcess.on("line", () => resolve())
|
||||
})
|
||||
|
||||
mockStream = (async function* () {
|
||||
yield "compiling...\n"
|
||||
// Wait to ensure hot state check happens after first chunk
|
||||
await new Promise(resolve => setTimeout(resolve, 10))
|
||||
await new Promise((resolve) => setTimeout(resolve, 10))
|
||||
yield "still compiling...\n"
|
||||
yield "done"
|
||||
})()
|
||||
|
||||
mockExecution = {
|
||||
read: jest.fn().mockReturnValue(mockStream)
|
||||
read: jest.fn().mockReturnValue(mockStream),
|
||||
}
|
||||
|
||||
mockTerminal.shellIntegration.executeCommand.mockReturnValue(mockExecution)
|
||||
@@ -178,7 +182,7 @@ describe("TerminalProcess", () => {
|
||||
["output#", "output"],
|
||||
["output> ", "output"],
|
||||
["multi\nline%", "multi\nline"],
|
||||
["no artifacts", "no artifacts"]
|
||||
["no artifacts", "no artifacts"],
|
||||
]
|
||||
|
||||
for (const [input, expected] of cases) {
|
||||
|
||||
@@ -29,8 +29,8 @@ describe("TerminalRegistry", () => {
|
||||
name: "Roo Cline",
|
||||
iconPath: expect.any(Object),
|
||||
env: {
|
||||
PAGER: "cat"
|
||||
}
|
||||
PAGER: "cat",
|
||||
},
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user