Add ApiStream generator to interface with providers

This commit is contained in:
Saoud Rizwan
2024-09-28 22:17:50 -04:00
parent 19a0ac00bd
commit 7271152f62
4 changed files with 119 additions and 92 deletions

View File

@@ -1,5 +1,4 @@
import { Anthropic } from "@anthropic-ai/sdk"
import { Stream } from "@anthropic-ai/sdk/streaming"
import { ApiConfiguration, ModelInfo } from "../shared/api"
import { AnthropicHandler } from "./providers/anthropic"
import { AwsBedrockHandler } from "./providers/bedrock"
@@ -9,15 +8,14 @@ import { OpenAiHandler } from "./providers/openai"
import { OllamaHandler } from "./providers/ollama"
import { GeminiHandler } from "./providers/gemini"
import { OpenAiNativeHandler } from "./providers/openai-native"
export type AnthropicStream = Stream<Anthropic.Beta.PromptCaching.Messages.RawPromptCachingBetaMessageStreamEvent>
import { ApiStream } from "./transform/stream"
export interface ApiHandler {
createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<AnthropicStream>
): ApiStream
getModel(): { id: string; info: ModelInfo }
}

View File

@@ -1,5 +1,5 @@
import { Anthropic } from "@anthropic-ai/sdk"
import { AnthropicStream, ApiHandler } from "../index"
import { Stream as AnthropicStream } from "@anthropic-ai/sdk/streaming"
import {
anthropicDefaultModelId,
AnthropicModelId,
@@ -7,6 +7,8 @@ import {
ApiHandlerOptions,
ModelInfo,
} from "../../shared/api"
import { ApiHandler } from "../index"
import { ApiStream } from "../transform/stream"
export class AnthropicHandler implements ApiHandler {
private options: ApiHandlerOptions
@@ -20,11 +22,8 @@ export class AnthropicHandler implements ApiHandler {
})
}
async createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<AnthropicStream> {
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
let stream: AnthropicStream<Anthropic.Beta.PromptCaching.Messages.RawPromptCachingBetaMessageStreamEvent>
const modelId = this.getModel().id
switch (modelId) {
case "claude-3-5-sonnet-20240620":
@@ -39,7 +38,7 @@ export class AnthropicHandler implements ApiHandler {
)
const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1
const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1
const stream = this.client.beta.promptCaching.messages.create(
stream = await this.client.beta.promptCaching.messages.create(
{
model: modelId,
max_tokens: this.getModel().info.maxTokens,
@@ -92,14 +91,10 @@ export class AnthropicHandler implements ApiHandler {
}
})()
)
return stream
// throw new Error("Not implemented")
// return { message }
break
}
default: {
const stream = await this.client.messages.create({
stream = (await this.client.messages.create({
model: modelId,
max_tokens: this.getModel().info.maxTokens,
temperature: 0,
@@ -108,8 +103,67 @@ export class AnthropicHandler implements ApiHandler {
// tools,
// tool_choice: { type: "auto" },
stream: true,
})
return stream as AnthropicStream
})) as any
break
}
}
for await (const chunk of stream) {
switch (chunk.type) {
case "message_start":
// tells us cache reads/writes/input/output
const usage = chunk.message.usage
yield {
type: "usage",
inputTokens: usage.input_tokens || 0,
outputTokens: usage.output_tokens || 0,
cacheWriteTokens: usage.cache_creation_input_tokens || 0,
cacheReadTokens: usage.cache_read_input_tokens || 0,
}
break
case "message_delta":
// tells us stop_reason, stop_sequence, and output tokens along the way and at the end of the message
yield {
type: "usage",
inputTokens: 0,
outputTokens: chunk.usage.output_tokens || 0,
cacheWriteTokens: 0,
cacheReadTokens: 0,
}
break
case "message_stop":
// no usage data, just an indicator that the message is done
break
case "content_block_start":
switch (chunk.content_block.type) {
case "text":
// we may receive multiple text blocks, in which case just insert a line break between them
if (chunk.index > 0) {
yield {
type: "text",
text: "\n",
}
}
yield {
type: "text",
text: chunk.content_block.text,
}
break
}
break
case "content_block_delta":
switch (chunk.delta.type) {
case "text_delta":
yield {
type: "text",
text: chunk.delta.text,
}
break
}
break
case "content_block_stop":
break
}
}
}

View File

@@ -0,0 +1,16 @@
export type ApiStream = AsyncGenerator<ApiStreamChunk>
export type ApiStreamChunk = ApiStreamTextChunk | ApiStreamUsageChunk
export interface ApiStreamTextChunk {
type: "text"
text: string
}
export interface ApiStreamUsageChunk {
type: "usage"
inputTokens: number
outputTokens: number
cacheWriteTokens: number
cacheReadTokens: number
totalCost?: number // openrouter
}

View File

@@ -8,7 +8,7 @@ import pWaitFor from "p-wait-for"
import * as path from "path"
import { serializeError } from "serialize-error"
import * as vscode from "vscode"
import { AnthropicStream, ApiHandler, buildApiHandler } from "../api"
import { ApiHandler, buildApiHandler } from "../api"
import { diagnosticsToProblemsString, getNewDiagnostics } from "../integrations/diagnostics"
import { formatContentBlockToMarkdown } from "../integrations/misc/export-markdown"
import { extractTextFromFile } from "../integrations/misc/extract-text"
@@ -41,6 +41,7 @@ import { SYSTEM_PROMPT } from "./prompts/system"
import { TOOLS } from "./prompts/tools"
import { truncateHalfConversation } from "./sliding-window"
import { ClaudeDevProvider } from "./webview/ClaudeDevProvider"
import { ApiStream } from "../api/transform/stream"
const cwd =
vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0) ?? path.join(os.homedir(), "Desktop") // may or may not exist but fs checking existence would immediately ask for permission which would be bad UX, need to come up with a better solution
@@ -1205,7 +1206,7 @@ export class ClaudeDev {
}
}
async attemptApiRequest(previousApiReqIndex: number): Promise<AnthropicStream> {
async attemptApiRequest(previousApiReqIndex: number): Promise<ApiStream> {
try {
let systemPrompt = await SYSTEM_PROMPT(cwd, this.api.getModel().info.supportsImages)
if (this.customInstructions && this.customInstructions.trim()) {
@@ -1241,7 +1242,7 @@ ${this.customInstructions.trim()}
}
}
}
const stream = await this.api.createMessage(
const stream = this.api.createMessage(
systemPrompt,
this.apiConversationHistory,
TOOLS(cwd, this.api.getModel().info.supportsImages)
@@ -2198,7 +2199,6 @@ ${this.customInstructions.trim()}
private didRejectTool = false
private presentAssistantMessageLocked = false
private presentAssistantMessageHasPendingUpdates = false
private parseTextStreamAccumulator = ""
//edit
private isEditingExistingFile: boolean | undefined
private isEditingFile = false
@@ -2206,9 +2206,7 @@ ${this.customInstructions.trim()}
private editFileCreatedDirs: string[] = []
private editFileDocumentWasOpen = false
parseTextStream(chunk: string) {
this.parseTextStreamAccumulator += chunk
parseAssistantMessage(assistantMessage: string) {
// let text = ""
let textContent: TextContent = {
type: "text",
@@ -2222,7 +2220,7 @@ ${this.customInstructions.trim()}
let currentParamValueLines: string[] = []
let textContentLines: string[] = []
const rawLines = this.parseTextStreamAccumulator.split("\n")
const rawLines = assistantMessage.split("\n")
if (rawLines.length === 1) {
const firstLine = rawLines[0].trim()
@@ -2374,14 +2372,12 @@ ${this.customInstructions.trim()}
try {
const stream = await this.attemptApiRequest(previousApiReqIndex)
let cacheCreationInputTokens = 0
let cacheReadInputTokens = 0
let cacheWriteTokens = 0
let cacheReadTokens = 0
let inputTokens = 0
let outputTokens = 0
let totalCost: number | undefined
// todo add error listeners so we can return api error? or wil lfor await handle that below?
let apiContentBlocks: Anthropic.ContentBlock[] = []
this.currentStreamingContentIndex = 0
this.assistantMessageContent = []
this.didCompleteReadingStream = false
@@ -2390,7 +2386,7 @@ ${this.customInstructions.trim()}
this.didRejectTool = false
this.presentAssistantMessageLocked = false
this.presentAssistantMessageHasPendingUpdates = false
this.parseTextStreamAccumulator = ""
// edit
this.isEditingExistingFile = undefined
this.isEditingFile = false
@@ -2398,49 +2394,23 @@ ${this.customInstructions.trim()}
this.editFileCreatedDirs = []
this.editFileDocumentWasOpen = false
let assistantMessage = ""
// TODO: handle error being thrown in stream
for await (const chunk of stream) {
switch (chunk.type) {
case "message_start":
// tells us cache reads/writes/input/output
const usage = chunk.message.usage
cacheCreationInputTokens += usage.cache_creation_input_tokens || 0
cacheReadInputTokens += usage.cache_read_input_tokens || 0
inputTokens += usage.input_tokens || 0
outputTokens += usage.output_tokens || 0
case "usage":
inputTokens += chunk.inputTokens
outputTokens += chunk.outputTokens
cacheWriteTokens += chunk.cacheWriteTokens
cacheReadTokens += chunk.cacheReadTokens
totalCost = chunk.totalCost
break
case "message_delta":
// tells us stop_reason, stop_sequence, and output tokens along the way and at the end of the message
outputTokens += chunk.usage.output_tokens || 0
break
case "message_stop":
// no usage data, just an indicator that the message is done
break
case "content_block_start":
// await delay(4_000)
switch (chunk.content_block.type) {
case "text":
apiContentBlocks.push(chunk.content_block)
// we may receive multiple text blocks, in which case just insert a line break between them
if (chunk.index > 0) {
this.parseTextStream("\n")
}
this.parseTextStream(chunk.content_block.text)
assistantMessage += chunk.text
this.parseAssistantMessage(assistantMessage)
this.presentAssistantMessage()
break
}
break
case "content_block_delta":
switch (chunk.delta.type) {
case "text_delta":
;(apiContentBlocks[chunk.index] as Anthropic.TextBlock).text += chunk.delta.text
this.parseTextStream(chunk.delta.text)
this.presentAssistantMessage()
break
}
break
case "content_block_stop":
break
}
}
this.didCompleteReadingStream = true
@@ -2454,7 +2424,6 @@ ${this.customInstructions.trim()}
this.presentAssistantMessage() // if there is content to update then it will complete and update this.userMessageContentReady to true, which we pwaitfor before making the next request
}
let totalCost: string | undefined
// let inputTokens = response.usage.input_tokens
// let outputTokens = response.usage.output_tokens
// let cacheCreationInputTokens =
@@ -2473,11 +2442,9 @@ ${this.customInstructions.trim()}
...JSON.parse(this.claudeMessages[lastApiReqIndex].text),
tokensIn: inputTokens,
tokensOut: outputTokens,
cacheWrites: cacheCreationInputTokens,
cacheReads: cacheReadInputTokens,
cost:
totalCost ||
this.calculateApiCost(inputTokens, outputTokens, cacheCreationInputTokens, cacheReadInputTokens),
cacheWrites: cacheWriteTokens,
cacheReads: cacheReadTokens,
cost: totalCost ?? this.calculateApiCost(inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens),
})
await this.saveClaudeMessages()
await this.providerRef.deref()?.postStateToWebview()
@@ -2485,25 +2452,17 @@ ${this.customInstructions.trim()}
// now add to apiconversationhistory
// need to save assistant responses to file before proceeding to tool use since user can exit at any moment and we wouldn't be able to save the assistant's response
let didEndLoop = false
if (apiContentBlocks.length > 0) {
// Remove 'partial' prop from assistantContentBlocks
// const blocksWithoutPartial: Anthropic.Messages.ContentBlock[] = this.assistantContentBlocks.map(
// (block) => {
// const { partial, ...rest } = block
// return rest
// }
// )
await this.addToApiConversationHistory({ role: "assistant", content: apiContentBlocks })
if (assistantMessage.length > 0) {
await this.addToApiConversationHistory({
role: "assistant",
content: [{ type: "text", text: assistantMessage }],
})
// in case the content blocks finished
// it may be the api stream finished after the last parsed content block was executed, so we are able to detect out of bounds and set userMessageContentReady to true (not you should not call presentAssistantMessage since if the last block is completed it will be presented again)
const completeBlocks = this.assistantMessageContent.filter((block) => !block.partial) // if there are any partial blocks after the stream ended we can consider them invalid
if (this.currentStreamingContentIndex >= completeBlocks.length) {
this.userMessageContentReady = true
//throw new Error("No more content blocks to stream! This shouldn't happen...") // remove and just return after testing
}
await pWaitFor(() => this.userMessageContentReady)
@@ -2511,7 +2470,7 @@ ${this.customInstructions.trim()}
const recDidEndLoop = await this.recursivelyMakeClaudeRequests(this.userMessageContent)
didEndLoop = recDidEndLoop
} else {
// this should never happen! it there's no assistant_responses, that means we got no text or tool_use content blocks from API which we should assume is an error
// if there's no assistant_responses, that means we got no text or tool_use content blocks from API which we should assume is an error
await this.say(
"error",
"Unexpected API Response: The language model did not provide any assistant messages. This may indicate an issue with the API or the model's output."