MCP checkbox for always allow

This commit is contained in:
Matt Rubens
2024-12-13 14:23:31 -05:00
parent 6ee118e0a2
commit 1346f1280c
26 changed files with 744 additions and 22 deletions

View File

@@ -1,5 +1,9 @@
# Roo Cline Changelog
## [2.2.2]
- Add checkboxes to auto-approve MCP tools
## [2.2.1]
- Fix another diff editing indentation bug

View File

@@ -1,6 +1,6 @@
# Roo-Cline
A fork of Cline, an autonomous coding agent, with some added experimental configuration and automation features.
A fork of Cline, an autonomous coding agent, optimized for speed and flexibility.
- Auto-approval capabilities for commands, write, and browser operations
- Support for .clinerules per-project custom instructions
- Ability to run side-by-side with Cline
@@ -10,6 +10,7 @@ A fork of Cline, an autonomous coding agent, with some added experimental config
- Support for copying prompts from the history screen
- Support for editing through diffs / handling truncated full-file edits
- Support for newer Gemini models (gemini-exp-1206 and gemini-2.0-flash-exp)
- Support for auto-approving MCP tools
## Disclaimer

View File

@@ -5,17 +5,35 @@ module.exports = {
moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'],
transform: {
'^.+\\.tsx?$': ['ts-jest', {
tsconfig: 'tsconfig.json'
tsconfig: {
"module": "CommonJS",
"moduleResolution": "node",
"esModuleInterop": true,
"allowJs": true
}
}]
},
testMatch: ['**/__tests__/**/*.test.ts'],
moduleNameMapper: {
'^vscode$': '<rootDir>/node_modules/@types/vscode/index.d.ts'
'^vscode$': '<rootDir>/src/__mocks__/vscode.js',
'@modelcontextprotocol/sdk$': '<rootDir>/src/__mocks__/@modelcontextprotocol/sdk/index.js',
'@modelcontextprotocol/sdk/(.*)': '<rootDir>/src/__mocks__/@modelcontextprotocol/sdk/$1',
'^delay$': '<rootDir>/src/__mocks__/delay.js',
'^p-wait-for$': '<rootDir>/src/__mocks__/p-wait-for.js',
'^globby$': '<rootDir>/src/__mocks__/globby.js',
'^serialize-error$': '<rootDir>/src/__mocks__/serialize-error.js',
'^strip-ansi$': '<rootDir>/src/__mocks__/strip-ansi.js',
'^default-shell$': '<rootDir>/src/__mocks__/default-shell.js',
'^os-name$': '<rootDir>/src/__mocks__/os-name.js'
},
transformIgnorePatterns: [
'node_modules/(?!(@modelcontextprotocol|delay|p-wait-for|globby|serialize-error|strip-ansi|default-shell|os-name)/)'
],
setupFiles: [],
globals: {
'ts-jest': {
diagnostics: false
diagnostics: false,
isolatedModules: true
}
}
};

View File

@@ -0,0 +1,17 @@
class Client {
constructor() {
this.request = jest.fn()
}
connect() {
return Promise.resolve()
}
close() {
return Promise.resolve()
}
}
module.exports = {
Client
}

View File

@@ -0,0 +1,22 @@
class StdioClientTransport {
constructor() {
this.start = jest.fn().mockResolvedValue(undefined)
this.close = jest.fn().mockResolvedValue(undefined)
this.stderr = {
on: jest.fn()
}
}
}
class StdioServerParameters {
constructor() {
this.command = ''
this.args = []
this.env = {}
}
}
module.exports = {
StdioClientTransport,
StdioServerParameters
}

View File

@@ -0,0 +1,24 @@
const { Client } = require('./client/index.js')
const { StdioClientTransport, StdioServerParameters } = require('./client/stdio.js')
const {
CallToolResultSchema,
ListToolsResultSchema,
ListResourcesResultSchema,
ListResourceTemplatesResultSchema,
ReadResourceResultSchema,
ErrorCode,
McpError
} = require('./types.js')
module.exports = {
Client,
StdioClientTransport,
StdioServerParameters,
CallToolResultSchema,
ListToolsResultSchema,
ListResourcesResultSchema,
ListResourceTemplatesResultSchema,
ReadResourceResultSchema,
ErrorCode,
McpError
}

View File

@@ -0,0 +1,51 @@
const CallToolResultSchema = {
parse: jest.fn().mockReturnValue({})
}
const ListToolsResultSchema = {
parse: jest.fn().mockReturnValue({
tools: []
})
}
const ListResourcesResultSchema = {
parse: jest.fn().mockReturnValue({
resources: []
})
}
const ListResourceTemplatesResultSchema = {
parse: jest.fn().mockReturnValue({
resourceTemplates: []
})
}
const ReadResourceResultSchema = {
parse: jest.fn().mockReturnValue({
contents: []
})
}
const ErrorCode = {
InvalidRequest: 'InvalidRequest',
MethodNotFound: 'MethodNotFound',
InvalidParams: 'InvalidParams',
InternalError: 'InternalError'
}
class McpError extends Error {
constructor(code, message) {
super(message)
this.code = code
}
}
module.exports = {
CallToolResultSchema,
ListToolsResultSchema,
ListResourcesResultSchema,
ListResourceTemplatesResultSchema,
ReadResourceResultSchema,
ErrorCode,
McpError
}

17
src/__mocks__/McpHub.ts Normal file
View File

@@ -0,0 +1,17 @@
export class McpHub {
connections = []
isConnecting = false
constructor() {
this.toggleToolAlwaysAllow = jest.fn()
this.callTool = jest.fn()
}
async toggleToolAlwaysAllow(serverName: string, toolName: string, shouldAllow: boolean): Promise<void> {
return Promise.resolve()
}
async callTool(serverName: string, toolName: string, toolArguments?: Record<string, unknown>): Promise<any> {
return Promise.resolve({ result: 'success' })
}
}

View File

@@ -0,0 +1,12 @@
// Mock default shell based on platform
const os = require('os');
let defaultShell;
if (os.platform() === 'win32') {
defaultShell = 'cmd.exe';
} else {
defaultShell = '/bin/bash';
}
module.exports = defaultShell;
module.exports.default = defaultShell;

6
src/__mocks__/delay.js Normal file
View File

@@ -0,0 +1,6 @@
function delay(ms) {
return new Promise(resolve => setTimeout(resolve, ms));
}
module.exports = delay;
module.exports.default = delay;

10
src/__mocks__/globby.js Normal file
View File

@@ -0,0 +1,10 @@
function globby(patterns, options) {
return Promise.resolve([]);
}
globby.sync = function(patterns, options) {
return [];
};
module.exports = globby;
module.exports.default = globby;

6
src/__mocks__/os-name.js Normal file
View File

@@ -0,0 +1,6 @@
function osName() {
return 'macOS';
}
module.exports = osName;
module.exports.default = osName;

View File

@@ -0,0 +1,20 @@
function pWaitFor(condition, options = {}) {
return new Promise((resolve, reject) => {
const interval = setInterval(() => {
if (condition()) {
clearInterval(interval);
resolve();
}
}, options.interval || 20);
if (options.timeout) {
setTimeout(() => {
clearInterval(interval);
reject(new Error('Timed out'));
}, options.timeout);
}
});
}
module.exports = pWaitFor;
module.exports.default = pWaitFor;

View File

@@ -0,0 +1,25 @@
function serializeError(error) {
if (error instanceof Error) {
return {
name: error.name,
message: error.message,
stack: error.stack
};
}
return error;
}
function deserializeError(errorData) {
if (errorData && typeof errorData === 'object') {
const error = new Error(errorData.message);
error.name = errorData.name;
error.stack = errorData.stack;
return error;
}
return errorData;
}
module.exports = {
serializeError,
deserializeError
};

View File

@@ -0,0 +1,7 @@
function stripAnsi(string) {
// Simple mock that just returns the input string
return string;
}
module.exports = stripAnsi;
module.exports.default = stripAnsi;

57
src/__mocks__/vscode.js Normal file
View File

@@ -0,0 +1,57 @@
const vscode = {
window: {
showInformationMessage: jest.fn(),
showErrorMessage: jest.fn(),
createTextEditorDecorationType: jest.fn().mockReturnValue({
dispose: jest.fn()
})
},
workspace: {
onDidSaveTextDocument: jest.fn()
},
Disposable: class {
dispose() {}
},
Uri: {
file: (path) => ({
fsPath: path,
scheme: 'file',
authority: '',
path: path,
query: '',
fragment: '',
with: jest.fn(),
toJSON: jest.fn()
})
},
EventEmitter: class {
constructor() {
this.event = jest.fn();
this.fire = jest.fn();
}
},
ConfigurationTarget: {
Global: 1,
Workspace: 2,
WorkspaceFolder: 3
},
Position: class {
constructor(line, character) {
this.line = line;
this.character = character;
}
},
Range: class {
constructor(startLine, startCharacter, endLine, endCharacter) {
this.start = new vscode.Position(startLine, startCharacter);
this.end = new vscode.Position(endLine, endCharacter);
}
},
ThemeColor: class {
constructor(id) {
this.id = id;
}
}
};
module.exports = vscode;

View File

@@ -550,6 +550,18 @@ export class ClineProvider implements vscode.WebviewViewProvider {
}
break
}
case "toggleToolAlwaysAllow": {
try {
await this.mcpHub?.toggleToolAlwaysAllow(
message.serverName!,
message.toolName!,
message.alwaysAllow!
)
} catch (error) {
console.error(`Failed to toggle auto-approve for tool ${message.toolName}:`, error)
}
break
}
// Add more switch case statements here as more webview message commands
// are created within the webview context (i.e. inside media/main.js)
case "playSound":

View File

@@ -33,14 +33,17 @@ export type McpConnection = {
}
// StdioServerParameters
const AlwaysAllowSchema = z.array(z.string()).default([])
const StdioConfigSchema = z.object({
command: z.string(),
args: z.array(z.string()).optional(),
env: z.record(z.string()).optional(),
alwaysAllow: AlwaysAllowSchema.optional()
})
const McpSettingsSchema = z.object({
mcpServers: z.record(StdioConfigSchema),
mcpServers: z.record(StdioConfigSchema)
})
export class McpHub {
@@ -285,7 +288,21 @@ export class McpHub {
const response = await this.connections
.find((conn) => conn.server.name === serverName)
?.client.request({ method: "tools/list" }, ListToolsResultSchema)
return response?.tools || []
// Get always allow settings
const settingsPath = await this.getMcpSettingsFilePath()
const content = await fs.readFile(settingsPath, "utf-8")
const config = JSON.parse(content)
const alwaysAllowConfig = config.mcpServers[serverName]?.alwaysAllow || []
// Mark tools as always allowed based on settings
const tools = (response?.tools || []).map(tool => ({
...tool,
alwaysAllow: alwaysAllowConfig.includes(tool.name)
}))
console.log(`[MCP] Fetched tools for ${serverName}:`, tools)
return tools
} catch (error) {
// console.error(`Failed to fetch tools for ${serverName}:`, error)
return []
@@ -478,6 +495,7 @@ export class McpHub {
`No connection found for server: ${serverName}. Please make sure to use MCP servers available under 'Connected MCP Servers'.`,
)
}
return await connection.client.request(
{
method: "tools/call",
@@ -490,6 +508,45 @@ export class McpHub {
)
}
async toggleToolAlwaysAllow(serverName: string, toolName: string, shouldAllow: boolean): Promise<void> {
try {
const settingsPath = await this.getMcpSettingsFilePath()
const content = await fs.readFile(settingsPath, "utf-8")
const config = JSON.parse(content)
// Initialize alwaysAllow if it doesn't exist
if (!config.mcpServers[serverName].alwaysAllow) {
config.mcpServers[serverName].alwaysAllow = []
}
const alwaysAllow = config.mcpServers[serverName].alwaysAllow
const toolIndex = alwaysAllow.indexOf(toolName)
if (shouldAllow && toolIndex === -1) {
// Add tool to always allow list
alwaysAllow.push(toolName)
} else if (!shouldAllow && toolIndex !== -1) {
// Remove tool from always allow list
alwaysAllow.splice(toolIndex, 1)
}
// Write updated config back to file
await fs.writeFile(settingsPath, JSON.stringify(config, null, 2))
// Update the tools list to reflect the change
const connection = this.connections.find(conn => conn.server.name === serverName)
if (connection) {
connection.server.tools = await this.fetchToolsList(serverName)
await this.notifyWebviewOfServerChanges()
}
} catch (error) {
console.error("Failed to update always allow settings:", error)
vscode.window.showErrorMessage("Failed to update always allow settings")
throw error // Re-throw to ensure the error is properly handled
}
}
async dispose(): Promise<void> {
this.removeAllFileWatchers()
for (const connection of this.connections) {

View File

@@ -0,0 +1,193 @@
import type { McpHub as McpHubType } from '../McpHub'
import type { ClineProvider } from '../../../core/webview/ClineProvider'
import type { ExtensionContext, Uri } from 'vscode'
import type { McpConnection } from '../McpHub'
const vscode = require('vscode')
const fs = require('fs/promises')
const { McpHub } = require('../McpHub')
jest.mock('vscode')
jest.mock('fs/promises')
jest.mock('../../../core/webview/ClineProvider')
describe('McpHub', () => {
let mcpHub: McpHubType
let mockProvider: Partial<ClineProvider>
const mockSettingsPath = '/mock/settings/path/cline_mcp_settings.json'
beforeEach(() => {
jest.clearAllMocks()
const mockUri: Uri = {
scheme: 'file',
authority: '',
path: '/test/path',
query: '',
fragment: '',
fsPath: '/test/path',
with: jest.fn(),
toJSON: jest.fn()
}
mockProvider = {
ensureSettingsDirectoryExists: jest.fn().mockResolvedValue('/mock/settings/path'),
ensureMcpServersDirectoryExists: jest.fn().mockResolvedValue('/mock/settings/path'),
postMessageToWebview: jest.fn(),
context: {
subscriptions: [],
workspaceState: {} as any,
globalState: {} as any,
secrets: {} as any,
extensionUri: mockUri,
extensionPath: '/test/path',
storagePath: '/test/storage',
globalStoragePath: '/test/global-storage',
environmentVariableCollection: {} as any,
extension: {
id: 'test-extension',
extensionUri: mockUri,
extensionPath: '/test/path',
extensionKind: 1,
isActive: true,
packageJSON: {
version: '1.0.0'
},
activate: jest.fn(),
exports: undefined
} as any,
asAbsolutePath: (path: string) => path,
storageUri: mockUri,
globalStorageUri: mockUri,
logUri: mockUri,
extensionMode: 1,
logPath: '/test/path',
languageModelAccessInformation: {} as any
} as ExtensionContext
}
// Mock fs.readFile for initial settings
;(fs.readFile as jest.Mock).mockResolvedValue(JSON.stringify({
mcpServers: {
'test-server': {
command: 'node',
args: ['test.js'],
alwaysAllow: ['allowed-tool']
}
}
}))
mcpHub = new McpHub(mockProvider as ClineProvider)
})
describe('toggleToolAlwaysAllow', () => {
it('should add tool to always allow list when enabling', async () => {
const mockConfig = {
mcpServers: {
'test-server': {
command: 'node',
args: ['test.js'],
alwaysAllow: []
}
}
}
// Mock reading initial config
;(fs.readFile as jest.Mock).mockResolvedValueOnce(JSON.stringify(mockConfig))
await mcpHub.toggleToolAlwaysAllow('test-server', 'new-tool', true)
// Verify the config was updated correctly
const writeCall = (fs.writeFile as jest.Mock).mock.calls[0]
const writtenConfig = JSON.parse(writeCall[1])
expect(writtenConfig.mcpServers['test-server'].alwaysAllow).toContain('new-tool')
})
it('should remove tool from always allow list when disabling', async () => {
const mockConfig = {
mcpServers: {
'test-server': {
command: 'node',
args: ['test.js'],
alwaysAllow: ['existing-tool']
}
}
}
// Mock reading initial config
;(fs.readFile as jest.Mock).mockResolvedValueOnce(JSON.stringify(mockConfig))
await mcpHub.toggleToolAlwaysAllow('test-server', 'existing-tool', false)
// Verify the config was updated correctly
const writeCall = (fs.writeFile as jest.Mock).mock.calls[0]
const writtenConfig = JSON.parse(writeCall[1])
expect(writtenConfig.mcpServers['test-server'].alwaysAllow).not.toContain('existing-tool')
})
it('should initialize alwaysAllow if it does not exist', async () => {
const mockConfig = {
mcpServers: {
'test-server': {
command: 'node',
args: ['test.js']
}
}
}
// Mock reading initial config
;(fs.readFile as jest.Mock).mockResolvedValueOnce(JSON.stringify(mockConfig))
await mcpHub.toggleToolAlwaysAllow('test-server', 'new-tool', true)
// Verify the config was updated with initialized alwaysAllow
const writeCall = (fs.writeFile as jest.Mock).mock.calls[0]
const writtenConfig = JSON.parse(writeCall[1])
expect(writtenConfig.mcpServers['test-server'].alwaysAllow).toBeDefined()
expect(writtenConfig.mcpServers['test-server'].alwaysAllow).toContain('new-tool')
})
})
describe('callTool', () => {
it('should execute tool successfully', async () => {
// Mock the connection with a minimal client implementation
const mockConnection: McpConnection = {
server: {
name: 'test-server',
config: JSON.stringify({}),
status: 'connected' as const
},
client: {
request: jest.fn().mockResolvedValue({ result: 'success' })
} as any,
transport: {
start: jest.fn(),
close: jest.fn(),
stderr: { on: jest.fn() }
} as any
}
mcpHub.connections = [mockConnection]
await mcpHub.callTool('test-server', 'some-tool', {})
// Verify the request was made with correct parameters
expect(mockConnection.client.request).toHaveBeenCalledWith(
{
method: 'tools/call',
params: {
name: 'some-tool',
arguments: {}
}
},
expect.any(Object)
)
})
it('should throw error if server not found', async () => {
await expect(mcpHub.callTool('non-existent-server', 'some-tool', {}))
.rejects
.toThrow('No connection found for server: non-existent-server')
})
})
})

View File

@@ -34,6 +34,7 @@ export interface WebviewMessage {
| "diffEnabled"
| "openMcpSettings"
| "restartMcpServer"
| "toggleToolAlwaysAllow"
text?: string
askResponse?: ClineAskResponse
apiConfiguration?: ApiConfiguration
@@ -41,6 +42,10 @@ export interface WebviewMessage {
bool?: boolean
commands?: string[]
audioType?: AudioType
// For toggleToolAutoApprove
serverName?: string
toolName?: string
alwaysAllow?: boolean
}
export type ClineAskResponse = "yesButtonClicked" | "noButtonClicked" | "messageResponse"

View File

@@ -12,6 +12,7 @@ export type McpTool = {
name: string
description?: string
inputSchema?: object
alwaysAllow?: boolean
}
export type McpResource = {

View File

@@ -813,14 +813,19 @@ export const ChatRowContent = ({
{useMcpServer.type === "use_mcp_tool" && (
<>
<div onClick={(e) => e.stopPropagation()}>
<McpToolRow
tool={{
name: useMcpServer.toolName || "",
description:
server?.tools?.find((tool) => tool.name === useMcpServer.toolName)
?.description || "",
alwaysAllow: server?.tools?.find((tool) => tool.name === useMcpServer.toolName)
?.alwaysAllow || false,
}}
serverName={useMcpServer.serverName}
/>
</div>
{useMcpServer.arguments && useMcpServer.arguments !== "{}" && (
<div style={{ marginTop: "8px" }}>
<div

View File

@@ -11,6 +11,7 @@ import {
ClineSayTool,
ExtensionMessage,
} from "../../../../src/shared/ExtensionMessage"
import { McpServer, McpTool } from "../../../../src/shared/mcp"
import { findLast } from "../../../../src/shared/array"
import { combineApiRequests } from "../../../../src/shared/combineApiRequests"
import { combineCommandSequences } from "../../../../src/shared/combineCommandSequences"
@@ -36,7 +37,7 @@ interface ChatViewProps {
export const MAX_IMAGES_PER_MESSAGE = 20 // Anthropic limits to 20 images
const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryView }: ChatViewProps) => {
const { version, clineMessages: messages, taskHistory, apiConfiguration, alwaysAllowBrowser, alwaysAllowReadOnly, alwaysAllowWrite, alwaysAllowExecute, allowedCommands } = useExtensionState()
const { version, clineMessages: messages, taskHistory, apiConfiguration, mcpServers, alwaysAllowBrowser, alwaysAllowReadOnly, alwaysAllowWrite, alwaysAllowExecute, allowedCommands } = useExtensionState()
//const task = messages.length > 0 ? (messages[0].say === "task" ? messages[0] : undefined) : undefined) : undefined
const task = useMemo(() => messages.at(0), [messages]) // leaving this less safe version here since if the first message is not a task, then the extension is in a bad state and needs to be debugged (see Cline.abort)
@@ -767,6 +768,19 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie
return false
}
const isMcpToolAlwaysAllowed = () => {
const lastMessage = messages.at(-1)
if (lastMessage?.type === "ask" && lastMessage.ask === "use_mcp_server" && lastMessage.text) {
const mcpServerUse = JSON.parse(lastMessage.text) as { type: string; serverName: string; toolName: string }
if (mcpServerUse.type === "use_mcp_tool") {
const server = mcpServers?.find((s: McpServer) => s.name === mcpServerUse.serverName)
const tool = server?.tools?.find((t: McpTool) => t.name === mcpServerUse.toolName)
return tool?.alwaysAllow || false
}
}
return false
}
const isAllowedCommand = () => {
const lastMessage = messages.at(-1)
if (lastMessage?.type === "ask" && lastMessage.text) {
@@ -788,11 +802,12 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie
(alwaysAllowBrowser && clineAsk === "browser_action_launch") ||
(alwaysAllowReadOnly && clineAsk === "tool" && isReadOnlyToolAction()) ||
(alwaysAllowWrite && clineAsk === "tool" && isWriteToolAction()) ||
(alwaysAllowExecute && clineAsk === "command" && isAllowedCommand())
(alwaysAllowExecute && clineAsk === "command" && isAllowedCommand()) ||
(clineAsk === "use_mcp_server" && isMcpToolAlwaysAllowed())
) {
handlePrimaryButtonClick()
}
}, [clineAsk, enableButtons, handlePrimaryButtonClick, alwaysAllowBrowser, alwaysAllowReadOnly, alwaysAllowWrite, alwaysAllowExecute, messages, allowedCommands])
}, [clineAsk, enableButtons, handlePrimaryButtonClick, alwaysAllowBrowser, alwaysAllowReadOnly, alwaysAllowWrite, alwaysAllowExecute, messages, allowedCommands, mcpServers])
return (
<div

View File

@@ -1,20 +1,46 @@
import { VSCodeCheckbox } from "@vscode/webview-ui-toolkit/react"
import { McpTool } from "../../../../src/shared/mcp"
import { vscode } from "../../utils/vscode"
type McpToolRowProps = {
tool: McpTool
serverName?: string
}
const McpToolRow = ({ tool }: McpToolRowProps) => {
const McpToolRow = ({ tool, serverName }: McpToolRowProps) => {
const handleAlwaysAllowChange = () => {
if (!serverName) return;
vscode.postMessage({
type: "toggleToolAlwaysAllow",
serverName,
toolName: tool.name,
alwaysAllow: !tool.alwaysAllow
});
}
return (
<div
key={tool.name}
style={{
padding: "3px 0",
}}>
<div style={{ display: "flex" }}>
<div
style={{ display: "flex", alignItems: "center", justifyContent: "space-between" }}
onClick={(e) => e.stopPropagation()}>
<div style={{ display: "flex", alignItems: "center" }}>
<span className="codicon codicon-symbol-method" style={{ marginRight: "6px" }}></span>
<span style={{ fontWeight: 500 }}>{tool.name}</span>
</div>
{serverName && (
<VSCodeCheckbox
checked={tool.alwaysAllow}
onChange={handleAlwaysAllowChange}
data-tool={tool.name}>
Always allow
</VSCodeCheckbox>
)}
</div>
{tool.description && (
<div
style={{

View File

@@ -256,7 +256,11 @@ const ServerRow = ({ server }: { server: McpServer }) => {
<div
style={{ display: "flex", flexDirection: "column", gap: "8px", width: "100%" }}>
{server.tools.map((tool) => (
<McpToolRow key={tool.name} tool={tool} />
<McpToolRow
key={tool.name}
tool={tool}
serverName={server.name}
/>
))}
</div>
) : (

View File

@@ -0,0 +1,107 @@
import React from 'react'
import { render, fireEvent, screen } from '@testing-library/react'
import McpToolRow from '../McpToolRow'
import { vscode } from '../../../utils/vscode'
jest.mock('../../../utils/vscode', () => ({
vscode: {
postMessage: jest.fn()
}
}))
describe('McpToolRow', () => {
const mockTool = {
name: 'test-tool',
description: 'A test tool',
alwaysAllow: false
}
beforeEach(() => {
jest.clearAllMocks()
})
it('renders tool name and description', () => {
render(<McpToolRow tool={mockTool} />)
expect(screen.getByText('test-tool')).toBeInTheDocument()
expect(screen.getByText('A test tool')).toBeInTheDocument()
})
it('does not show always allow checkbox when serverName is not provided', () => {
render(<McpToolRow tool={mockTool} />)
expect(screen.queryByText('Always allow')).not.toBeInTheDocument()
})
it('shows always allow checkbox when serverName is provided', () => {
render(<McpToolRow tool={mockTool} serverName="test-server" />)
expect(screen.getByText('Always allow')).toBeInTheDocument()
})
it('sends message to toggle always allow when checkbox is clicked', () => {
render(<McpToolRow tool={mockTool} serverName="test-server" />)
const checkbox = screen.getByRole('checkbox')
fireEvent.click(checkbox)
expect(vscode.postMessage).toHaveBeenCalledWith({
type: 'toggleToolAlwaysAllow',
serverName: 'test-server',
toolName: 'test-tool',
alwaysAllow: true
})
})
it('reflects always allow state in checkbox', () => {
const alwaysAllowedTool = {
...mockTool,
alwaysAllow: true
}
render(<McpToolRow tool={alwaysAllowedTool} serverName="test-server" />)
const checkbox = screen.getByRole('checkbox')
expect(checkbox).toBeChecked()
})
it('prevents event propagation when clicking the checkbox', () => {
const mockStopPropagation = jest.fn()
render(<McpToolRow tool={mockTool} serverName="test-server" />)
const container = screen.getByTestId('tool-row-container')
fireEvent.click(container, {
stopPropagation: mockStopPropagation
})
expect(mockStopPropagation).toHaveBeenCalled()
})
it('displays input schema parameters when provided', () => {
const toolWithSchema = {
...mockTool,
inputSchema: {
type: 'object',
properties: {
param1: {
type: 'string',
description: 'First parameter'
},
param2: {
type: 'number',
description: 'Second parameter'
}
},
required: ['param1']
}
}
render(<McpToolRow tool={toolWithSchema} serverName="test-server" />)
expect(screen.getByText('Parameters')).toBeInTheDocument()
expect(screen.getByText('param1')).toBeInTheDocument()
expect(screen.getByText('param2')).toBeInTheDocument()
expect(screen.getByText('First parameter')).toBeInTheDocument()
expect(screen.getByText('Second parameter')).toBeInTheDocument()
})
})