Skip tool execution if user rejected a previous tool

This commit is contained in:
Saoud Rizwan
2024-09-12 17:23:53 -04:00
parent d06bb18505
commit c29fdaa520

View File

@@ -674,7 +674,7 @@ export class ClaudeDev {
this.terminalManager.disposeAll()
}
async executeTool(toolName: ToolName, toolInput: any): Promise<ToolResponse> {
async executeTool(toolName: ToolName, toolInput: any): Promise<[boolean, ToolResponse]> {
switch (toolName) {
case "write_to_file":
return this.writeToFile(toolInput.path, toolInput.content)
@@ -693,7 +693,7 @@ export class ClaudeDev {
case "attempt_completion":
return this.attemptCompletion(toolInput.result, toolInput.command)
default:
return `Unknown tool: ${toolName}`
return [false, `Unknown tool: ${toolName}`]
}
}
@@ -719,10 +719,11 @@ export class ClaudeDev {
return totalCost
}
async writeToFile(relPath?: string, newContent?: string): Promise<ToolResponse> {
// return is [didUserRejectTool, ToolResponse]
async writeToFile(relPath?: string, newContent?: string): Promise<[boolean, ToolResponse]> {
if (relPath === undefined) {
this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("write_to_file", "path")
return [false, await this.sayAndCreateMissingParamError("write_to_file", "path")]
}
if (newContent === undefined) {
this.consecutiveMistakeCount++
@@ -731,9 +732,12 @@ export class ClaudeDev {
"error",
`Claude tried to use write_to_file for '${relPath}' without value for required parameter 'content'. This is likely due to reaching the maximum output token limit. Retrying with suggestion to change response size...`
)
return await this.formatToolError(
`Missing value for required parameter 'content'. This may occur if the file is too large, exceeding output limits. Consider splitting into smaller files or reducing content size. Please retry with all required parameters.`
)
return [
false,
await this.formatToolError(
`Missing value for required parameter 'content'. This may occur if the file is too large, exceeding output limits. Consider splitting into smaller files or reducing content size. Please retry with all required parameters.`
),
]
}
this.consecutiveMistakeCount = 0
try {
@@ -961,9 +965,9 @@ export class ClaudeDev {
if (response === "messageResponse") {
await this.say("user_feedback", text, images)
return this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images)
return [true, this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images)]
}
return await this.formatToolDenied()
return [true, await this.formatToolDenied()]
}
const editedContent = updatedDocument.getText()
@@ -1040,11 +1044,14 @@ export class ClaudeDev {
diff: this.createPrettyPatch(relPath, normalizedNewContent, normalizedEditedContent),
} as ClaudeSayTool)
)
return this.formatToolResult(
`The user made the following updates to your content:\n\n${userDiff}\n\nThe updated content, which includes both your original modifications and the user's additional edits, has been successfully saved to ${relPath}. Note this does not mean you need to re-write the file with the user's changes, they have already been applied to the file.`
)
return [
false,
await this.formatToolResult(
`The user made the following updates to your content:\n\n${userDiff}\n\nThe updated content, which includes both your original modifications and the user's additional edits, has been successfully saved to ${relPath}. Note this does not mean you need to re-write the file with the user's changes, they have already been applied to the file.`
),
]
} else {
return this.formatToolResult(`The content was successfully saved to ${relPath}.`)
return [false, await this.formatToolResult(`The content was successfully saved to ${relPath}.`)]
}
} catch (error) {
const errorString = `Error writing file: ${JSON.stringify(serializeError(error))}`
@@ -1052,7 +1059,7 @@ export class ClaudeDev {
"error",
`Error writing file:\n${error.message ?? JSON.stringify(serializeError(error), null, 2)}`
)
return await this.formatToolError(errorString)
return [false, await this.formatToolError(errorString)]
}
}
@@ -1125,10 +1132,10 @@ export class ClaudeDev {
}
}
async readFile(relPath?: string): Promise<ToolResponse> {
async readFile(relPath?: string): Promise<[boolean, ToolResponse]> {
if (relPath === undefined) {
this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("read_file", "path")
return [false, await this.sayAndCreateMissingParamError("read_file", "path")]
}
this.consecutiveMistakeCount = 0
try {
@@ -1147,27 +1154,30 @@ export class ClaudeDev {
if (response !== "yesButtonTapped") {
if (response === "messageResponse") {
await this.say("user_feedback", text, images)
return this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images)
return [
true,
this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images),
]
}
return await this.formatToolDenied()
return [true, await this.formatToolDenied()]
}
}
return content
return [false, content]
} catch (error) {
const errorString = `Error reading file: ${JSON.stringify(serializeError(error))}`
await this.say(
"error",
`Error reading file:\n${error.message ?? JSON.stringify(serializeError(error), null, 2)}`
)
return await this.formatToolError(errorString)
return [false, await this.formatToolError(errorString)]
}
}
async listFiles(relDirPath?: string, recursiveRaw?: string): Promise<ToolResponse> {
async listFiles(relDirPath?: string, recursiveRaw?: string): Promise<[boolean, ToolResponse]> {
if (relDirPath === undefined) {
this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("list_files", "path")
return [false, await this.sayAndCreateMissingParamError("list_files", "path")]
}
this.consecutiveMistakeCount = 0
try {
@@ -1188,13 +1198,16 @@ export class ClaudeDev {
if (response !== "yesButtonTapped") {
if (response === "messageResponse") {
await this.say("user_feedback", text, images)
return this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images)
return [
true,
this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images),
]
}
return await this.formatToolDenied()
return [true, await this.formatToolDenied()]
}
}
return this.formatToolResult(result)
return [false, await this.formatToolResult(result)]
} catch (error) {
const errorString = `Error listing files and directories: ${JSON.stringify(serializeError(error))}`
await this.say(
@@ -1203,7 +1216,7 @@ export class ClaudeDev {
error.message ?? JSON.stringify(serializeError(error), null, 2)
}`
)
return await this.formatToolError(errorString)
return [false, await this.formatToolError(errorString)]
}
}
@@ -1266,10 +1279,10 @@ export class ClaudeDev {
}
}
async listCodeDefinitionNames(relDirPath?: string): Promise<ToolResponse> {
async listCodeDefinitionNames(relDirPath?: string): Promise<[boolean, ToolResponse]> {
if (relDirPath === undefined) {
this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("list_code_definition_names", "path")
return [false, await this.sayAndCreateMissingParamError("list_code_definition_names", "path")]
}
this.consecutiveMistakeCount = 0
try {
@@ -1288,13 +1301,16 @@ export class ClaudeDev {
if (response !== "yesButtonTapped") {
if (response === "messageResponse") {
await this.say("user_feedback", text, images)
return this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images)
return [
true,
this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images),
]
}
return await this.formatToolDenied()
return [true, await this.formatToolDenied()]
}
}
return this.formatToolResult(result)
return [false, await this.formatToolResult(result)]
} catch (error) {
const errorString = `Error parsing source code definitions: ${JSON.stringify(serializeError(error))}`
await this.say(
@@ -1303,18 +1319,18 @@ export class ClaudeDev {
error.message ?? JSON.stringify(serializeError(error), null, 2)
}`
)
return await this.formatToolError(errorString)
return [false, await this.formatToolError(errorString)]
}
}
async searchFiles(relDirPath: string, regex: string, filePattern?: string): Promise<ToolResponse> {
async searchFiles(relDirPath: string, regex: string, filePattern?: string): Promise<[boolean, ToolResponse]> {
if (relDirPath === undefined) {
this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("search_files", "path")
return [false, await this.sayAndCreateMissingParamError("search_files", "path")]
}
if (regex === undefined) {
this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("search_files", "regex", relDirPath)
return [false, await this.sayAndCreateMissingParamError("search_files", "regex", relDirPath)]
}
this.consecutiveMistakeCount = 0
try {
@@ -1336,36 +1352,42 @@ export class ClaudeDev {
if (response !== "yesButtonTapped") {
if (response === "messageResponse") {
await this.say("user_feedback", text, images)
return this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images)
return [
true,
this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images),
]
}
return await this.formatToolDenied()
return [true, await this.formatToolDenied()]
}
}
return this.formatToolResult(results)
return [false, await this.formatToolResult(results)]
} catch (error) {
const errorString = `Error searching files: ${JSON.stringify(serializeError(error))}`
await this.say(
"error",
`Error searching files:\n${error.message ?? JSON.stringify(serializeError(error), null, 2)}`
)
return await this.formatToolError(errorString)
return [false, await this.formatToolError(errorString)]
}
}
async executeCommand(command?: string, returnEmptyStringOnSuccess: boolean = false): Promise<ToolResponse> {
async executeCommand(
command?: string,
returnEmptyStringOnSuccess: boolean = false
): Promise<[boolean, ToolResponse]> {
if (command === undefined) {
this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("execute_command", "command")
return [false, await this.sayAndCreateMissingParamError("execute_command", "command")]
}
this.consecutiveMistakeCount = 0
const { response, text, images } = await this.ask("command", command)
if (response !== "yesButtonTapped") {
if (response === "messageResponse") {
await this.say("user_feedback", text, images)
return this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images)
return [true, this.formatToolResponseWithImages(await this.formatToolDeniedFeedback(text), images)]
}
return await this.formatToolDenied()
return [true, await this.formatToolDenied()]
}
try {
@@ -1422,75 +1444,85 @@ export class ClaudeDev {
if (userFeedback) {
await this.say("user_feedback", userFeedback.text, userFeedback.images)
return this.formatToolResponseWithImages(
`Command is still running in the user's terminal.${
result.length > 0 ? `\nHere's the output so far:\n${result}` : ""
}\n\nThe user provided the following feedback:\n<feedback>\n${userFeedback.text}\n</feedback>`,
userFeedback.images
)
return [
true,
this.formatToolResponseWithImages(
`Command is still running in the user's terminal.${
result.length > 0 ? `\nHere's the output so far:\n${result}` : ""
}\n\nThe user provided the following feedback:\n<feedback>\n${userFeedback.text}\n</feedback>`,
userFeedback.images
),
]
}
// for attemptCompletion, we don't want to return the command output
if (returnEmptyStringOnSuccess) {
return ""
return [false, ""]
}
if (completed) {
return await this.formatToolResult(
`Command executed.${result.length > 0 ? `\nOutput:\n${result}` : ""}`
)
return [
false,
await this.formatToolResult(`Command executed.${result.length > 0 ? `\nOutput:\n${result}` : ""}`),
]
} else {
return await this.formatToolResult(
`Command is still running in the user's terminal.${
result.length > 0 ? `\nHere's the output so far:\n${result}` : ""
}\n\nYou will be updated on the terminal status and new output in the future.`
)
return [
false,
await this.formatToolResult(
`Command is still running in the user's terminal.${
result.length > 0 ? `\nHere's the output so far:\n${result}` : ""
}\n\nYou will be updated on the terminal status and new output in the future.`
),
]
}
} catch (error) {
let errorMessage = error.message || JSON.stringify(serializeError(error), null, 2)
const errorString = `Error executing command:\n${errorMessage}`
await this.say("error", `Error executing command:\n${errorMessage}`)
return await this.formatToolError(errorString)
return [false, await this.formatToolError(errorString)]
}
}
async askFollowupQuestion(question?: string): Promise<ToolResponse> {
async askFollowupQuestion(question?: string): Promise<[boolean, ToolResponse]> {
if (question === undefined) {
this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("ask_followup_question", "question")
return [false, await this.sayAndCreateMissingParamError("ask_followup_question", "question")]
}
this.consecutiveMistakeCount = 0
const { text, images } = await this.ask("followup", question)
await this.say("user_feedback", text ?? "", images)
return this.formatToolResponseWithImages(`<answer>\n${text}\n</answer>`, images)
return [false, this.formatToolResponseWithImages(`<answer>\n${text}\n</answer>`, images)]
}
async attemptCompletion(result?: string, command?: string): Promise<ToolResponse> {
async attemptCompletion(result?: string, command?: string): Promise<[boolean, ToolResponse]> {
// result is required, command is optional
if (result === undefined) {
this.consecutiveMistakeCount++
return await this.sayAndCreateMissingParamError("attempt_completion", "result")
return [false, await this.sayAndCreateMissingParamError("attempt_completion", "result")]
}
this.consecutiveMistakeCount = 0
let resultToSend = result
if (command) {
await this.say("completion_result", resultToSend)
// TODO: currently we don't handle if this command fails, it could be useful to let claude know and retry
const commandResult = await this.executeCommand(command, true)
const [didUserReject, commandResult] = await this.executeCommand(command, true)
// if we received non-empty string, the command was rejected or failed
if (commandResult) {
return commandResult
return [didUserReject, commandResult]
}
resultToSend = ""
}
const { response, text, images } = await this.ask("completion_result", resultToSend) // this prompts webview to show 'new task' button, and enable text input (which would be the 'text' here)
if (response === "yesButtonTapped") {
return "" // signals to recursive loop to stop (for now this never happens since yesButtonTapped will trigger a new task)
return [false, ""] // signals to recursive loop to stop (for now this never happens since yesButtonTapped will trigger a new task)
}
await this.say("user_feedback", text ?? "", images)
return this.formatToolResponseWithImages(
`The user has provided feedback on the results. Consider their input to continue the task, and then attempt completion again.\n<feedback>\n${text}\n</feedback>`,
images
)
return [
true,
this.formatToolResponseWithImages(
`The user has provided feedback on the results. Consider their input to continue the task, and then attempt completion again.\n<feedback>\n${text}\n</feedback>`,
images
),
]
}
async attemptApiRequest(): Promise<Anthropic.Messages.Message> {
@@ -1668,21 +1700,31 @@ ${this.customInstructions.trim()}
let toolResults: Anthropic.ToolResultBlockParam[] = []
let attemptCompletionBlock: Anthropic.Messages.ToolUseBlock | undefined
let userRejectedATool = false
for (const contentBlock of response.content) {
if (contentBlock.type === "tool_use") {
const toolName = contentBlock.name as ToolName
const toolInput = contentBlock.input
const toolUseId = contentBlock.id
if (userRejectedATool) {
toolResults.push({
type: "tool_result",
tool_use_id: toolUseId,
content: "Skipping tool execution due to previous tool user rejection.",
})
continue
}
if (toolName === "attempt_completion") {
attemptCompletionBlock = contentBlock
} else {
// NOTE: while anthropic sdk accepts string or array of string/image, openai sdk (openrouter) only accepts a string
const result = await this.executeTool(toolName, toolInput)
// this.say(
// "tool",
// `\nTool Used: ${toolName}\nTool Input: ${JSON.stringify(toolInput)}\nTool Result: ${result}`
// )
const [didUserReject, result] = await this.executeTool(toolName, toolInput)
toolResults.push({ type: "tool_result", tool_use_id: toolUseId, content: result })
if (didUserReject) {
userRejectedATool = true
}
}
}
}
@@ -1692,7 +1734,7 @@ ${this.customInstructions.trim()}
// attempt_completion is always done last, since there might have been other tools that needed to be called first before the job is finished
// it's important to note that claude will order the tools logically in most cases, so we don't have to think about which tools make sense calling before others
if (attemptCompletionBlock) {
let result = await this.executeTool(
let [_, result] = await this.executeTool(
attemptCompletionBlock.name as ToolName,
attemptCompletionBlock.input
)