diff --git a/src/api/index.ts b/src/api/index.ts index c468b97..cda63a8 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -1,5 +1,4 @@ import { Anthropic } from "@anthropic-ai/sdk" -import { Stream } from "@anthropic-ai/sdk/streaming" import { ApiConfiguration, ModelInfo } from "../shared/api" import { AnthropicHandler } from "./providers/anthropic" import { AwsBedrockHandler } from "./providers/bedrock" @@ -9,15 +8,14 @@ import { OpenAiHandler } from "./providers/openai" import { OllamaHandler } from "./providers/ollama" import { GeminiHandler } from "./providers/gemini" import { OpenAiNativeHandler } from "./providers/openai-native" - -export type AnthropicStream = Stream +import { ApiStream } from "./transform/stream" export interface ApiHandler { createMessage( systemPrompt: string, messages: Anthropic.Messages.MessageParam[], tools: Anthropic.Messages.Tool[] - ): Promise + ): ApiStream getModel(): { id: string; info: ModelInfo } } diff --git a/src/api/providers/anthropic.ts b/src/api/providers/anthropic.ts index 79a6963..d01a44b 100644 --- a/src/api/providers/anthropic.ts +++ b/src/api/providers/anthropic.ts @@ -1,5 +1,5 @@ import { Anthropic } from "@anthropic-ai/sdk" -import { AnthropicStream, ApiHandler } from "../index" +import { Stream as AnthropicStream } from "@anthropic-ai/sdk/streaming" import { anthropicDefaultModelId, AnthropicModelId, @@ -7,6 +7,8 @@ import { ApiHandlerOptions, ModelInfo, } from "../../shared/api" +import { ApiHandler } from "../index" +import { ApiStream } from "../transform/stream" export class AnthropicHandler implements ApiHandler { private options: ApiHandlerOptions @@ -20,11 +22,8 @@ export class AnthropicHandler implements ApiHandler { }) } - async createMessage( - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - tools: Anthropic.Messages.Tool[] - ): Promise { + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { + let stream: AnthropicStream const modelId = this.getModel().id switch (modelId) { case "claude-3-5-sonnet-20240620": @@ -39,7 +38,7 @@ export class AnthropicHandler implements ApiHandler { ) const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 - const stream = this.client.beta.promptCaching.messages.create( + stream = await this.client.beta.promptCaching.messages.create( { model: modelId, max_tokens: this.getModel().info.maxTokens, @@ -92,14 +91,10 @@ export class AnthropicHandler implements ApiHandler { } })() ) - - return stream - - // throw new Error("Not implemented") - // return { message } + break } default: { - const stream = await this.client.messages.create({ + stream = (await this.client.messages.create({ model: modelId, max_tokens: this.getModel().info.maxTokens, temperature: 0, @@ -108,8 +103,67 @@ export class AnthropicHandler implements ApiHandler { // tools, // tool_choice: { type: "auto" }, stream: true, - }) - return stream as AnthropicStream + })) as any + break + } + } + + for await (const chunk of stream) { + switch (chunk.type) { + case "message_start": + // tells us cache reads/writes/input/output + const usage = chunk.message.usage + yield { + type: "usage", + inputTokens: usage.input_tokens || 0, + outputTokens: usage.output_tokens || 0, + cacheWriteTokens: usage.cache_creation_input_tokens || 0, + cacheReadTokens: usage.cache_read_input_tokens || 0, + } + break + case "message_delta": + // tells us stop_reason, stop_sequence, and output tokens along the way and at the end of the message + + yield { + type: "usage", + inputTokens: 0, + outputTokens: chunk.usage.output_tokens || 0, + cacheWriteTokens: 0, + cacheReadTokens: 0, + } + break + case "message_stop": + // no usage data, just an indicator that the message is done + break + case "content_block_start": + switch (chunk.content_block.type) { + case "text": + // we may receive multiple text blocks, in which case just insert a line break between them + if (chunk.index > 0) { + yield { + type: "text", + text: "\n", + } + } + yield { + type: "text", + text: chunk.content_block.text, + } + break + } + break + case "content_block_delta": + switch (chunk.delta.type) { + case "text_delta": + yield { + type: "text", + text: chunk.delta.text, + } + break + } + break + case "content_block_stop": + break } } } diff --git a/src/api/transform/stream.ts b/src/api/transform/stream.ts new file mode 100644 index 0000000..4f8c562 --- /dev/null +++ b/src/api/transform/stream.ts @@ -0,0 +1,16 @@ +export type ApiStream = AsyncGenerator +export type ApiStreamChunk = ApiStreamTextChunk | ApiStreamUsageChunk + +export interface ApiStreamTextChunk { + type: "text" + text: string +} + +export interface ApiStreamUsageChunk { + type: "usage" + inputTokens: number + outputTokens: number + cacheWriteTokens: number + cacheReadTokens: number + totalCost?: number // openrouter +} diff --git a/src/core/ClaudeDev.ts b/src/core/ClaudeDev.ts index bae78ae..53398d7 100644 --- a/src/core/ClaudeDev.ts +++ b/src/core/ClaudeDev.ts @@ -8,7 +8,7 @@ import pWaitFor from "p-wait-for" import * as path from "path" import { serializeError } from "serialize-error" import * as vscode from "vscode" -import { AnthropicStream, ApiHandler, buildApiHandler } from "../api" +import { ApiHandler, buildApiHandler } from "../api" import { diagnosticsToProblemsString, getNewDiagnostics } from "../integrations/diagnostics" import { formatContentBlockToMarkdown } from "../integrations/misc/export-markdown" import { extractTextFromFile } from "../integrations/misc/extract-text" @@ -41,6 +41,7 @@ import { SYSTEM_PROMPT } from "./prompts/system" import { TOOLS } from "./prompts/tools" import { truncateHalfConversation } from "./sliding-window" import { ClaudeDevProvider } from "./webview/ClaudeDevProvider" +import { ApiStream } from "../api/transform/stream" const cwd = vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0) ?? path.join(os.homedir(), "Desktop") // may or may not exist but fs checking existence would immediately ask for permission which would be bad UX, need to come up with a better solution @@ -1205,7 +1206,7 @@ export class ClaudeDev { } } - async attemptApiRequest(previousApiReqIndex: number): Promise { + async attemptApiRequest(previousApiReqIndex: number): Promise { try { let systemPrompt = await SYSTEM_PROMPT(cwd, this.api.getModel().info.supportsImages) if (this.customInstructions && this.customInstructions.trim()) { @@ -1241,7 +1242,7 @@ ${this.customInstructions.trim()} } } } - const stream = await this.api.createMessage( + const stream = this.api.createMessage( systemPrompt, this.apiConversationHistory, TOOLS(cwd, this.api.getModel().info.supportsImages) @@ -2198,7 +2199,6 @@ ${this.customInstructions.trim()} private didRejectTool = false private presentAssistantMessageLocked = false private presentAssistantMessageHasPendingUpdates = false - private parseTextStreamAccumulator = "" //edit private isEditingExistingFile: boolean | undefined private isEditingFile = false @@ -2206,9 +2206,7 @@ ${this.customInstructions.trim()} private editFileCreatedDirs: string[] = [] private editFileDocumentWasOpen = false - parseTextStream(chunk: string) { - this.parseTextStreamAccumulator += chunk - + parseAssistantMessage(assistantMessage: string) { // let text = "" let textContent: TextContent = { type: "text", @@ -2222,7 +2220,7 @@ ${this.customInstructions.trim()} let currentParamValueLines: string[] = [] let textContentLines: string[] = [] - const rawLines = this.parseTextStreamAccumulator.split("\n") + const rawLines = assistantMessage.split("\n") if (rawLines.length === 1) { const firstLine = rawLines[0].trim() @@ -2374,14 +2372,12 @@ ${this.customInstructions.trim()} try { const stream = await this.attemptApiRequest(previousApiReqIndex) - let cacheCreationInputTokens = 0 - let cacheReadInputTokens = 0 + let cacheWriteTokens = 0 + let cacheReadTokens = 0 let inputTokens = 0 let outputTokens = 0 + let totalCost: number | undefined - // todo add error listeners so we can return api error? or wil lfor await handle that below? - - let apiContentBlocks: Anthropic.ContentBlock[] = [] this.currentStreamingContentIndex = 0 this.assistantMessageContent = [] this.didCompleteReadingStream = false @@ -2390,7 +2386,7 @@ ${this.customInstructions.trim()} this.didRejectTool = false this.presentAssistantMessageLocked = false this.presentAssistantMessageHasPendingUpdates = false - this.parseTextStreamAccumulator = "" + // edit this.isEditingExistingFile = undefined this.isEditingFile = false @@ -2398,47 +2394,21 @@ ${this.customInstructions.trim()} this.editFileCreatedDirs = [] this.editFileDocumentWasOpen = false + let assistantMessage = "" + // TODO: handle error being thrown in stream for await (const chunk of stream) { switch (chunk.type) { - case "message_start": - // tells us cache reads/writes/input/output - const usage = chunk.message.usage - cacheCreationInputTokens += usage.cache_creation_input_tokens || 0 - cacheReadInputTokens += usage.cache_read_input_tokens || 0 - inputTokens += usage.input_tokens || 0 - outputTokens += usage.output_tokens || 0 + case "usage": + inputTokens += chunk.inputTokens + outputTokens += chunk.outputTokens + cacheWriteTokens += chunk.cacheWriteTokens + cacheReadTokens += chunk.cacheReadTokens + totalCost = chunk.totalCost break - case "message_delta": - // tells us stop_reason, stop_sequence, and output tokens along the way and at the end of the message - outputTokens += chunk.usage.output_tokens || 0 - break - case "message_stop": - // no usage data, just an indicator that the message is done - break - case "content_block_start": - // await delay(4_000) - switch (chunk.content_block.type) { - case "text": - apiContentBlocks.push(chunk.content_block) - // we may receive multiple text blocks, in which case just insert a line break between them - if (chunk.index > 0) { - this.parseTextStream("\n") - } - this.parseTextStream(chunk.content_block.text) - this.presentAssistantMessage() - break - } - break - case "content_block_delta": - switch (chunk.delta.type) { - case "text_delta": - ;(apiContentBlocks[chunk.index] as Anthropic.TextBlock).text += chunk.delta.text - this.parseTextStream(chunk.delta.text) - this.presentAssistantMessage() - break - } - break - case "content_block_stop": + case "text": + assistantMessage += chunk.text + this.parseAssistantMessage(assistantMessage) + this.presentAssistantMessage() break } } @@ -2454,7 +2424,6 @@ ${this.customInstructions.trim()} this.presentAssistantMessage() // if there is content to update then it will complete and update this.userMessageContentReady to true, which we pwaitfor before making the next request } - let totalCost: string | undefined // let inputTokens = response.usage.input_tokens // let outputTokens = response.usage.output_tokens // let cacheCreationInputTokens = @@ -2473,11 +2442,9 @@ ${this.customInstructions.trim()} ...JSON.parse(this.claudeMessages[lastApiReqIndex].text), tokensIn: inputTokens, tokensOut: outputTokens, - cacheWrites: cacheCreationInputTokens, - cacheReads: cacheReadInputTokens, - cost: - totalCost || - this.calculateApiCost(inputTokens, outputTokens, cacheCreationInputTokens, cacheReadInputTokens), + cacheWrites: cacheWriteTokens, + cacheReads: cacheReadTokens, + cost: totalCost ?? this.calculateApiCost(inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens), }) await this.saveClaudeMessages() await this.providerRef.deref()?.postStateToWebview() @@ -2485,25 +2452,17 @@ ${this.customInstructions.trim()} // now add to apiconversationhistory // need to save assistant responses to file before proceeding to tool use since user can exit at any moment and we wouldn't be able to save the assistant's response let didEndLoop = false - if (apiContentBlocks.length > 0) { - // Remove 'partial' prop from assistantContentBlocks - // const blocksWithoutPartial: Anthropic.Messages.ContentBlock[] = this.assistantContentBlocks.map( - // (block) => { - // const { partial, ...rest } = block - // return rest - // } - // ) - await this.addToApiConversationHistory({ role: "assistant", content: apiContentBlocks }) - - // incase the content blocks finished + if (assistantMessage.length > 0) { + await this.addToApiConversationHistory({ + role: "assistant", + content: [{ type: "text", text: assistantMessage }], + }) + // in case the content blocks finished // it may be the api stream finished after the last parsed content block was executed, so we are able to detect out of bounds and set userMessageContentReady to true (not you should not call presentAssistantMessage since if the last block is completed it will be presented again) - const completeBlocks = this.assistantMessageContent.filter((block) => !block.partial) // if there are any partial blocks after the stream ended we can consider them invalid - if (this.currentStreamingContentIndex >= completeBlocks.length) { this.userMessageContentReady = true - //throw new Error("No more content blocks to stream! This shouldn't happen...") // remove and just return after testing } await pWaitFor(() => this.userMessageContentReady) @@ -2511,7 +2470,7 @@ ${this.customInstructions.trim()} const recDidEndLoop = await this.recursivelyMakeClaudeRequests(this.userMessageContent) didEndLoop = recDidEndLoop } else { - // this should never happen! it there's no assistant_responses, that means we got no text or tool_use content blocks from API which we should assume is an error + // if there's no assistant_responses, that means we got no text or tool_use content blocks from API which we should assume is an error await this.say( "error", "Unexpected API Response: The language model did not provide any assistant messages. This may indicate an issue with the API or the model's output."