From 9b1b9c10a1ea36b8e4f0b27539f15cfe2db2f19b Mon Sep 17 00:00:00 2001 From: Saoud Rizwan <7799382+saoudrizwan@users.noreply.github.com> Date: Mon, 30 Sep 2024 18:47:33 -0400 Subject: [PATCH] Add Cancel button --- src/core/ClaudeDev.ts | 135 ++++++++++++-------- src/core/webview/ClaudeDevProvider.ts | 13 ++ src/shared/ExtensionMessage.ts | 10 ++ src/shared/WebviewMessage.ts | 1 + webview-ui/src/components/chat/ChatRow.tsx | 30 +++-- webview-ui/src/components/chat/ChatView.tsx | 46 ++++++- 6 files changed, 167 insertions(+), 68 deletions(-) diff --git a/src/core/ClaudeDev.ts b/src/core/ClaudeDev.ts index a875df6..11073f0 100644 --- a/src/core/ClaudeDev.ts +++ b/src/core/ClaudeDev.ts @@ -21,7 +21,7 @@ import { ApiConfiguration } from "../shared/api" import { findLastIndex } from "../shared/array" import { combineApiRequests } from "../shared/combineApiRequests" import { combineCommandSequences } from "../shared/combineCommandSequences" -import { ClaudeAsk, ClaudeMessage, ClaudeSay, ClaudeSayTool } from "../shared/ExtensionMessage" +import { ClaudeApiReqInfo, ClaudeAsk, ClaudeMessage, ClaudeSay, ClaudeSayTool } from "../shared/ExtensionMessage" import { getApiMetrics } from "../shared/getApiMetrics" import { HistoryItem } from "../shared/HistoryItem" import { ToolName } from "../shared/Tool" @@ -69,6 +69,7 @@ export class ClaudeDev { private consecutiveMistakeCount: number = 0 private providerRef: WeakRef private abort: boolean = false + didFinishAborting = false private diffViewProvider: DiffViewProvider // streaming @@ -381,19 +382,6 @@ export class ClaudeDev { private async resumeTaskFromHistory() { const modifiedClaudeMessages = await this.getSavedClaudeMessages() - // Need to modify claude messages for good ux, i.e. if the last message is an api_request_started, then remove it otherwise the user will think the request is still loading - const lastApiReqStartedIndex = modifiedClaudeMessages.reduce( - (lastIndex, m, index) => (m.type === "say" && m.say === "api_req_started" ? index : lastIndex), - -1 - ) - const lastApiReqFinishedIndex = modifiedClaudeMessages.reduce( - (lastIndex, m, index) => (m.type === "say" && m.say === "api_req_finished" ? index : lastIndex), - -1 - ) - if (lastApiReqStartedIndex > lastApiReqFinishedIndex && lastApiReqStartedIndex !== -1) { - modifiedClaudeMessages.splice(lastApiReqStartedIndex, 1) - } - // Remove any resume messages that may have been added before const lastRelevantMessageIndex = findLastIndex( modifiedClaudeMessages, @@ -403,6 +391,23 @@ export class ClaudeDev { modifiedClaudeMessages.splice(lastRelevantMessageIndex + 1) } + // if the last message is an api_req_started it means there was no partial content streamed, so we remove it + if (modifiedClaudeMessages.at(-1)?.say === "api_req_started") { + modifiedClaudeMessages.pop() + } + // since we don't use api_req_finished anymore, we need to check if the last api_req_started has a cost value, if it doesn't and it's not cancelled, then we remove it since it indicates an api request without any partial content streamed + // const lastApiReqStartedIndex = findLastIndex( + // modifiedClaudeMessages, + // (m) => m.type === "say" && m.say === "api_req_started" + // ) + // if (lastApiReqStartedIndex !== -1) { + // const lastApiReqStarted = modifiedClaudeMessages[lastApiReqStartedIndex] + // const { cost, cancelled }: ClaudeApiReqInfo = JSON.parse(lastApiReqStarted.text || "{}") + // if (cost === undefined || cancelled) { + // modifiedClaudeMessages.splice(lastApiReqStartedIndex, 1) + // } + // } + await this.overwriteClaudeMessages(modifiedClaudeMessages) this.claudeMessages = await this.getSavedClaudeMessages() @@ -698,13 +703,9 @@ export class ClaudeDev { if (previousApiReqIndex >= 0) { const previousRequest = this.claudeMessages[previousApiReqIndex] if (previousRequest && previousRequest.text) { - const { - tokensIn, - tokensOut, - cacheWrites, - cacheReads, - }: { tokensIn?: number; tokensOut?: number; cacheWrites?: number; cacheReads?: number } = - JSON.parse(previousRequest.text) + const { tokensIn, tokensOut, cacheWrites, cacheReads }: ClaudeApiReqInfo = JSON.parse( + previousRequest.text + ) const totalTokens = (tokensIn || 0) + (tokensOut || 0) + (cacheWrites || 0) + (cacheReads || 0) const contextWindow = this.api.getModel().info.contextWindow const maxAllowedSize = Math.max(contextWindow - 40_000, contextWindow * 0.8) @@ -1584,7 +1585,7 @@ export class ClaudeDev { request: userContent .map((block) => formatContentBlockToMarkdown(block, this.apiConversationHistory)) .join("\n\n"), - }) + } satisfies ClaudeApiReqInfo) await this.saveClaudeMessages() await this.providerRef.deref()?.postStateToWebview() @@ -1596,6 +1597,29 @@ export class ClaudeDev { let outputTokens = 0 let totalCost: number | undefined + // update api_req_started. we can't use api_req_finished anymore since it's a unique case where it could come after a streaming message (ie in the middle of being updated or executed) + // fortunately api_req_finished was always parsed out for the gui anyways, so it remains solely for legacy purposes to keep track of prices in tasks from history + // (it's worth removing a few months from now) + const updateApiReqMsg = (cancelled?: boolean) => { + this.claudeMessages[lastApiReqIndex].text = JSON.stringify({ + ...JSON.parse(this.claudeMessages[lastApiReqIndex].text || "{}"), + tokensIn: inputTokens, + tokensOut: outputTokens, + cacheWrites: cacheWriteTokens, + cacheReads: cacheReadTokens, + cost: + totalCost ?? + calculateApiCost( + this.api.getModel().info, + inputTokens, + outputTokens, + cacheWriteTokens, + cacheReadTokens + ), + cancelled, + } satisfies ClaudeApiReqInfo) + } + // reset streaming state this.currentStreamingContentIndex = 0 this.assistantMessageContent = [] @@ -1624,6 +1648,42 @@ export class ClaudeDev { this.presentAssistantMessage() break } + + if (this.abort) { + console.log("aborting stream...") + if (this.diffViewProvider.isEditing) { + await this.diffViewProvider.revertChanges() // closes diff view + } + + // if last message is a partial we need to save it + const lastMessage = this.claudeMessages.at(-1) + if (lastMessage && lastMessage.partial) { + lastMessage.ts = Date.now() + lastMessage.partial = false + // instead of streaming partialMessage events, we do a save and post like normal to persist to disk + console.log("saving messages...", lastMessage) + // await this.saveClaudeMessages() + } + + // + await this.addToApiConversationHistory({ + role: "assistant", + content: [{ type: "text", text: assistantMessage + "\n\n[Response interrupted by user]" }], + }) + + // update api_req_started to have cancelled and cost, so that we can display the cost of the partial stream + updateApiReqMsg(true) + await this.saveClaudeMessages() + + // signals to provider that it can retrieve the saved messages from disk, as abortTask can not be awaited on in nature + this.didFinishAborting = true + break // aborts the stream + } + } + + // need to call here in case the stream was aborted + if (this.abort) { + throw new Error("ClaudeDev instance aborted") } this.didCompleteReadingStream = true @@ -1637,36 +1697,7 @@ export class ClaudeDev { this.presentAssistantMessage() // if there is content to update then it will complete and update this.userMessageContentReady to true, which we pwaitfor before making the next request } - // let inputTokens = response.usage.input_tokens - // let outputTokens = response.usage.output_tokens - // let cacheCreationInputTokens = - // (response as Anthropic.Beta.PromptCaching.Messages.PromptCachingBetaMessage).usage - // .cache_creation_input_tokens || undefined - // let cacheReadInputTokens = - // (response as Anthropic.Beta.PromptCaching.Messages.PromptCachingBetaMessage).usage - // .cache_read_input_tokens || undefined - // @ts-ignore-next-line - // let totalCost = response.usage.total_cost - - // update api_req_started. we can't use api_req_finished anymore since it's a unique case where it could come after a streaming message (ie in the middle of being updated or executed) - // fortunately api_req_finished was always parsed out for the gui anyways, so it remains solely for legacy purposes to keep track of prices in tasks from history - // (it's worth removing a few months from now) - this.claudeMessages[lastApiReqIndex].text = JSON.stringify({ - ...JSON.parse(this.claudeMessages[lastApiReqIndex].text), - tokensIn: inputTokens, - tokensOut: outputTokens, - cacheWrites: cacheWriteTokens, - cacheReads: cacheReadTokens, - cost: - totalCost ?? - calculateApiCost( - this.api.getModel().info, - inputTokens, - outputTokens, - cacheWriteTokens, - cacheReadTokens - ), - }) + updateApiReqMsg() await this.saveClaudeMessages() await this.providerRef.deref()?.postStateToWebview() diff --git a/src/core/webview/ClaudeDevProvider.ts b/src/core/webview/ClaudeDevProvider.ts index c959e4f..cfa2092 100644 --- a/src/core/webview/ClaudeDevProvider.ts +++ b/src/core/webview/ClaudeDevProvider.ts @@ -19,6 +19,7 @@ import WorkspaceTracker from "../../integrations/workspace/WorkspaceTracker" import { openMention } from "../mentions" import { fileExistsAtPath } from "../../utils/fs" import { buildApiHandler } from "../../api" +import pWaitFor from "p-wait-for" /* https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default/weather-webview/src/providers/WeatherViewProvider.ts @@ -441,6 +442,18 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider { break case "openMention": openMention(message.text) + break + case "cancelTask": + if (this.claudeDev) { + const { historyItem } = await this.getTaskWithId(this.claudeDev.taskId) + this.claudeDev.abortTask() + await pWaitFor(() => this.claudeDev === undefined || this.claudeDev.didFinishAborting, { + timeout: 3_000, + }) + await this.initClaudeDevWithHistoryItem(historyItem) // clears task again, so we need to abortTask manually above + await this.postStateToWebview() + } + break // Add more switch case statements here as more webview message commands // are created within the webview context (i.e. inside media/main.js) diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 072740c..38f917c 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -87,3 +87,13 @@ export interface ClaudeSayTool { regex?: string filePattern?: string } + +export interface ClaudeApiReqInfo { + request?: string + tokensIn?: number + tokensOut?: number + cacheWrites?: number + cacheReads?: number + cost?: number + cancelled?: boolean +} diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index ce678aa..51c663f 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -20,6 +20,7 @@ export interface WebviewMessage { | "openImage" | "openFile" | "openMention" + | "cancelTask" text?: string askResponse?: ClaudeAskResponse apiConfiguration?: ApiConfiguration diff --git a/webview-ui/src/components/chat/ChatRow.tsx b/webview-ui/src/components/chat/ChatRow.tsx index d6ce800..0231713 100644 --- a/webview-ui/src/components/chat/ChatRow.tsx +++ b/webview-ui/src/components/chat/ChatRow.tsx @@ -2,7 +2,7 @@ import { VSCodeBadge, VSCodeProgressRing } from "@vscode/webview-ui-toolkit/reac import deepEqual from "fast-deep-equal" import React, { memo, useMemo } from "react" import ReactMarkdown from "react-markdown" -import { ClaudeMessage, ClaudeSayTool } from "../../../../src/shared/ExtensionMessage" +import { ClaudeApiReqInfo, ClaudeMessage, ClaudeSayTool } from "../../../../src/shared/ExtensionMessage" import { COMMAND_OUTPUT_STRING } from "../../../../src/shared/combineCommandSequences" import { vscode } from "../../utils/vscode" import CodeAccordian, { removeLeadingNonAlphanumeric } from "../common/CodeAccordian" @@ -37,11 +37,12 @@ const ChatRow = memo( export default ChatRow const ChatRowContent = ({ message, isExpanded, onToggleExpand, lastModifiedMessage, isLast }: ChatRowProps) => { - const cost = useMemo(() => { + const [cost, apiReqCancelled] = useMemo(() => { if (message.text != null && message.say === "api_req_started") { - return JSON.parse(message.text).cost + const info: ClaudeApiReqInfo = JSON.parse(message.text) + return [info.cost, info.cancelled] } - return undefined + return [undefined, undefined] }, [message.text, message.say]) const apiRequestFailedMessage = isLast && lastModifiedMessage?.ask === "api_req_failed" // if request is retried then the latest message is a api_req_retried @@ -54,6 +55,7 @@ const ChatRowContent = ({ message, isExpanded, onToggleExpand, lastModifiedMessa const normalColor = "var(--vscode-foreground)" const errorColor = "var(--vscode-errorForeground)" const successColor = "var(--vscode-charts-green)" + const cancelledColor = "var(--vscode-descriptionForeground)" const [icon, title] = useMemo(() => { switch (type) { @@ -94,9 +96,15 @@ const ChatRowContent = ({ message, isExpanded, onToggleExpand, lastModifiedMessa case "api_req_started": return [ cost != null ? ( - + apiReqCancelled ? ( + + ) : ( + + ) ) : apiRequestFailedMessage ? ( ), cost != null ? ( - API Request + apiReqCancelled ? ( + API Request Cancelled + ) : ( + API Request + ) ) : apiRequestFailedMessage ? ( API Request Failed ) : ( @@ -122,7 +134,7 @@ const ChatRowContent = ({ message, isExpanded, onToggleExpand, lastModifiedMessa default: return [null, null] } - }, [type, cost, apiRequestFailedMessage, isCommandExecuting]) + }, [type, cost, apiRequestFailedMessage, isCommandExecuting, apiReqCancelled]) const headerStyle: React.CSSProperties = { display: "flex", diff --git a/webview-ui/src/components/chat/ChatView.tsx b/webview-ui/src/components/chat/ChatView.tsx index dab9e02..133f242 100644 --- a/webview-ui/src/components/chat/ChatView.tsx +++ b/webview-ui/src/components/chat/ChatView.tsx @@ -14,6 +14,7 @@ import ChatRow from "./ChatRow" import ChatTextArea from "./ChatTextArea" import HistoryPreview from "../history/HistoryPreview" import TaskHeader from "./TaskHeader" +import { findLast } from "../../../../src/shared/array" interface ChatViewProps { isHidden: boolean @@ -182,6 +183,24 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie } }, [messages.length]) + const isStreaming = useMemo(() => { + const isLastMessagePartial = modifiedMessages.at(-1)?.partial === true + if (isLastMessagePartial) { + return true + } else { + const lastApiReqStarted = findLast(modifiedMessages, (message) => message.say === "api_req_started") + if (lastApiReqStarted && lastApiReqStarted.text != null && lastApiReqStarted.say === "api_req_started") { + const cost = JSON.parse(lastApiReqStarted.text).cost + if (cost === undefined) { + // api request has not finished yet + return true + } + } + } + + return false + }, [modifiedMessages]) + const handleSendMessage = useCallback( (text: string, images: string[]) => { text = text.trim() @@ -251,6 +270,11 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie }, [claudeAsk, startNewTask]) const handleSecondaryButtonClick = useCallback(() => { + if (isStreaming) { + vscode.postMessage({ type: "cancelTask" }) + return + } + switch (claudeAsk) { case "api_req_failed": case "mistake_limit_reached": @@ -267,7 +291,7 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie setEnableButtons(false) // setPrimaryButtonText(undefined) // setSecondaryButtonText(undefined) - }, [claudeAsk, startNewTask]) + }, [claudeAsk, startNewTask, isStreaming]) const handleTaskCloseButtonClick = useCallback(() => { startNewTask() @@ -544,11 +568,16 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie />
- {primaryButtonText && ( + {primaryButtonText && !isStreaming && ( )} - {secondaryButtonText && ( + {(secondaryButtonText || isStreaming) && ( - {secondaryButtonText} + {isStreaming ? "Cancel" : secondaryButtonText} )}