Skip tool execution if user rejected a previous tool

This commit is contained in:
Saoud Rizwan
2024-09-12 17:23:53 -04:00
parent d06bb18505
commit c29fdaa520

View File

@@ -674,7 +674,7 @@ export class ClaudeDev {
this.terminalManager.disposeAll() this.terminalManager.disposeAll()
} }
async executeTool(toolName: ToolName, toolInput: any): Promise<ToolResponse> { async executeTool(toolName: ToolName, toolInput: any): Promise<[boolean, ToolResponse]> {
switch (toolName) { switch (toolName) {
case "write_to_file": case "write_to_file":
return this.writeToFile(toolInput.path, toolInput.content) return this.writeToFile(toolInput.path, toolInput.content)
@@ -693,7 +693,7 @@ export class ClaudeDev {
case "attempt_completion": case "attempt_completion":
return this.attemptCompletion(toolInput.result, toolInput.command) return this.attemptCompletion(toolInput.result, toolInput.command)
default: default:
return `Unknown tool: ${toolName}` return [false, `Unknown tool: ${toolName}`]
} }
} }
@@ -719,10 +719,11 @@ export class ClaudeDev {
return totalCost return totalCost
} }
async writeToFile(relPath?: string, newContent?: string): Promise<ToolResponse> { // return is [didUserRejectTool, ToolResponse]
async writeToFile(relPath?: string, newContent?: string): Promise<[boolean, ToolResponse]> {
if (relPath === undefined) { if (relPath === undefined) {
this.consecutiveMistakeCount++ this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("write_to_file", "path") return [false, await this.sayAndCreateMissingParamError("write_to_file", "path")]
} }
if (newContent === undefined) { if (newContent === undefined) {
this.consecutiveMistakeCount++ this.consecutiveMistakeCount++
@@ -731,9 +732,12 @@ export class ClaudeDev {
"error", "error",
`Claude tried to use write_to_file for '${relPath}' without value for required parameter 'content'. This is likely due to reaching the maximum output token limit. Retrying with suggestion to change response size...` `Claude tried to use write_to_file for '${relPath}' without value for required parameter 'content'. This is likely due to reaching the maximum output token limit. Retrying with suggestion to change response size...`
) )
return await this.formatToolError( return [
`Missing value for required parameter 'content'. This may occur if the file is too large, exceeding output limits. Consider splitting into smaller files or reducing content size. Please retry with all required parameters.` false,
) await this.formatToolError(
`Missing value for required parameter 'content'. This may occur if the file is too large, exceeding output limits. Consider splitting into smaller files or reducing content size. Please retry with all required parameters.`
),
]
} }
this.consecutiveMistakeCount = 0 this.consecutiveMistakeCount = 0
try { try {
@@ -961,9 +965,9 @@ export class ClaudeDev {
if (response === "messageResponse") { if (response === "messageResponse") {
await this.say("user_feedback", text, images) await this.say("user_feedback", text, images)
return this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images) return [true, this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images)]
} }
return await this.formatToolDenied() return [true, await this.formatToolDenied()]
} }
const editedContent = updatedDocument.getText() const editedContent = updatedDocument.getText()
@@ -1040,11 +1044,14 @@ export class ClaudeDev {
diff: this.createPrettyPatch(relPath, normalizedNewContent, normalizedEditedContent), diff: this.createPrettyPatch(relPath, normalizedNewContent, normalizedEditedContent),
} as ClaudeSayTool) } as ClaudeSayTool)
) )
return this.formatToolResult( return [
`The user made the following updates to your content:\n\n${userDiff}\n\nThe updated content, which includes both your original modifications and the user's additional edits, has been successfully saved to ${relPath}. Note this does not mean you need to re-write the file with the user's changes, they have already been applied to the file.` false,
) await this.formatToolResult(
`The user made the following updates to your content:\n\n${userDiff}\n\nThe updated content, which includes both your original modifications and the user's additional edits, has been successfully saved to ${relPath}. Note this does not mean you need to re-write the file with the user's changes, they have already been applied to the file.`
),
]
} else { } else {
return this.formatToolResult(`The content was successfully saved to ${relPath}.`) return [false, await this.formatToolResult(`The content was successfully saved to ${relPath}.`)]
} }
} catch (error) { } catch (error) {
const errorString = `Error writing file: ${JSON.stringify(serializeError(error))}` const errorString = `Error writing file: ${JSON.stringify(serializeError(error))}`
@@ -1052,7 +1059,7 @@ export class ClaudeDev {
"error", "error",
`Error writing file:\n${error.message ?? JSON.stringify(serializeError(error), null, 2)}` `Error writing file:\n${error.message ?? JSON.stringify(serializeError(error), null, 2)}`
) )
return await this.formatToolError(errorString) return [false, await this.formatToolError(errorString)]
} }
} }
@@ -1125,10 +1132,10 @@ export class ClaudeDev {
} }
} }
async readFile(relPath?: string): Promise<ToolResponse> { async readFile(relPath?: string): Promise<[boolean, ToolResponse]> {
if (relPath === undefined) { if (relPath === undefined) {
this.consecutiveMistakeCount++ this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("read_file", "path") return [false, await this.sayAndCreateMissingParamError("read_file", "path")]
} }
this.consecutiveMistakeCount = 0 this.consecutiveMistakeCount = 0
try { try {
@@ -1147,27 +1154,30 @@ export class ClaudeDev {
if (response !== "yesButtonTapped") { if (response !== "yesButtonTapped") {
if (response === "messageResponse") { if (response === "messageResponse") {
await this.say("user_feedback", text, images) await this.say("user_feedback", text, images)
return this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images) return [
true,
this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images),
]
} }
return await this.formatToolDenied() return [true, await this.formatToolDenied()]
} }
} }
return content return [false, content]
} catch (error) { } catch (error) {
const errorString = `Error reading file: ${JSON.stringify(serializeError(error))}` const errorString = `Error reading file: ${JSON.stringify(serializeError(error))}`
await this.say( await this.say(
"error", "error",
`Error reading file:\n${error.message ?? JSON.stringify(serializeError(error), null, 2)}` `Error reading file:\n${error.message ?? JSON.stringify(serializeError(error), null, 2)}`
) )
return await this.formatToolError(errorString) return [false, await this.formatToolError(errorString)]
} }
} }
async listFiles(relDirPath?: string, recursiveRaw?: string): Promise<ToolResponse> { async listFiles(relDirPath?: string, recursiveRaw?: string): Promise<[boolean, ToolResponse]> {
if (relDirPath === undefined) { if (relDirPath === undefined) {
this.consecutiveMistakeCount++ this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("list_files", "path") return [false, await this.sayAndCreateMissingParamError("list_files", "path")]
} }
this.consecutiveMistakeCount = 0 this.consecutiveMistakeCount = 0
try { try {
@@ -1188,13 +1198,16 @@ export class ClaudeDev {
if (response !== "yesButtonTapped") { if (response !== "yesButtonTapped") {
if (response === "messageResponse") { if (response === "messageResponse") {
await this.say("user_feedback", text, images) await this.say("user_feedback", text, images)
return this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images) return [
true,
this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images),
]
} }
return await this.formatToolDenied() return [true, await this.formatToolDenied()]
} }
} }
return this.formatToolResult(result) return [false, await this.formatToolResult(result)]
} catch (error) { } catch (error) {
const errorString = `Error listing files and directories: ${JSON.stringify(serializeError(error))}` const errorString = `Error listing files and directories: ${JSON.stringify(serializeError(error))}`
await this.say( await this.say(
@@ -1203,7 +1216,7 @@ export class ClaudeDev {
error.message ?? JSON.stringify(serializeError(error), null, 2) error.message ?? JSON.stringify(serializeError(error), null, 2)
}` }`
) )
return await this.formatToolError(errorString) return [false, await this.formatToolError(errorString)]
} }
} }
@@ -1266,10 +1279,10 @@ export class ClaudeDev {
} }
} }
async listCodeDefinitionNames(relDirPath?: string): Promise<ToolResponse> { async listCodeDefinitionNames(relDirPath?: string): Promise<[boolean, ToolResponse]> {
if (relDirPath === undefined) { if (relDirPath === undefined) {
this.consecutiveMistakeCount++ this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("list_code_definition_names", "path") return [false, await this.sayAndCreateMissingParamError("list_code_definition_names", "path")]
} }
this.consecutiveMistakeCount = 0 this.consecutiveMistakeCount = 0
try { try {
@@ -1288,13 +1301,16 @@ export class ClaudeDev {
if (response !== "yesButtonTapped") { if (response !== "yesButtonTapped") {
if (response === "messageResponse") { if (response === "messageResponse") {
await this.say("user_feedback", text, images) await this.say("user_feedback", text, images)
return this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images) return [
true,
this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images),
]
} }
return await this.formatToolDenied() return [true, await this.formatToolDenied()]
} }
} }
return this.formatToolResult(result) return [false, await this.formatToolResult(result)]
} catch (error) { } catch (error) {
const errorString = `Error parsing source code definitions: ${JSON.stringify(serializeError(error))}` const errorString = `Error parsing source code definitions: ${JSON.stringify(serializeError(error))}`
await this.say( await this.say(
@@ -1303,18 +1319,18 @@ export class ClaudeDev {
error.message ?? JSON.stringify(serializeError(error), null, 2) error.message ?? JSON.stringify(serializeError(error), null, 2)
}` }`
) )
return await this.formatToolError(errorString) return [false, await this.formatToolError(errorString)]
} }
} }
async searchFiles(relDirPath: string, regex: string, filePattern?: string): Promise<ToolResponse> { async searchFiles(relDirPath: string, regex: string, filePattern?: string): Promise<[boolean, ToolResponse]> {
if (relDirPath === undefined) { if (relDirPath === undefined) {
this.consecutiveMistakeCount++ this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("search_files", "path") return [false, await this.sayAndCreateMissingParamError("search_files", "path")]
} }
if (regex === undefined) { if (regex === undefined) {
this.consecutiveMistakeCount++ this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("search_files", "regex", relDirPath) return [false, await this.sayAndCreateMissingParamError("search_files", "regex", relDirPath)]
} }
this.consecutiveMistakeCount = 0 this.consecutiveMistakeCount = 0
try { try {
@@ -1336,36 +1352,42 @@ export class ClaudeDev {
if (response !== "yesButtonTapped") { if (response !== "yesButtonTapped") {
if (response === "messageResponse") { if (response === "messageResponse") {
await this.say("user_feedback", text, images) await this.say("user_feedback", text, images)
return this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images) return [
true,
this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images),
]
} }
return await this.formatToolDenied() return [true, await this.formatToolDenied()]
} }
} }
return this.formatToolResult(results) return [false, await this.formatToolResult(results)]
} catch (error) { } catch (error) {
const errorString = `Error searching files: ${JSON.stringify(serializeError(error))}` const errorString = `Error searching files: ${JSON.stringify(serializeError(error))}`
await this.say( await this.say(
"error", "error",
`Error searching files:\n${error.message ?? JSON.stringify(serializeError(error), null, 2)}` `Error searching files:\n${error.message ?? JSON.stringify(serializeError(error), null, 2)}`
) )
return await this.formatToolError(errorString) return [false, await this.formatToolError(errorString)]
} }
} }
async executeCommand(command?: string, returnEmptyStringOnSuccess: boolean = false): Promise<ToolResponse> { async executeCommand(
command?: string,
returnEmptyStringOnSuccess: boolean = false
): Promise<[boolean, ToolResponse]> {
if (command === undefined) { if (command === undefined) {
this.consecutiveMistakeCount++ this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("execute_command", "command") return [false, await this.sayAndCreateMissingParamError("execute_command", "command")]
} }
this.consecutiveMistakeCount = 0 this.consecutiveMistakeCount = 0
const { response, text, images } = await this.ask("command", command) const { response, text, images } = await this.ask("command", command)
if (response !== "yesButtonTapped") { if (response !== "yesButtonTapped") {
if (response === "messageResponse") { if (response === "messageResponse") {
await this.say("user_feedback", text, images) await this.say("user_feedback", text, images)
return this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images) return [true, this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images)]
} }
return await this.formatToolDenied() return [true, await this.formatToolDenied()]
} }
try { try {
@@ -1422,75 +1444,85 @@ export class ClaudeDev {
if (userFeedback) { if (userFeedback) {
await this.say("user_feedback", userFeedback.text, userFeedback.images) await this.say("user_feedback", userFeedback.text, userFeedback.images)
return this.formatToolResponseWithImages( return [
`Command is still running in the user's terminal.${ true,
result.length > 0 ? `\nHere's the output so far:\n${result}` : "" this.formatToolResponseWithImages(
}\n\nThe user provided the following feedback:\n<feedback>\n${userFeedback.text}\n</feedback>`, `Command is still running in the user's terminal.${
userFeedback.images result.length > 0 ? `\nHere's the output so far:\n${result}` : ""
) }\n\nThe user provided the following feedback:\n<feedback>\n${userFeedback.text}\n</feedback>`,
userFeedback.images
),
]
} }
// for attemptCompletion, we don't want to return the command output // for attemptCompletion, we don't want to return the command output
if (returnEmptyStringOnSuccess) { if (returnEmptyStringOnSuccess) {
return "" return [false, ""]
} }
if (completed) { if (completed) {
return await this.formatToolResult( return [
`Command executed.${result.length > 0 ? `\nOutput:\n${result}` : ""}` false,
) await this.formatToolResult(`Command executed.${result.length > 0 ? `\nOutput:\n${result}` : ""}`),
]
} else { } else {
return await this.formatToolResult( return [
`Command is still running in the user's terminal.${ false,
result.length > 0 ? `\nHere's the output so far:\n${result}` : "" await this.formatToolResult(
}\n\nYou will be updated on the terminal status and new output in the future.` `Command is still running in the user's terminal.${
) result.length > 0 ? `\nHere's the output so far:\n${result}` : ""
}\n\nYou will be updated on the terminal status and new output in the future.`
),
]
} }
} catch (error) { } catch (error) {
let errorMessage = error.message || JSON.stringify(serializeError(error), null, 2) let errorMessage = error.message || JSON.stringify(serializeError(error), null, 2)
const errorString = `Error executing command:\n${errorMessage}` const errorString = `Error executing command:\n${errorMessage}`
await this.say("error", `Error executing command:\n${errorMessage}`) await this.say("error", `Error executing command:\n${errorMessage}`)
return await this.formatToolError(errorString) return [false, await this.formatToolError(errorString)]
} }
} }
async askFollowupQuestion(question?: string): Promise<ToolResponse> { async askFollowupQuestion(question?: string): Promise<[boolean, ToolResponse]> {
if (question === undefined) { if (question === undefined) {
this.consecutiveMistakeCount++ this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("ask_followup_question", "question") return [false, await this.sayAndCreateMissingParamError("ask_followup_question", "question")]
} }
this.consecutiveMistakeCount = 0 this.consecutiveMistakeCount = 0
const { text, images } = await this.ask("followup", question) const { text, images } = await this.ask("followup", question)
await this.say("user_feedback", text ?? "", images) await this.say("user_feedback", text ?? "", images)
return this.formatToolResponseWithImages(`<answer>\n${text}\n</answer>`, images) return [false, this.formatToolResponseWithImages(`<answer>\n${text}\n</answer>`, images)]
} }
async attemptCompletion(result?: string, command?: string): Promise<ToolResponse> { async attemptCompletion(result?: string, command?: string): Promise<[boolean, ToolResponse]> {
// result is required, command is optional // result is required, command is optional
if (result === undefined) { if (result === undefined) {
this.consecutiveMistakeCount++ this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("attempt_completion", "result") return [false, await this.sayAndCreateMissingParamError("attempt_completion", "result")]
} }
this.consecutiveMistakeCount = 0 this.consecutiveMistakeCount = 0
let resultToSend = result let resultToSend = result
if (command) { if (command) {
await this.say("completion_result", resultToSend) await this.say("completion_result", resultToSend)
// TODO: currently we don't handle if this command fails, it could be useful to let claude know and retry // TODO: currently we don't handle if this command fails, it could be useful to let claude know and retry
const commandResult = await this.executeCommand(command, true) const [didUserReject, commandResult] = await this.executeCommand(command, true)
// if we received non-empty string, the command was rejected or failed // if we received non-empty string, the command was rejected or failed
if (commandResult) { if (commandResult) {
return commandResult return [didUserReject, commandResult]
} }
resultToSend = "" resultToSend = ""
} }
const { response, text, images } = await this.ask("completion_result", resultToSend) // this prompts webview to show 'new task' button, and enable text input (which would be the 'text' here) const { response, text, images } = await this.ask("completion_result", resultToSend) // this prompts webview to show 'new task' button, and enable text input (which would be the 'text' here)
if (response === "yesButtonTapped") { if (response === "yesButtonTapped") {
return "" // signals to recursive loop to stop (for now this never happens since yesButtonTapped will trigger a new task) return [false, ""] // signals to recursive loop to stop (for now this never happens since yesButtonTapped will trigger a new task)
} }
await this.say("user_feedback", text ?? "", images) await this.say("user_feedback", text ?? "", images)
return this.formatToolResponseWithImages( return [
`The user has provided feedback on the results. Consider their input to continue the task, and then attempt completion again.\n<feedback>\n${text}\n</feedback>`, true,
images this.formatToolResponseWithImages(
) `The user has provided feedback on the results. Consider their input to continue the task, and then attempt completion again.\n<feedback>\n${text}\n</feedback>`,
images
),
]
} }
async attemptApiRequest(): Promise<Anthropic.Messages.Message> { async attemptApiRequest(): Promise<Anthropic.Messages.Message> {
@@ -1668,21 +1700,31 @@ ${this.customInstructions.trim()}
let toolResults: Anthropic.ToolResultBlockParam[] = [] let toolResults: Anthropic.ToolResultBlockParam[] = []
let attemptCompletionBlock: Anthropic.Messages.ToolUseBlock | undefined let attemptCompletionBlock: Anthropic.Messages.ToolUseBlock | undefined
let userRejectedATool = false
for (const contentBlock of response.content) { for (const contentBlock of response.content) {
if (contentBlock.type === "tool_use") { if (contentBlock.type === "tool_use") {
const toolName = contentBlock.name as ToolName const toolName = contentBlock.name as ToolName
const toolInput = contentBlock.input const toolInput = contentBlock.input
const toolUseId = contentBlock.id const toolUseId = contentBlock.id
if (userRejectedATool) {
toolResults.push({
type: "tool_result",
tool_use_id: toolUseId,
content: "Skipping tool execution due to previous tool user rejection.",
})
continue
}
if (toolName === "attempt_completion") { if (toolName === "attempt_completion") {
attemptCompletionBlock = contentBlock attemptCompletionBlock = contentBlock
} else { } else {
// NOTE: while anthropic sdk accepts string or array of string/image, openai sdk (openrouter) only accepts a string const [didUserReject, result] = await this.executeTool(toolName, toolInput)
const result = await this.executeTool(toolName, toolInput)
// this.say(
// "tool",
// `\nTool Used: ${toolName}\nTool Input: ${JSON.stringify(toolInput)}\nTool Result: ${result}`
// )
toolResults.push({ type: "tool_result", tool_use_id: toolUseId, content: result }) toolResults.push({ type: "tool_result", tool_use_id: toolUseId, content: result })
if (didUserReject) {
userRejectedATool = true
}
} }
} }
} }
@@ -1692,7 +1734,7 @@ ${this.customInstructions.trim()}
// attempt_completion is always done last, since there might have been other tools that needed to be called first before the job is finished // attempt_completion is always done last, since there might have been other tools that needed to be called first before the job is finished
// it's important to note that claude will order the tools logically in most cases, so we don't have to think about which tools make sense calling before others // it's important to note that claude will order the tools logically in most cases, so we don't have to think about which tools make sense calling before others
if (attemptCompletionBlock) { if (attemptCompletionBlock) {
let result = await this.executeTool( let [_, result] = await this.executeTool(
attemptCompletionBlock.name as ToolName, attemptCompletionBlock.name as ToolName,
attemptCompletionBlock.input attemptCompletionBlock.input
) )