Add Cancel button

This commit is contained in:
Saoud Rizwan
2024-09-30 18:47:33 -04:00
parent c2a2e1b54c
commit 9b1b9c10a1
6 changed files with 167 additions and 68 deletions

View File

@@ -21,7 +21,7 @@ import { ApiConfiguration } from "../shared/api"
import { findLastIndex } from "../shared/array" import { findLastIndex } from "../shared/array"
import { combineApiRequests } from "../shared/combineApiRequests" import { combineApiRequests } from "../shared/combineApiRequests"
import { combineCommandSequences } from "../shared/combineCommandSequences" import { combineCommandSequences } from "../shared/combineCommandSequences"
import { ClaudeAsk, ClaudeMessage, ClaudeSay, ClaudeSayTool } from "../shared/ExtensionMessage" import { ClaudeApiReqInfo, ClaudeAsk, ClaudeMessage, ClaudeSay, ClaudeSayTool } from "../shared/ExtensionMessage"
import { getApiMetrics } from "../shared/getApiMetrics" import { getApiMetrics } from "../shared/getApiMetrics"
import { HistoryItem } from "../shared/HistoryItem" import { HistoryItem } from "../shared/HistoryItem"
import { ToolName } from "../shared/Tool" import { ToolName } from "../shared/Tool"
@@ -69,6 +69,7 @@ export class ClaudeDev {
private consecutiveMistakeCount: number = 0 private consecutiveMistakeCount: number = 0
private providerRef: WeakRef<ClaudeDevProvider> private providerRef: WeakRef<ClaudeDevProvider>
private abort: boolean = false private abort: boolean = false
didFinishAborting = false
private diffViewProvider: DiffViewProvider private diffViewProvider: DiffViewProvider
// streaming // streaming
@@ -381,19 +382,6 @@ export class ClaudeDev {
private async resumeTaskFromHistory() { private async resumeTaskFromHistory() {
const modifiedClaudeMessages = await this.getSavedClaudeMessages() const modifiedClaudeMessages = await this.getSavedClaudeMessages()
// Need to modify claude messages for good ux, i.e. if the last message is an api_request_started, then remove it otherwise the user will think the request is still loading
const lastApiReqStartedIndex = modifiedClaudeMessages.reduce(
(lastIndex, m, index) => (m.type === "say" && m.say === "api_req_started" ? index : lastIndex),
-1
)
const lastApiReqFinishedIndex = modifiedClaudeMessages.reduce(
(lastIndex, m, index) => (m.type === "say" && m.say === "api_req_finished" ? index : lastIndex),
-1
)
if (lastApiReqStartedIndex > lastApiReqFinishedIndex && lastApiReqStartedIndex !== -1) {
modifiedClaudeMessages.splice(lastApiReqStartedIndex, 1)
}
// Remove any resume messages that may have been added before // Remove any resume messages that may have been added before
const lastRelevantMessageIndex = findLastIndex( const lastRelevantMessageIndex = findLastIndex(
modifiedClaudeMessages, modifiedClaudeMessages,
@@ -403,6 +391,23 @@ export class ClaudeDev {
modifiedClaudeMessages.splice(lastRelevantMessageIndex + 1) modifiedClaudeMessages.splice(lastRelevantMessageIndex + 1)
} }
// if the last message is an api_req_started it means there was no partial content streamed, so we remove it
if (modifiedClaudeMessages.at(-1)?.say === "api_req_started") {
modifiedClaudeMessages.pop()
}
// since we don't use api_req_finished anymore, we need to check if the last api_req_started has a cost value, if it doesn't and it's not cancelled, then we remove it since it indicates an api request without any partial content streamed
// const lastApiReqStartedIndex = findLastIndex(
// modifiedClaudeMessages,
// (m) => m.type === "say" && m.say === "api_req_started"
// )
// if (lastApiReqStartedIndex !== -1) {
// const lastApiReqStarted = modifiedClaudeMessages[lastApiReqStartedIndex]
// const { cost, cancelled }: ClaudeApiReqInfo = JSON.parse(lastApiReqStarted.text || "{}")
// if (cost === undefined || cancelled) {
// modifiedClaudeMessages.splice(lastApiReqStartedIndex, 1)
// }
// }
await this.overwriteClaudeMessages(modifiedClaudeMessages) await this.overwriteClaudeMessages(modifiedClaudeMessages)
this.claudeMessages = await this.getSavedClaudeMessages() this.claudeMessages = await this.getSavedClaudeMessages()
@@ -698,13 +703,9 @@ export class ClaudeDev {
if (previousApiReqIndex >= 0) { if (previousApiReqIndex >= 0) {
const previousRequest = this.claudeMessages[previousApiReqIndex] const previousRequest = this.claudeMessages[previousApiReqIndex]
if (previousRequest && previousRequest.text) { if (previousRequest && previousRequest.text) {
const { const { tokensIn, tokensOut, cacheWrites, cacheReads }: ClaudeApiReqInfo = JSON.parse(
tokensIn, previousRequest.text
tokensOut, )
cacheWrites,
cacheReads,
}: { tokensIn?: number; tokensOut?: number; cacheWrites?: number; cacheReads?: number } =
JSON.parse(previousRequest.text)
const totalTokens = (tokensIn || 0) + (tokensOut || 0) + (cacheWrites || 0) + (cacheReads || 0) const totalTokens = (tokensIn || 0) + (tokensOut || 0) + (cacheWrites || 0) + (cacheReads || 0)
const contextWindow = this.api.getModel().info.contextWindow const contextWindow = this.api.getModel().info.contextWindow
const maxAllowedSize = Math.max(contextWindow - 40_000, contextWindow * 0.8) const maxAllowedSize = Math.max(contextWindow - 40_000, contextWindow * 0.8)
@@ -1584,7 +1585,7 @@ export class ClaudeDev {
request: userContent request: userContent
.map((block) => formatContentBlockToMarkdown(block, this.apiConversationHistory)) .map((block) => formatContentBlockToMarkdown(block, this.apiConversationHistory))
.join("\n\n"), .join("\n\n"),
}) } satisfies ClaudeApiReqInfo)
await this.saveClaudeMessages() await this.saveClaudeMessages()
await this.providerRef.deref()?.postStateToWebview() await this.providerRef.deref()?.postStateToWebview()
@@ -1596,6 +1597,29 @@ export class ClaudeDev {
let outputTokens = 0 let outputTokens = 0
let totalCost: number | undefined let totalCost: number | undefined
// update api_req_started. we can't use api_req_finished anymore since it's a unique case where it could come after a streaming message (ie in the middle of being updated or executed)
// fortunately api_req_finished was always parsed out for the gui anyways, so it remains solely for legacy purposes to keep track of prices in tasks from history
// (it's worth removing a few months from now)
const updateApiReqMsg = (cancelled?: boolean) => {
this.claudeMessages[lastApiReqIndex].text = JSON.stringify({
...JSON.parse(this.claudeMessages[lastApiReqIndex].text || "{}"),
tokensIn: inputTokens,
tokensOut: outputTokens,
cacheWrites: cacheWriteTokens,
cacheReads: cacheReadTokens,
cost:
totalCost ??
calculateApiCost(
this.api.getModel().info,
inputTokens,
outputTokens,
cacheWriteTokens,
cacheReadTokens
),
cancelled,
} satisfies ClaudeApiReqInfo)
}
// reset streaming state // reset streaming state
this.currentStreamingContentIndex = 0 this.currentStreamingContentIndex = 0
this.assistantMessageContent = [] this.assistantMessageContent = []
@@ -1624,6 +1648,42 @@ export class ClaudeDev {
this.presentAssistantMessage() this.presentAssistantMessage()
break break
} }
if (this.abort) {
console.log("aborting stream...")
if (this.diffViewProvider.isEditing) {
await this.diffViewProvider.revertChanges() // closes diff view
}
// if last message is a partial we need to save it
const lastMessage = this.claudeMessages.at(-1)
if (lastMessage && lastMessage.partial) {
lastMessage.ts = Date.now()
lastMessage.partial = false
// instead of streaming partialMessage events, we do a save and post like normal to persist to disk
console.log("saving messages...", lastMessage)
// await this.saveClaudeMessages()
}
//
await this.addToApiConversationHistory({
role: "assistant",
content: [{ type: "text", text: assistantMessage + "\n\n[Response interrupted by user]" }],
})
// update api_req_started to have cancelled and cost, so that we can display the cost of the partial stream
updateApiReqMsg(true)
await this.saveClaudeMessages()
// signals to provider that it can retrieve the saved messages from disk, as abortTask can not be awaited on in nature
this.didFinishAborting = true
break // aborts the stream
}
}
// need to call here in case the stream was aborted
if (this.abort) {
throw new Error("ClaudeDev instance aborted")
} }
this.didCompleteReadingStream = true this.didCompleteReadingStream = true
@@ -1637,36 +1697,7 @@ export class ClaudeDev {
this.presentAssistantMessage() // if there is content to update then it will complete and update this.userMessageContentReady to true, which we pwaitfor before making the next request this.presentAssistantMessage() // if there is content to update then it will complete and update this.userMessageContentReady to true, which we pwaitfor before making the next request
} }
// let inputTokens = response.usage.input_tokens updateApiReqMsg()
// let outputTokens = response.usage.output_tokens
// let cacheCreationInputTokens =
// (response as Anthropic.Beta.PromptCaching.Messages.PromptCachingBetaMessage).usage
// .cache_creation_input_tokens || undefined
// let cacheReadInputTokens =
// (response as Anthropic.Beta.PromptCaching.Messages.PromptCachingBetaMessage).usage
// .cache_read_input_tokens || undefined
// @ts-ignore-next-line
// let totalCost = response.usage.total_cost
// update api_req_started. we can't use api_req_finished anymore since it's a unique case where it could come after a streaming message (ie in the middle of being updated or executed)
// fortunately api_req_finished was always parsed out for the gui anyways, so it remains solely for legacy purposes to keep track of prices in tasks from history
// (it's worth removing a few months from now)
this.claudeMessages[lastApiReqIndex].text = JSON.stringify({
...JSON.parse(this.claudeMessages[lastApiReqIndex].text),
tokensIn: inputTokens,
tokensOut: outputTokens,
cacheWrites: cacheWriteTokens,
cacheReads: cacheReadTokens,
cost:
totalCost ??
calculateApiCost(
this.api.getModel().info,
inputTokens,
outputTokens,
cacheWriteTokens,
cacheReadTokens
),
})
await this.saveClaudeMessages() await this.saveClaudeMessages()
await this.providerRef.deref()?.postStateToWebview() await this.providerRef.deref()?.postStateToWebview()

View File

@@ -19,6 +19,7 @@ import WorkspaceTracker from "../../integrations/workspace/WorkspaceTracker"
import { openMention } from "../mentions" import { openMention } from "../mentions"
import { fileExistsAtPath } from "../../utils/fs" import { fileExistsAtPath } from "../../utils/fs"
import { buildApiHandler } from "../../api" import { buildApiHandler } from "../../api"
import pWaitFor from "p-wait-for"
/* /*
https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default/weather-webview/src/providers/WeatherViewProvider.ts https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default/weather-webview/src/providers/WeatherViewProvider.ts
@@ -441,6 +442,18 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
break break
case "openMention": case "openMention":
openMention(message.text) openMention(message.text)
break
case "cancelTask":
if (this.claudeDev) {
const { historyItem } = await this.getTaskWithId(this.claudeDev.taskId)
this.claudeDev.abortTask()
await pWaitFor(() => this.claudeDev === undefined || this.claudeDev.didFinishAborting, {
timeout: 3_000,
})
await this.initClaudeDevWithHistoryItem(historyItem) // clears task again, so we need to abortTask manually above
await this.postStateToWebview()
}
break break
// Add more switch case statements here as more webview message commands // Add more switch case statements here as more webview message commands
// are created within the webview context (i.e. inside media/main.js) // are created within the webview context (i.e. inside media/main.js)

View File

@@ -87,3 +87,13 @@ export interface ClaudeSayTool {
regex?: string regex?: string
filePattern?: string filePattern?: string
} }
export interface ClaudeApiReqInfo {
request?: string
tokensIn?: number
tokensOut?: number
cacheWrites?: number
cacheReads?: number
cost?: number
cancelled?: boolean
}

View File

@@ -20,6 +20,7 @@ export interface WebviewMessage {
| "openImage" | "openImage"
| "openFile" | "openFile"
| "openMention" | "openMention"
| "cancelTask"
text?: string text?: string
askResponse?: ClaudeAskResponse askResponse?: ClaudeAskResponse
apiConfiguration?: ApiConfiguration apiConfiguration?: ApiConfiguration

View File

@@ -2,7 +2,7 @@ import { VSCodeBadge, VSCodeProgressRing } from "@vscode/webview-ui-toolkit/reac
import deepEqual from "fast-deep-equal" import deepEqual from "fast-deep-equal"
import React, { memo, useMemo } from "react" import React, { memo, useMemo } from "react"
import ReactMarkdown from "react-markdown" import ReactMarkdown from "react-markdown"
import { ClaudeMessage, ClaudeSayTool } from "../../../../src/shared/ExtensionMessage" import { ClaudeApiReqInfo, ClaudeMessage, ClaudeSayTool } from "../../../../src/shared/ExtensionMessage"
import { COMMAND_OUTPUT_STRING } from "../../../../src/shared/combineCommandSequences" import { COMMAND_OUTPUT_STRING } from "../../../../src/shared/combineCommandSequences"
import { vscode } from "../../utils/vscode" import { vscode } from "../../utils/vscode"
import CodeAccordian, { removeLeadingNonAlphanumeric } from "../common/CodeAccordian" import CodeAccordian, { removeLeadingNonAlphanumeric } from "../common/CodeAccordian"
@@ -37,11 +37,12 @@ const ChatRow = memo(
export default ChatRow export default ChatRow
const ChatRowContent = ({ message, isExpanded, onToggleExpand, lastModifiedMessage, isLast }: ChatRowProps) => { const ChatRowContent = ({ message, isExpanded, onToggleExpand, lastModifiedMessage, isLast }: ChatRowProps) => {
const cost = useMemo(() => { const [cost, apiReqCancelled] = useMemo(() => {
if (message.text != null && message.say === "api_req_started") { if (message.text != null && message.say === "api_req_started") {
return JSON.parse(message.text).cost const info: ClaudeApiReqInfo = JSON.parse(message.text)
return [info.cost, info.cancelled]
} }
return undefined return [undefined, undefined]
}, [message.text, message.say]) }, [message.text, message.say])
const apiRequestFailedMessage = const apiRequestFailedMessage =
isLast && lastModifiedMessage?.ask === "api_req_failed" // if request is retried then the latest message is a api_req_retried isLast && lastModifiedMessage?.ask === "api_req_failed" // if request is retried then the latest message is a api_req_retried
@@ -54,6 +55,7 @@ const ChatRowContent = ({ message, isExpanded, onToggleExpand, lastModifiedMessa
const normalColor = "var(--vscode-foreground)" const normalColor = "var(--vscode-foreground)"
const errorColor = "var(--vscode-errorForeground)" const errorColor = "var(--vscode-errorForeground)"
const successColor = "var(--vscode-charts-green)" const successColor = "var(--vscode-charts-green)"
const cancelledColor = "var(--vscode-descriptionForeground)"
const [icon, title] = useMemo(() => { const [icon, title] = useMemo(() => {
switch (type) { switch (type) {
@@ -94,9 +96,15 @@ const ChatRowContent = ({ message, isExpanded, onToggleExpand, lastModifiedMessa
case "api_req_started": case "api_req_started":
return [ return [
cost != null ? ( cost != null ? (
apiReqCancelled ? (
<span
className="codicon codicon-error"
style={{ color: cancelledColor, marginBottom: "-1.5px" }}></span>
) : (
<span <span
className="codicon codicon-check" className="codicon codicon-check"
style={{ color: successColor, marginBottom: "-1.5px" }}></span> style={{ color: successColor, marginBottom: "-1.5px" }}></span>
)
) : apiRequestFailedMessage ? ( ) : apiRequestFailedMessage ? (
<span <span
className="codicon codicon-error" className="codicon codicon-error"
@@ -105,7 +113,11 @@ const ChatRowContent = ({ message, isExpanded, onToggleExpand, lastModifiedMessa
<ProgressIndicator /> <ProgressIndicator />
), ),
cost != null ? ( cost != null ? (
apiReqCancelled ? (
<span style={{ color: normalColor, fontWeight: "bold" }}>API Request Cancelled</span>
) : (
<span style={{ color: normalColor, fontWeight: "bold" }}>API Request</span> <span style={{ color: normalColor, fontWeight: "bold" }}>API Request</span>
)
) : apiRequestFailedMessage ? ( ) : apiRequestFailedMessage ? (
<span style={{ color: errorColor, fontWeight: "bold" }}>API Request Failed</span> <span style={{ color: errorColor, fontWeight: "bold" }}>API Request Failed</span>
) : ( ) : (
@@ -122,7 +134,7 @@ const ChatRowContent = ({ message, isExpanded, onToggleExpand, lastModifiedMessa
default: default:
return [null, null] return [null, null]
} }
}, [type, cost, apiRequestFailedMessage, isCommandExecuting]) }, [type, cost, apiRequestFailedMessage, isCommandExecuting, apiReqCancelled])
const headerStyle: React.CSSProperties = { const headerStyle: React.CSSProperties = {
display: "flex", display: "flex",

View File

@@ -14,6 +14,7 @@ import ChatRow from "./ChatRow"
import ChatTextArea from "./ChatTextArea" import ChatTextArea from "./ChatTextArea"
import HistoryPreview from "../history/HistoryPreview" import HistoryPreview from "../history/HistoryPreview"
import TaskHeader from "./TaskHeader" import TaskHeader from "./TaskHeader"
import { findLast } from "../../../../src/shared/array"
interface ChatViewProps { interface ChatViewProps {
isHidden: boolean isHidden: boolean
@@ -182,6 +183,24 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie
} }
}, [messages.length]) }, [messages.length])
const isStreaming = useMemo(() => {
const isLastMessagePartial = modifiedMessages.at(-1)?.partial === true
if (isLastMessagePartial) {
return true
} else {
const lastApiReqStarted = findLast(modifiedMessages, (message) => message.say === "api_req_started")
if (lastApiReqStarted && lastApiReqStarted.text != null && lastApiReqStarted.say === "api_req_started") {
const cost = JSON.parse(lastApiReqStarted.text).cost
if (cost === undefined) {
// api request has not finished yet
return true
}
}
}
return false
}, [modifiedMessages])
const handleSendMessage = useCallback( const handleSendMessage = useCallback(
(text: string, images: string[]) => { (text: string, images: string[]) => {
text = text.trim() text = text.trim()
@@ -251,6 +270,11 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie
}, [claudeAsk, startNewTask]) }, [claudeAsk, startNewTask])
const handleSecondaryButtonClick = useCallback(() => { const handleSecondaryButtonClick = useCallback(() => {
if (isStreaming) {
vscode.postMessage({ type: "cancelTask" })
return
}
switch (claudeAsk) { switch (claudeAsk) {
case "api_req_failed": case "api_req_failed":
case "mistake_limit_reached": case "mistake_limit_reached":
@@ -267,7 +291,7 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie
setEnableButtons(false) setEnableButtons(false)
// setPrimaryButtonText(undefined) // setPrimaryButtonText(undefined)
// setSecondaryButtonText(undefined) // setSecondaryButtonText(undefined)
}, [claudeAsk, startNewTask]) }, [claudeAsk, startNewTask, isStreaming])
const handleTaskCloseButtonClick = useCallback(() => { const handleTaskCloseButtonClick = useCallback(() => {
startNewTask() startNewTask()
@@ -544,11 +568,16 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie
/> />
<div <div
style={{ style={{
opacity: primaryButtonText || secondaryButtonText ? (enableButtons ? 1 : 0.5) : 0, opacity:
primaryButtonText || secondaryButtonText || isStreaming
? enableButtons || isStreaming
? 1
: 0.5
: 0,
display: "flex", display: "flex",
padding: "10px 15px 0px 15px", padding: "10px 15px 0px 15px",
}}> }}>
{primaryButtonText && ( {primaryButtonText && !isStreaming && (
<VSCodeButton <VSCodeButton
appearance="primary" appearance="primary"
disabled={!enableButtons} disabled={!enableButtons}
@@ -560,13 +589,16 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie
{primaryButtonText} {primaryButtonText}
</VSCodeButton> </VSCodeButton>
)} )}
{secondaryButtonText && ( {(secondaryButtonText || isStreaming) && (
<VSCodeButton <VSCodeButton
appearance="secondary" appearance="secondary"
disabled={!enableButtons} disabled={!enableButtons && !isStreaming}
style={{ flex: 1, marginLeft: "6px" }} style={{
flex: isStreaming ? 2 : 1,
marginLeft: isStreaming ? 0 : "6px",
}}
onClick={handleSecondaryButtonClick}> onClick={handleSecondaryButtonClick}>
{secondaryButtonText} {isStreaming ? "Cancel" : secondaryButtonText}
</VSCodeButton> </VSCodeButton>
)} )}
</div> </div>