Replace token estimation with using last API response token usage

This commit is contained in:
Saoud Rizwan
2024-08-30 22:29:18 -04:00
parent dcd6d84632
commit 3e58160d99
6 changed files with 28 additions and 140 deletions

View File

@@ -15,3 +15,8 @@ export function findLastIndex<T>(array: Array<T>, predicate: (value: T, index: n
}
return -1
}
export function findLast<T>(array: Array<T>, predicate: (value: T, index: number, obj: T[]) => boolean): T | undefined {
const index = findLastIndex(array, predicate)
return index === -1 ? undefined : array[index]
}

View File

@@ -1,26 +1,4 @@
import { Anthropic } from "@anthropic-ai/sdk"
import { countTokens } from "@anthropic-ai/tokenizer"
import { Buffer } from "buffer"
import sizeOf from "image-size"
export function isWithinContextWindow(
contextWindow: number,
systemPrompt: string,
tools: Anthropic.Messages.Tool[],
messages: Anthropic.Messages.MessageParam[]
): boolean {
const adjustedContextWindow = contextWindow * 0.75 // Buffer to account for tokenizer differences
// counting tokens is expensive, so we first try to estimate before doing a more accurate calculation
const estimatedTotalMessageTokens = countTokens(systemPrompt + JSON.stringify(tools) + JSON.stringify(messages))
if (estimatedTotalMessageTokens <= adjustedContextWindow) {
return true
}
const systemPromptTokens = countTokens(systemPrompt)
const toolsTokens = countTokens(JSON.stringify(tools))
let availableTokens = adjustedContextWindow - systemPromptTokens - toolsTokens
let accurateTotalMessageTokens = messages.reduce((sum, message) => sum + countMessageTokens(message), 0)
return accurateTotalMessageTokens <= availableTokens
}
/*
We can't implement a dynamically updating sliding window as it would break prompt cache
@@ -46,54 +24,3 @@ export function truncateHalfConversation(
return truncatedMessages
}
function countMessageTokens(message: Anthropic.Messages.MessageParam): number {
if (typeof message.content === "string") {
return countTokens(message.content)
} else if (Array.isArray(message.content)) {
return message.content.reduce((sum, item) => {
if (typeof item === "string") {
return sum + countTokens(item)
} else if (item.type === "text") {
return sum + countTokens(item.text)
} else if (item.type === "image") {
return sum + estimateImageTokens(item.source.data)
} else if (item.type === "tool_use") {
return sum + countTokens(JSON.stringify(item.input))
} else if (item.type === "tool_result") {
if (Array.isArray(item.content)) {
return (
sum +
item.content.reduce((contentSum, contentItem) => {
if (contentItem.type === "text") {
return contentSum + countTokens(contentItem.text)
} else if (contentItem.type === "image") {
return contentSum + estimateImageTokens(contentItem.source.data)
}
return contentSum + countTokens(JSON.stringify(contentItem))
}, 0)
)
} else {
return sum + countTokens(item.content || "")
}
} else {
return sum + countTokens(JSON.stringify(item))
}
}, 0)
} else {
return countTokens(JSON.stringify(message.content))
}
}
function estimateImageTokens(base64: string): number {
const base64Data = base64.split(";base64,").pop()
if (base64Data) {
const buffer = Buffer.from(base64Data, "base64")
const dimensions = sizeOf(buffer)
if (dimensions.width && dimensions.height) {
// "you can estimate the number of tokens used through this algorithm: tokens = (width px * height px)/750"
return Math.ceil((dimensions.width * dimensions.height) / 750)
}
}
return countTokens(base64)
}