diff --git a/CHANGELOG.md b/CHANGELOG.md index 9899c18..dff6e00 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Roo Cline Changelog +## [2.2.2] + +- Add checkboxes to auto-approve MCP tools + ## [2.2.1] - Fix another diff editing indentation bug diff --git a/README.md b/README.md index 9daf1c6..753b870 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Roo-Cline -A fork of Cline, an autonomous coding agent, with some added experimental configuration and automation features. +A fork of Cline, an autonomous coding agent, optimized for speed and flexibility. - Auto-approval capabilities for commands, write, and browser operations - Support for .clinerules per-project custom instructions - Ability to run side-by-side with Cline @@ -10,6 +10,7 @@ A fork of Cline, an autonomous coding agent, with some added experimental config - Support for copying prompts from the history screen - Support for editing through diffs / handling truncated full-file edits - Support for newer Gemini models (gemini-exp-1206 and gemini-2.0-flash-exp) +- Support for auto-approving MCP tools ## Disclaimer diff --git a/jest.config.js b/jest.config.js index dbca14c..b6012c0 100644 --- a/jest.config.js +++ b/jest.config.js @@ -5,17 +5,35 @@ module.exports = { moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'], transform: { '^.+\\.tsx?$': ['ts-jest', { - tsconfig: 'tsconfig.json' + tsconfig: { + "module": "CommonJS", + "moduleResolution": "node", + "esModuleInterop": true, + "allowJs": true + } }] }, testMatch: ['**/__tests__/**/*.test.ts'], moduleNameMapper: { - '^vscode$': '/node_modules/@types/vscode/index.d.ts' + '^vscode$': '/src/__mocks__/vscode.js', + '@modelcontextprotocol/sdk$': '/src/__mocks__/@modelcontextprotocol/sdk/index.js', + '@modelcontextprotocol/sdk/(.*)': '/src/__mocks__/@modelcontextprotocol/sdk/$1', + '^delay$': '/src/__mocks__/delay.js', + '^p-wait-for$': '/src/__mocks__/p-wait-for.js', + '^globby$': '/src/__mocks__/globby.js', + '^serialize-error$': '/src/__mocks__/serialize-error.js', + '^strip-ansi$': '/src/__mocks__/strip-ansi.js', + '^default-shell$': '/src/__mocks__/default-shell.js', + '^os-name$': '/src/__mocks__/os-name.js' }, + transformIgnorePatterns: [ + 'node_modules/(?!(@modelcontextprotocol|delay|p-wait-for|globby|serialize-error|strip-ansi|default-shell|os-name)/)' + ], setupFiles: [], globals: { 'ts-jest': { - diagnostics: false + diagnostics: false, + isolatedModules: true } } }; diff --git a/src/__mocks__/@modelcontextprotocol/sdk/client/index.js b/src/__mocks__/@modelcontextprotocol/sdk/client/index.js new file mode 100644 index 0000000..6ed5825 --- /dev/null +++ b/src/__mocks__/@modelcontextprotocol/sdk/client/index.js @@ -0,0 +1,17 @@ +class Client { + constructor() { + this.request = jest.fn() + } + + connect() { + return Promise.resolve() + } + + close() { + return Promise.resolve() + } +} + +module.exports = { + Client +} \ No newline at end of file diff --git a/src/__mocks__/@modelcontextprotocol/sdk/client/stdio.js b/src/__mocks__/@modelcontextprotocol/sdk/client/stdio.js new file mode 100644 index 0000000..afa42ad --- /dev/null +++ b/src/__mocks__/@modelcontextprotocol/sdk/client/stdio.js @@ -0,0 +1,22 @@ +class StdioClientTransport { + constructor() { + this.start = jest.fn().mockResolvedValue(undefined) + this.close = jest.fn().mockResolvedValue(undefined) + this.stderr = { + on: jest.fn() + } + } +} + +class StdioServerParameters { + constructor() { + this.command = '' + this.args = [] + this.env = {} + } +} + +module.exports = { + StdioClientTransport, + StdioServerParameters +} \ No newline at end of file diff --git a/src/__mocks__/@modelcontextprotocol/sdk/index.js b/src/__mocks__/@modelcontextprotocol/sdk/index.js new file mode 100644 index 0000000..c6e43e6 --- /dev/null +++ b/src/__mocks__/@modelcontextprotocol/sdk/index.js @@ -0,0 +1,24 @@ +const { Client } = require('./client/index.js') +const { StdioClientTransport, StdioServerParameters } = require('./client/stdio.js') +const { + CallToolResultSchema, + ListToolsResultSchema, + ListResourcesResultSchema, + ListResourceTemplatesResultSchema, + ReadResourceResultSchema, + ErrorCode, + McpError +} = require('./types.js') + +module.exports = { + Client, + StdioClientTransport, + StdioServerParameters, + CallToolResultSchema, + ListToolsResultSchema, + ListResourcesResultSchema, + ListResourceTemplatesResultSchema, + ReadResourceResultSchema, + ErrorCode, + McpError +} \ No newline at end of file diff --git a/src/__mocks__/@modelcontextprotocol/sdk/types.js b/src/__mocks__/@modelcontextprotocol/sdk/types.js new file mode 100644 index 0000000..a2b3ea1 --- /dev/null +++ b/src/__mocks__/@modelcontextprotocol/sdk/types.js @@ -0,0 +1,51 @@ +const CallToolResultSchema = { + parse: jest.fn().mockReturnValue({}) +} + +const ListToolsResultSchema = { + parse: jest.fn().mockReturnValue({ + tools: [] + }) +} + +const ListResourcesResultSchema = { + parse: jest.fn().mockReturnValue({ + resources: [] + }) +} + +const ListResourceTemplatesResultSchema = { + parse: jest.fn().mockReturnValue({ + resourceTemplates: [] + }) +} + +const ReadResourceResultSchema = { + parse: jest.fn().mockReturnValue({ + contents: [] + }) +} + +const ErrorCode = { + InvalidRequest: 'InvalidRequest', + MethodNotFound: 'MethodNotFound', + InvalidParams: 'InvalidParams', + InternalError: 'InternalError' +} + +class McpError extends Error { + constructor(code, message) { + super(message) + this.code = code + } +} + +module.exports = { + CallToolResultSchema, + ListToolsResultSchema, + ListResourcesResultSchema, + ListResourceTemplatesResultSchema, + ReadResourceResultSchema, + ErrorCode, + McpError +} \ No newline at end of file diff --git a/src/__mocks__/McpHub.ts b/src/__mocks__/McpHub.ts new file mode 100644 index 0000000..d39b2d7 --- /dev/null +++ b/src/__mocks__/McpHub.ts @@ -0,0 +1,17 @@ +export class McpHub { + connections = [] + isConnecting = false + + constructor() { + this.toggleToolAlwaysAllow = jest.fn() + this.callTool = jest.fn() + } + + async toggleToolAlwaysAllow(serverName: string, toolName: string, shouldAllow: boolean): Promise { + return Promise.resolve() + } + + async callTool(serverName: string, toolName: string, toolArguments?: Record): Promise { + return Promise.resolve({ result: 'success' }) + } +} \ No newline at end of file diff --git a/src/__mocks__/default-shell.js b/src/__mocks__/default-shell.js new file mode 100644 index 0000000..f03e4fb --- /dev/null +++ b/src/__mocks__/default-shell.js @@ -0,0 +1,12 @@ +// Mock default shell based on platform +const os = require('os'); + +let defaultShell; +if (os.platform() === 'win32') { + defaultShell = 'cmd.exe'; +} else { + defaultShell = '/bin/bash'; +} + +module.exports = defaultShell; +module.exports.default = defaultShell; \ No newline at end of file diff --git a/src/__mocks__/delay.js b/src/__mocks__/delay.js new file mode 100644 index 0000000..9ecb361 --- /dev/null +++ b/src/__mocks__/delay.js @@ -0,0 +1,6 @@ +function delay(ms) { + return new Promise(resolve => setTimeout(resolve, ms)); +} + +module.exports = delay; +module.exports.default = delay; \ No newline at end of file diff --git a/src/__mocks__/globby.js b/src/__mocks__/globby.js new file mode 100644 index 0000000..2584cd1 --- /dev/null +++ b/src/__mocks__/globby.js @@ -0,0 +1,10 @@ +function globby(patterns, options) { + return Promise.resolve([]); +} + +globby.sync = function(patterns, options) { + return []; +}; + +module.exports = globby; +module.exports.default = globby; \ No newline at end of file diff --git a/src/__mocks__/os-name.js b/src/__mocks__/os-name.js new file mode 100644 index 0000000..e760ff3 --- /dev/null +++ b/src/__mocks__/os-name.js @@ -0,0 +1,6 @@ +function osName() { + return 'macOS'; +} + +module.exports = osName; +module.exports.default = osName; \ No newline at end of file diff --git a/src/__mocks__/p-wait-for.js b/src/__mocks__/p-wait-for.js new file mode 100644 index 0000000..f1e6a68 --- /dev/null +++ b/src/__mocks__/p-wait-for.js @@ -0,0 +1,20 @@ +function pWaitFor(condition, options = {}) { + return new Promise((resolve, reject) => { + const interval = setInterval(() => { + if (condition()) { + clearInterval(interval); + resolve(); + } + }, options.interval || 20); + + if (options.timeout) { + setTimeout(() => { + clearInterval(interval); + reject(new Error('Timed out')); + }, options.timeout); + } + }); +} + +module.exports = pWaitFor; +module.exports.default = pWaitFor; \ No newline at end of file diff --git a/src/__mocks__/serialize-error.js b/src/__mocks__/serialize-error.js new file mode 100644 index 0000000..bf01dc1 --- /dev/null +++ b/src/__mocks__/serialize-error.js @@ -0,0 +1,25 @@ +function serializeError(error) { + if (error instanceof Error) { + return { + name: error.name, + message: error.message, + stack: error.stack + }; + } + return error; +} + +function deserializeError(errorData) { + if (errorData && typeof errorData === 'object') { + const error = new Error(errorData.message); + error.name = errorData.name; + error.stack = errorData.stack; + return error; + } + return errorData; +} + +module.exports = { + serializeError, + deserializeError +}; \ No newline at end of file diff --git a/src/__mocks__/strip-ansi.js b/src/__mocks__/strip-ansi.js new file mode 100644 index 0000000..bf7aff9 --- /dev/null +++ b/src/__mocks__/strip-ansi.js @@ -0,0 +1,7 @@ +function stripAnsi(string) { + // Simple mock that just returns the input string + return string; +} + +module.exports = stripAnsi; +module.exports.default = stripAnsi; \ No newline at end of file diff --git a/src/__mocks__/vscode.js b/src/__mocks__/vscode.js new file mode 100644 index 0000000..23f3ae5 --- /dev/null +++ b/src/__mocks__/vscode.js @@ -0,0 +1,57 @@ +const vscode = { + window: { + showInformationMessage: jest.fn(), + showErrorMessage: jest.fn(), + createTextEditorDecorationType: jest.fn().mockReturnValue({ + dispose: jest.fn() + }) + }, + workspace: { + onDidSaveTextDocument: jest.fn() + }, + Disposable: class { + dispose() {} + }, + Uri: { + file: (path) => ({ + fsPath: path, + scheme: 'file', + authority: '', + path: path, + query: '', + fragment: '', + with: jest.fn(), + toJSON: jest.fn() + }) + }, + EventEmitter: class { + constructor() { + this.event = jest.fn(); + this.fire = jest.fn(); + } + }, + ConfigurationTarget: { + Global: 1, + Workspace: 2, + WorkspaceFolder: 3 + }, + Position: class { + constructor(line, character) { + this.line = line; + this.character = character; + } + }, + Range: class { + constructor(startLine, startCharacter, endLine, endCharacter) { + this.start = new vscode.Position(startLine, startCharacter); + this.end = new vscode.Position(endLine, endCharacter); + } + }, + ThemeColor: class { + constructor(id) { + this.id = id; + } + } +}; + +module.exports = vscode; \ No newline at end of file diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 48a12c0..4218ab6 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -550,6 +550,18 @@ export class ClineProvider implements vscode.WebviewViewProvider { } break } + case "toggleToolAlwaysAllow": { + try { + await this.mcpHub?.toggleToolAlwaysAllow( + message.serverName!, + message.toolName!, + message.alwaysAllow! + ) + } catch (error) { + console.error(`Failed to toggle auto-approve for tool ${message.toolName}:`, error) + } + break + } // Add more switch case statements here as more webview message commands // are created within the webview context (i.e. inside media/main.js) case "playSound": diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index 18a4685..715410e 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -33,14 +33,17 @@ export type McpConnection = { } // StdioServerParameters +const AlwaysAllowSchema = z.array(z.string()).default([]) + const StdioConfigSchema = z.object({ command: z.string(), args: z.array(z.string()).optional(), env: z.record(z.string()).optional(), + alwaysAllow: AlwaysAllowSchema.optional() }) const McpSettingsSchema = z.object({ - mcpServers: z.record(StdioConfigSchema), + mcpServers: z.record(StdioConfigSchema) }) export class McpHub { @@ -285,7 +288,21 @@ export class McpHub { const response = await this.connections .find((conn) => conn.server.name === serverName) ?.client.request({ method: "tools/list" }, ListToolsResultSchema) - return response?.tools || [] + + // Get always allow settings + const settingsPath = await this.getMcpSettingsFilePath() + const content = await fs.readFile(settingsPath, "utf-8") + const config = JSON.parse(content) + const alwaysAllowConfig = config.mcpServers[serverName]?.alwaysAllow || [] + + // Mark tools as always allowed based on settings + const tools = (response?.tools || []).map(tool => ({ + ...tool, + alwaysAllow: alwaysAllowConfig.includes(tool.name) + })) + + console.log(`[MCP] Fetched tools for ${serverName}:`, tools) + return tools } catch (error) { // console.error(`Failed to fetch tools for ${serverName}:`, error) return [] @@ -478,6 +495,7 @@ export class McpHub { `No connection found for server: ${serverName}. Please make sure to use MCP servers available under 'Connected MCP Servers'.`, ) } + return await connection.client.request( { method: "tools/call", @@ -490,6 +508,45 @@ export class McpHub { ) } + async toggleToolAlwaysAllow(serverName: string, toolName: string, shouldAllow: boolean): Promise { + try { + const settingsPath = await this.getMcpSettingsFilePath() + const content = await fs.readFile(settingsPath, "utf-8") + const config = JSON.parse(content) + + // Initialize alwaysAllow if it doesn't exist + if (!config.mcpServers[serverName].alwaysAllow) { + config.mcpServers[serverName].alwaysAllow = [] + } + + const alwaysAllow = config.mcpServers[serverName].alwaysAllow + const toolIndex = alwaysAllow.indexOf(toolName) + + if (shouldAllow && toolIndex === -1) { + // Add tool to always allow list + alwaysAllow.push(toolName) + } else if (!shouldAllow && toolIndex !== -1) { + // Remove tool from always allow list + alwaysAllow.splice(toolIndex, 1) + } + + // Write updated config back to file + await fs.writeFile(settingsPath, JSON.stringify(config, null, 2)) + + // Update the tools list to reflect the change + const connection = this.connections.find(conn => conn.server.name === serverName) + if (connection) { + connection.server.tools = await this.fetchToolsList(serverName) + await this.notifyWebviewOfServerChanges() + } + + } catch (error) { + console.error("Failed to update always allow settings:", error) + vscode.window.showErrorMessage("Failed to update always allow settings") + throw error // Re-throw to ensure the error is properly handled + } + } + async dispose(): Promise { this.removeAllFileWatchers() for (const connection of this.connections) { diff --git a/src/services/mcp/__tests__/McpHub.test.ts b/src/services/mcp/__tests__/McpHub.test.ts new file mode 100644 index 0000000..cf4899b --- /dev/null +++ b/src/services/mcp/__tests__/McpHub.test.ts @@ -0,0 +1,193 @@ +import type { McpHub as McpHubType } from '../McpHub' +import type { ClineProvider } from '../../../core/webview/ClineProvider' +import type { ExtensionContext, Uri } from 'vscode' +import type { McpConnection } from '../McpHub' + +const vscode = require('vscode') +const fs = require('fs/promises') +const { McpHub } = require('../McpHub') + +jest.mock('vscode') +jest.mock('fs/promises') +jest.mock('../../../core/webview/ClineProvider') + +describe('McpHub', () => { + let mcpHub: McpHubType + let mockProvider: Partial + const mockSettingsPath = '/mock/settings/path/cline_mcp_settings.json' + + beforeEach(() => { + jest.clearAllMocks() + + const mockUri: Uri = { + scheme: 'file', + authority: '', + path: '/test/path', + query: '', + fragment: '', + fsPath: '/test/path', + with: jest.fn(), + toJSON: jest.fn() + } + + mockProvider = { + ensureSettingsDirectoryExists: jest.fn().mockResolvedValue('/mock/settings/path'), + ensureMcpServersDirectoryExists: jest.fn().mockResolvedValue('/mock/settings/path'), + postMessageToWebview: jest.fn(), + context: { + subscriptions: [], + workspaceState: {} as any, + globalState: {} as any, + secrets: {} as any, + extensionUri: mockUri, + extensionPath: '/test/path', + storagePath: '/test/storage', + globalStoragePath: '/test/global-storage', + environmentVariableCollection: {} as any, + extension: { + id: 'test-extension', + extensionUri: mockUri, + extensionPath: '/test/path', + extensionKind: 1, + isActive: true, + packageJSON: { + version: '1.0.0' + }, + activate: jest.fn(), + exports: undefined + } as any, + asAbsolutePath: (path: string) => path, + storageUri: mockUri, + globalStorageUri: mockUri, + logUri: mockUri, + extensionMode: 1, + logPath: '/test/path', + languageModelAccessInformation: {} as any + } as ExtensionContext + } + + // Mock fs.readFile for initial settings + ;(fs.readFile as jest.Mock).mockResolvedValue(JSON.stringify({ + mcpServers: { + 'test-server': { + command: 'node', + args: ['test.js'], + alwaysAllow: ['allowed-tool'] + } + } + })) + + mcpHub = new McpHub(mockProvider as ClineProvider) + }) + + describe('toggleToolAlwaysAllow', () => { + it('should add tool to always allow list when enabling', async () => { + const mockConfig = { + mcpServers: { + 'test-server': { + command: 'node', + args: ['test.js'], + alwaysAllow: [] + } + } + } + + // Mock reading initial config + ;(fs.readFile as jest.Mock).mockResolvedValueOnce(JSON.stringify(mockConfig)) + + await mcpHub.toggleToolAlwaysAllow('test-server', 'new-tool', true) + + // Verify the config was updated correctly + const writeCall = (fs.writeFile as jest.Mock).mock.calls[0] + const writtenConfig = JSON.parse(writeCall[1]) + expect(writtenConfig.mcpServers['test-server'].alwaysAllow).toContain('new-tool') + }) + + it('should remove tool from always allow list when disabling', async () => { + const mockConfig = { + mcpServers: { + 'test-server': { + command: 'node', + args: ['test.js'], + alwaysAllow: ['existing-tool'] + } + } + } + + // Mock reading initial config + ;(fs.readFile as jest.Mock).mockResolvedValueOnce(JSON.stringify(mockConfig)) + + await mcpHub.toggleToolAlwaysAllow('test-server', 'existing-tool', false) + + // Verify the config was updated correctly + const writeCall = (fs.writeFile as jest.Mock).mock.calls[0] + const writtenConfig = JSON.parse(writeCall[1]) + expect(writtenConfig.mcpServers['test-server'].alwaysAllow).not.toContain('existing-tool') + }) + + it('should initialize alwaysAllow if it does not exist', async () => { + const mockConfig = { + mcpServers: { + 'test-server': { + command: 'node', + args: ['test.js'] + } + } + } + + // Mock reading initial config + ;(fs.readFile as jest.Mock).mockResolvedValueOnce(JSON.stringify(mockConfig)) + + await mcpHub.toggleToolAlwaysAllow('test-server', 'new-tool', true) + + // Verify the config was updated with initialized alwaysAllow + const writeCall = (fs.writeFile as jest.Mock).mock.calls[0] + const writtenConfig = JSON.parse(writeCall[1]) + expect(writtenConfig.mcpServers['test-server'].alwaysAllow).toBeDefined() + expect(writtenConfig.mcpServers['test-server'].alwaysAllow).toContain('new-tool') + }) + }) + + describe('callTool', () => { + it('should execute tool successfully', async () => { + // Mock the connection with a minimal client implementation + const mockConnection: McpConnection = { + server: { + name: 'test-server', + config: JSON.stringify({}), + status: 'connected' as const + }, + client: { + request: jest.fn().mockResolvedValue({ result: 'success' }) + } as any, + transport: { + start: jest.fn(), + close: jest.fn(), + stderr: { on: jest.fn() } + } as any + } + + mcpHub.connections = [mockConnection] + + await mcpHub.callTool('test-server', 'some-tool', {}) + + // Verify the request was made with correct parameters + expect(mockConnection.client.request).toHaveBeenCalledWith( + { + method: 'tools/call', + params: { + name: 'some-tool', + arguments: {} + } + }, + expect.any(Object) + ) + }) + + it('should throw error if server not found', async () => { + await expect(mcpHub.callTool('non-existent-server', 'some-tool', {})) + .rejects + .toThrow('No connection found for server: non-existent-server') + }) + }) +}) \ No newline at end of file diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index 519756d..fd5b63e 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -34,6 +34,7 @@ export interface WebviewMessage { | "diffEnabled" | "openMcpSettings" | "restartMcpServer" + | "toggleToolAlwaysAllow" text?: string askResponse?: ClineAskResponse apiConfiguration?: ApiConfiguration @@ -41,6 +42,10 @@ export interface WebviewMessage { bool?: boolean commands?: string[] audioType?: AudioType + // For toggleToolAutoApprove + serverName?: string + toolName?: string + alwaysAllow?: boolean } export type ClineAskResponse = "yesButtonClicked" | "noButtonClicked" | "messageResponse" diff --git a/src/shared/mcp.ts b/src/shared/mcp.ts index 82efae2..a00b343 100644 --- a/src/shared/mcp.ts +++ b/src/shared/mcp.ts @@ -12,6 +12,7 @@ export type McpTool = { name: string description?: string inputSchema?: object + alwaysAllow?: boolean } export type McpResource = { diff --git a/webview-ui/src/components/chat/ChatRow.tsx b/webview-ui/src/components/chat/ChatRow.tsx index d068220..6d9042f 100644 --- a/webview-ui/src/components/chat/ChatRow.tsx +++ b/webview-ui/src/components/chat/ChatRow.tsx @@ -813,14 +813,19 @@ export const ChatRowContent = ({ {useMcpServer.type === "use_mcp_tool" && ( <> - tool.name === useMcpServer.toolName) - ?.description || "", - }} - /> +
e.stopPropagation()}> + tool.name === useMcpServer.toolName) + ?.description || "", + alwaysAllow: server?.tools?.find((tool) => tool.name === useMcpServer.toolName) + ?.alwaysAllow || false, + }} + serverName={useMcpServer.serverName} + /> +
{useMcpServer.arguments && useMcpServer.arguments !== "{}" && (
{ - const { version, clineMessages: messages, taskHistory, apiConfiguration, alwaysAllowBrowser, alwaysAllowReadOnly, alwaysAllowWrite, alwaysAllowExecute, allowedCommands } = useExtensionState() + const { version, clineMessages: messages, taskHistory, apiConfiguration, mcpServers, alwaysAllowBrowser, alwaysAllowReadOnly, alwaysAllowWrite, alwaysAllowExecute, allowedCommands } = useExtensionState() //const task = messages.length > 0 ? (messages[0].say === "task" ? messages[0] : undefined) : undefined) : undefined const task = useMemo(() => messages.at(0), [messages]) // leaving this less safe version here since if the first message is not a task, then the extension is in a bad state and needs to be debugged (see Cline.abort) @@ -767,6 +768,19 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie return false } + const isMcpToolAlwaysAllowed = () => { + const lastMessage = messages.at(-1) + if (lastMessage?.type === "ask" && lastMessage.ask === "use_mcp_server" && lastMessage.text) { + const mcpServerUse = JSON.parse(lastMessage.text) as { type: string; serverName: string; toolName: string } + if (mcpServerUse.type === "use_mcp_tool") { + const server = mcpServers?.find((s: McpServer) => s.name === mcpServerUse.serverName) + const tool = server?.tools?.find((t: McpTool) => t.name === mcpServerUse.toolName) + return tool?.alwaysAllow || false + } + } + return false + } + const isAllowedCommand = () => { const lastMessage = messages.at(-1) if (lastMessage?.type === "ask" && lastMessage.text) { @@ -788,11 +802,12 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie (alwaysAllowBrowser && clineAsk === "browser_action_launch") || (alwaysAllowReadOnly && clineAsk === "tool" && isReadOnlyToolAction()) || (alwaysAllowWrite && clineAsk === "tool" && isWriteToolAction()) || - (alwaysAllowExecute && clineAsk === "command" && isAllowedCommand()) + (alwaysAllowExecute && clineAsk === "command" && isAllowedCommand()) || + (clineAsk === "use_mcp_server" && isMcpToolAlwaysAllowed()) ) { handlePrimaryButtonClick() } - }, [clineAsk, enableButtons, handlePrimaryButtonClick, alwaysAllowBrowser, alwaysAllowReadOnly, alwaysAllowWrite, alwaysAllowExecute, messages, allowedCommands]) + }, [clineAsk, enableButtons, handlePrimaryButtonClick, alwaysAllowBrowser, alwaysAllowReadOnly, alwaysAllowWrite, alwaysAllowExecute, messages, allowedCommands, mcpServers]) return (
{ +const McpToolRow = ({ tool, serverName }: McpToolRowProps) => { + const handleAlwaysAllowChange = () => { + if (!serverName) return; + + vscode.postMessage({ + type: "toggleToolAlwaysAllow", + serverName, + toolName: tool.name, + alwaysAllow: !tool.alwaysAllow + }); + } + return (
-
- - {tool.name} +
e.stopPropagation()}> +
+ + {tool.name} +
+ {serverName && ( + + Always allow + + )}
{tool.description && (
{
{server.tools.map((tool) => ( - + ))}
) : ( diff --git a/webview-ui/src/components/mcp/__tests__/McpToolRow.test.tsx b/webview-ui/src/components/mcp/__tests__/McpToolRow.test.tsx new file mode 100644 index 0000000..ff708db --- /dev/null +++ b/webview-ui/src/components/mcp/__tests__/McpToolRow.test.tsx @@ -0,0 +1,107 @@ +import React from 'react' +import { render, fireEvent, screen } from '@testing-library/react' +import McpToolRow from '../McpToolRow' +import { vscode } from '../../../utils/vscode' + +jest.mock('../../../utils/vscode', () => ({ + vscode: { + postMessage: jest.fn() + } +})) + +describe('McpToolRow', () => { + const mockTool = { + name: 'test-tool', + description: 'A test tool', + alwaysAllow: false + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + it('renders tool name and description', () => { + render() + + expect(screen.getByText('test-tool')).toBeInTheDocument() + expect(screen.getByText('A test tool')).toBeInTheDocument() + }) + + it('does not show always allow checkbox when serverName is not provided', () => { + render() + + expect(screen.queryByText('Always allow')).not.toBeInTheDocument() + }) + + it('shows always allow checkbox when serverName is provided', () => { + render() + + expect(screen.getByText('Always allow')).toBeInTheDocument() + }) + + it('sends message to toggle always allow when checkbox is clicked', () => { + render() + + const checkbox = screen.getByRole('checkbox') + fireEvent.click(checkbox) + + expect(vscode.postMessage).toHaveBeenCalledWith({ + type: 'toggleToolAlwaysAllow', + serverName: 'test-server', + toolName: 'test-tool', + alwaysAllow: true + }) + }) + + it('reflects always allow state in checkbox', () => { + const alwaysAllowedTool = { + ...mockTool, + alwaysAllow: true + } + + render() + + const checkbox = screen.getByRole('checkbox') + expect(checkbox).toBeChecked() + }) + + it('prevents event propagation when clicking the checkbox', () => { + const mockStopPropagation = jest.fn() + render() + + const container = screen.getByTestId('tool-row-container') + fireEvent.click(container, { + stopPropagation: mockStopPropagation + }) + + expect(mockStopPropagation).toHaveBeenCalled() + }) + + it('displays input schema parameters when provided', () => { + const toolWithSchema = { + ...mockTool, + inputSchema: { + type: 'object', + properties: { + param1: { + type: 'string', + description: 'First parameter' + }, + param2: { + type: 'number', + description: 'Second parameter' + } + }, + required: ['param1'] + } + } + + render() + + expect(screen.getByText('Parameters')).toBeInTheDocument() + expect(screen.getByText('param1')).toBeInTheDocument() + expect(screen.getByText('param2')).toBeInTheDocument() + expect(screen.getByText('First parameter')).toBeInTheDocument() + expect(screen.getByText('Second parameter')).toBeInTheDocument() + }) +}) \ No newline at end of file