Add ApiStream generator to interface with providers

This commit is contained in:
Saoud Rizwan
2024-09-28 22:17:50 -04:00
parent 19a0ac00bd
commit 7271152f62
4 changed files with 119 additions and 92 deletions

View File

@@ -1,5 +1,4 @@
import { Anthropic } from "@anthropic-ai/sdk" import { Anthropic } from "@anthropic-ai/sdk"
import { Stream } from "@anthropic-ai/sdk/streaming"
import { ApiConfiguration, ModelInfo } from "../shared/api" import { ApiConfiguration, ModelInfo } from "../shared/api"
import { AnthropicHandler } from "./providers/anthropic" import { AnthropicHandler } from "./providers/anthropic"
import { AwsBedrockHandler } from "./providers/bedrock" import { AwsBedrockHandler } from "./providers/bedrock"
@@ -9,15 +8,14 @@ import { OpenAiHandler } from "./providers/openai"
import { OllamaHandler } from "./providers/ollama" import { OllamaHandler } from "./providers/ollama"
import { GeminiHandler } from "./providers/gemini" import { GeminiHandler } from "./providers/gemini"
import { OpenAiNativeHandler } from "./providers/openai-native" import { OpenAiNativeHandler } from "./providers/openai-native"
import { ApiStream } from "./transform/stream"
export type AnthropicStream = Stream<Anthropic.Beta.PromptCaching.Messages.RawPromptCachingBetaMessageStreamEvent>
export interface ApiHandler { export interface ApiHandler {
createMessage( createMessage(
systemPrompt: string, systemPrompt: string,
messages: Anthropic.Messages.MessageParam[], messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[] tools: Anthropic.Messages.Tool[]
): Promise<AnthropicStream> ): ApiStream
getModel(): { id: string; info: ModelInfo } getModel(): { id: string; info: ModelInfo }
} }

View File

@@ -1,5 +1,5 @@
import { Anthropic } from "@anthropic-ai/sdk" import { Anthropic } from "@anthropic-ai/sdk"
import { AnthropicStream, ApiHandler } from "../index" import { Stream as AnthropicStream } from "@anthropic-ai/sdk/streaming"
import { import {
anthropicDefaultModelId, anthropicDefaultModelId,
AnthropicModelId, AnthropicModelId,
@@ -7,6 +7,8 @@ import {
ApiHandlerOptions, ApiHandlerOptions,
ModelInfo, ModelInfo,
} from "../../shared/api" } from "../../shared/api"
import { ApiHandler } from "../index"
import { ApiStream } from "../transform/stream"
export class AnthropicHandler implements ApiHandler { export class AnthropicHandler implements ApiHandler {
private options: ApiHandlerOptions private options: ApiHandlerOptions
@@ -20,11 +22,8 @@ export class AnthropicHandler implements ApiHandler {
}) })
} }
async createMessage( async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
systemPrompt: string, let stream: AnthropicStream<Anthropic.Beta.PromptCaching.Messages.RawPromptCachingBetaMessageStreamEvent>
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<AnthropicStream> {
const modelId = this.getModel().id const modelId = this.getModel().id
switch (modelId) { switch (modelId) {
case "claude-3-5-sonnet-20240620": case "claude-3-5-sonnet-20240620":
@@ -39,7 +38,7 @@ export class AnthropicHandler implements ApiHandler {
) )
const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1
const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1
const stream = this.client.beta.promptCaching.messages.create( stream = await this.client.beta.promptCaching.messages.create(
{ {
model: modelId, model: modelId,
max_tokens: this.getModel().info.maxTokens, max_tokens: this.getModel().info.maxTokens,
@@ -92,14 +91,10 @@ export class AnthropicHandler implements ApiHandler {
} }
})() })()
) )
break
return stream
// throw new Error("Not implemented")
// return { message }
} }
default: { default: {
const stream = await this.client.messages.create({ stream = (await this.client.messages.create({
model: modelId, model: modelId,
max_tokens: this.getModel().info.maxTokens, max_tokens: this.getModel().info.maxTokens,
temperature: 0, temperature: 0,
@@ -108,8 +103,67 @@ export class AnthropicHandler implements ApiHandler {
// tools, // tools,
// tool_choice: { type: "auto" }, // tool_choice: { type: "auto" },
stream: true, stream: true,
}) })) as any
return stream as AnthropicStream break
}
}
for await (const chunk of stream) {
switch (chunk.type) {
case "message_start":
// tells us cache reads/writes/input/output
const usage = chunk.message.usage
yield {
type: "usage",
inputTokens: usage.input_tokens || 0,
outputTokens: usage.output_tokens || 0,
cacheWriteTokens: usage.cache_creation_input_tokens || 0,
cacheReadTokens: usage.cache_read_input_tokens || 0,
}
break
case "message_delta":
// tells us stop_reason, stop_sequence, and output tokens along the way and at the end of the message
yield {
type: "usage",
inputTokens: 0,
outputTokens: chunk.usage.output_tokens || 0,
cacheWriteTokens: 0,
cacheReadTokens: 0,
}
break
case "message_stop":
// no usage data, just an indicator that the message is done
break
case "content_block_start":
switch (chunk.content_block.type) {
case "text":
// we may receive multiple text blocks, in which case just insert a line break between them
if (chunk.index > 0) {
yield {
type: "text",
text: "\n",
}
}
yield {
type: "text",
text: chunk.content_block.text,
}
break
}
break
case "content_block_delta":
switch (chunk.delta.type) {
case "text_delta":
yield {
type: "text",
text: chunk.delta.text,
}
break
}
break
case "content_block_stop":
break
} }
} }
} }

View File

@@ -0,0 +1,16 @@
export type ApiStream = AsyncGenerator<ApiStreamChunk>
export type ApiStreamChunk = ApiStreamTextChunk | ApiStreamUsageChunk
export interface ApiStreamTextChunk {
type: "text"
text: string
}
export interface ApiStreamUsageChunk {
type: "usage"
inputTokens: number
outputTokens: number
cacheWriteTokens: number
cacheReadTokens: number
totalCost?: number // openrouter
}

View File

@@ -8,7 +8,7 @@ import pWaitFor from "p-wait-for"
import * as path from "path" import * as path from "path"
import { serializeError } from "serialize-error" import { serializeError } from "serialize-error"
import * as vscode from "vscode" import * as vscode from "vscode"
import { AnthropicStream, ApiHandler, buildApiHandler } from "../api" import { ApiHandler, buildApiHandler } from "../api"
import { diagnosticsToProblemsString, getNewDiagnostics } from "../integrations/diagnostics" import { diagnosticsToProblemsString, getNewDiagnostics } from "../integrations/diagnostics"
import { formatContentBlockToMarkdown } from "../integrations/misc/export-markdown" import { formatContentBlockToMarkdown } from "../integrations/misc/export-markdown"
import { extractTextFromFile } from "../integrations/misc/extract-text" import { extractTextFromFile } from "../integrations/misc/extract-text"
@@ -41,6 +41,7 @@ import { SYSTEM_PROMPT } from "./prompts/system"
import { TOOLS } from "./prompts/tools" import { TOOLS } from "./prompts/tools"
import { truncateHalfConversation } from "./sliding-window" import { truncateHalfConversation } from "./sliding-window"
import { ClaudeDevProvider } from "./webview/ClaudeDevProvider" import { ClaudeDevProvider } from "./webview/ClaudeDevProvider"
import { ApiStream } from "../api/transform/stream"
const cwd = const cwd =
vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0) ?? path.join(os.homedir(), "Desktop") // may or may not exist but fs checking existence would immediately ask for permission which would be bad UX, need to come up with a better solution vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0) ?? path.join(os.homedir(), "Desktop") // may or may not exist but fs checking existence would immediately ask for permission which would be bad UX, need to come up with a better solution
@@ -1205,7 +1206,7 @@ export class ClaudeDev {
} }
} }
async attemptApiRequest(previousApiReqIndex: number): Promise<AnthropicStream> { async attemptApiRequest(previousApiReqIndex: number): Promise<ApiStream> {
try { try {
let systemPrompt = await SYSTEM_PROMPT(cwd, this.api.getModel().info.supportsImages) let systemPrompt = await SYSTEM_PROMPT(cwd, this.api.getModel().info.supportsImages)
if (this.customInstructions && this.customInstructions.trim()) { if (this.customInstructions && this.customInstructions.trim()) {
@@ -1241,7 +1242,7 @@ ${this.customInstructions.trim()}
} }
} }
} }
const stream = await this.api.createMessage( const stream = this.api.createMessage(
systemPrompt, systemPrompt,
this.apiConversationHistory, this.apiConversationHistory,
TOOLS(cwd, this.api.getModel().info.supportsImages) TOOLS(cwd, this.api.getModel().info.supportsImages)
@@ -2198,7 +2199,6 @@ ${this.customInstructions.trim()}
private didRejectTool = false private didRejectTool = false
private presentAssistantMessageLocked = false private presentAssistantMessageLocked = false
private presentAssistantMessageHasPendingUpdates = false private presentAssistantMessageHasPendingUpdates = false
private parseTextStreamAccumulator = ""
//edit //edit
private isEditingExistingFile: boolean | undefined private isEditingExistingFile: boolean | undefined
private isEditingFile = false private isEditingFile = false
@@ -2206,9 +2206,7 @@ ${this.customInstructions.trim()}
private editFileCreatedDirs: string[] = [] private editFileCreatedDirs: string[] = []
private editFileDocumentWasOpen = false private editFileDocumentWasOpen = false
parseTextStream(chunk: string) { parseAssistantMessage(assistantMessage: string) {
this.parseTextStreamAccumulator += chunk
// let text = "" // let text = ""
let textContent: TextContent = { let textContent: TextContent = {
type: "text", type: "text",
@@ -2222,7 +2220,7 @@ ${this.customInstructions.trim()}
let currentParamValueLines: string[] = [] let currentParamValueLines: string[] = []
let textContentLines: string[] = [] let textContentLines: string[] = []
const rawLines = this.parseTextStreamAccumulator.split("\n") const rawLines = assistantMessage.split("\n")
if (rawLines.length === 1) { if (rawLines.length === 1) {
const firstLine = rawLines[0].trim() const firstLine = rawLines[0].trim()
@@ -2374,14 +2372,12 @@ ${this.customInstructions.trim()}
try { try {
const stream = await this.attemptApiRequest(previousApiReqIndex) const stream = await this.attemptApiRequest(previousApiReqIndex)
let cacheCreationInputTokens = 0 let cacheWriteTokens = 0
let cacheReadInputTokens = 0 let cacheReadTokens = 0
let inputTokens = 0 let inputTokens = 0
let outputTokens = 0 let outputTokens = 0
let totalCost: number | undefined
// todo add error listeners so we can return api error? or wil lfor await handle that below?
let apiContentBlocks: Anthropic.ContentBlock[] = []
this.currentStreamingContentIndex = 0 this.currentStreamingContentIndex = 0
this.assistantMessageContent = [] this.assistantMessageContent = []
this.didCompleteReadingStream = false this.didCompleteReadingStream = false
@@ -2390,7 +2386,7 @@ ${this.customInstructions.trim()}
this.didRejectTool = false this.didRejectTool = false
this.presentAssistantMessageLocked = false this.presentAssistantMessageLocked = false
this.presentAssistantMessageHasPendingUpdates = false this.presentAssistantMessageHasPendingUpdates = false
this.parseTextStreamAccumulator = ""
// edit // edit
this.isEditingExistingFile = undefined this.isEditingExistingFile = undefined
this.isEditingFile = false this.isEditingFile = false
@@ -2398,49 +2394,23 @@ ${this.customInstructions.trim()}
this.editFileCreatedDirs = [] this.editFileCreatedDirs = []
this.editFileDocumentWasOpen = false this.editFileDocumentWasOpen = false
let assistantMessage = ""
// TODO: handle error being thrown in stream
for await (const chunk of stream) { for await (const chunk of stream) {
switch (chunk.type) { switch (chunk.type) {
case "message_start": case "usage":
// tells us cache reads/writes/input/output inputTokens += chunk.inputTokens
const usage = chunk.message.usage outputTokens += chunk.outputTokens
cacheCreationInputTokens += usage.cache_creation_input_tokens || 0 cacheWriteTokens += chunk.cacheWriteTokens
cacheReadInputTokens += usage.cache_read_input_tokens || 0 cacheReadTokens += chunk.cacheReadTokens
inputTokens += usage.input_tokens || 0 totalCost = chunk.totalCost
outputTokens += usage.output_tokens || 0
break break
case "message_delta":
// tells us stop_reason, stop_sequence, and output tokens along the way and at the end of the message
outputTokens += chunk.usage.output_tokens || 0
break
case "message_stop":
// no usage data, just an indicator that the message is done
break
case "content_block_start":
// await delay(4_000)
switch (chunk.content_block.type) {
case "text": case "text":
apiContentBlocks.push(chunk.content_block) assistantMessage += chunk.text
// we may receive multiple text blocks, in which case just insert a line break between them this.parseAssistantMessage(assistantMessage)
if (chunk.index > 0) {
this.parseTextStream("\n")
}
this.parseTextStream(chunk.content_block.text)
this.presentAssistantMessage() this.presentAssistantMessage()
break break
} }
break
case "content_block_delta":
switch (chunk.delta.type) {
case "text_delta":
;(apiContentBlocks[chunk.index] as Anthropic.TextBlock).text += chunk.delta.text
this.parseTextStream(chunk.delta.text)
this.presentAssistantMessage()
break
}
break
case "content_block_stop":
break
}
} }
this.didCompleteReadingStream = true this.didCompleteReadingStream = true
@@ -2454,7 +2424,6 @@ ${this.customInstructions.trim()}
this.presentAssistantMessage() // if there is content to update then it will complete and update this.userMessageContentReady to true, which we pwaitfor before making the next request this.presentAssistantMessage() // if there is content to update then it will complete and update this.userMessageContentReady to true, which we pwaitfor before making the next request
} }
let totalCost: string | undefined
// let inputTokens = response.usage.input_tokens // let inputTokens = response.usage.input_tokens
// let outputTokens = response.usage.output_tokens // let outputTokens = response.usage.output_tokens
// let cacheCreationInputTokens = // let cacheCreationInputTokens =
@@ -2473,11 +2442,9 @@ ${this.customInstructions.trim()}
...JSON.parse(this.claudeMessages[lastApiReqIndex].text), ...JSON.parse(this.claudeMessages[lastApiReqIndex].text),
tokensIn: inputTokens, tokensIn: inputTokens,
tokensOut: outputTokens, tokensOut: outputTokens,
cacheWrites: cacheCreationInputTokens, cacheWrites: cacheWriteTokens,
cacheReads: cacheReadInputTokens, cacheReads: cacheReadTokens,
cost: cost: totalCost ?? this.calculateApiCost(inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens),
totalCost ||
this.calculateApiCost(inputTokens, outputTokens, cacheCreationInputTokens, cacheReadInputTokens),
}) })
await this.saveClaudeMessages() await this.saveClaudeMessages()
await this.providerRef.deref()?.postStateToWebview() await this.providerRef.deref()?.postStateToWebview()
@@ -2485,25 +2452,17 @@ ${this.customInstructions.trim()}
// now add to apiconversationhistory // now add to apiconversationhistory
// need to save assistant responses to file before proceeding to tool use since user can exit at any moment and we wouldn't be able to save the assistant's response // need to save assistant responses to file before proceeding to tool use since user can exit at any moment and we wouldn't be able to save the assistant's response
let didEndLoop = false let didEndLoop = false
if (apiContentBlocks.length > 0) { if (assistantMessage.length > 0) {
// Remove 'partial' prop from assistantContentBlocks await this.addToApiConversationHistory({
// const blocksWithoutPartial: Anthropic.Messages.ContentBlock[] = this.assistantContentBlocks.map( role: "assistant",
// (block) => { content: [{ type: "text", text: assistantMessage }],
// const { partial, ...rest } = block })
// return rest
// }
// )
await this.addToApiConversationHistory({ role: "assistant", content: apiContentBlocks })
// in case the content blocks finished // in case the content blocks finished
// it may be the api stream finished after the last parsed content block was executed, so we are able to detect out of bounds and set userMessageContentReady to true (not you should not call presentAssistantMessage since if the last block is completed it will be presented again) // it may be the api stream finished after the last parsed content block was executed, so we are able to detect out of bounds and set userMessageContentReady to true (not you should not call presentAssistantMessage since if the last block is completed it will be presented again)
const completeBlocks = this.assistantMessageContent.filter((block) => !block.partial) // if there are any partial blocks after the stream ended we can consider them invalid const completeBlocks = this.assistantMessageContent.filter((block) => !block.partial) // if there are any partial blocks after the stream ended we can consider them invalid
if (this.currentStreamingContentIndex >= completeBlocks.length) { if (this.currentStreamingContentIndex >= completeBlocks.length) {
this.userMessageContentReady = true this.userMessageContentReady = true
//throw new Error("No more content blocks to stream! This shouldn't happen...") // remove and just return after testing
} }
await pWaitFor(() => this.userMessageContentReady) await pWaitFor(() => this.userMessageContentReady)
@@ -2511,7 +2470,7 @@ ${this.customInstructions.trim()}
const recDidEndLoop = await this.recursivelyMakeClaudeRequests(this.userMessageContent) const recDidEndLoop = await this.recursivelyMakeClaudeRequests(this.userMessageContent)
didEndLoop = recDidEndLoop didEndLoop = recDidEndLoop
} else { } else {
// this should never happen! it there's no assistant_responses, that means we got no text or tool_use content blocks from API which we should assume is an error // if there's no assistant_responses, that means we got no text or tool_use content blocks from API which we should assume is an error
await this.say( await this.say(
"error", "error",
"Unexpected API Response: The language model did not provide any assistant messages. This may indicate an issue with the API or the model's output." "Unexpected API Response: The language model did not provide any assistant messages. This may indicate an issue with the API or the model's output."