Improve sliding window algorithm to not break prompt cache so often

This commit is contained in:
Saoud Rizwan
2024-08-28 20:32:58 -04:00
parent 98fdf898be
commit a160e8d67b
2 changed files with 49 additions and 61 deletions

View File

@@ -25,7 +25,7 @@ import { HistoryItem } from "./shared/HistoryItem"
import { combineApiRequests } from "./shared/combineApiRequests" import { combineApiRequests } from "./shared/combineApiRequests"
import { combineCommandSequences } from "./shared/combineCommandSequences" import { combineCommandSequences } from "./shared/combineCommandSequences"
import { findLastIndex } from "./utils" import { findLastIndex } from "./utils"
import { slidingWindowContextManagement } from "./utils/context-management" import { isWithinContextWindow, truncateHalfConversation } from "./utils/context-management"
const SYSTEM_PROMPT = const SYSTEM_PROMPT =
() => `You are Claude Dev, a highly skilled software developer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices. () => `You are Claude Dev, a highly skilled software developer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
@@ -1253,13 +1253,21 @@ The following additional instructions are provided by the user. They should be f
${this.customInstructions.trim()} ${this.customInstructions.trim()}
` `
} }
const adjustedMessages = slidingWindowContextManagement( const isPromptWithinContextWindow = isWithinContextWindow(
this.api.getModel().info.contextWindow, this.api.getModel().info.contextWindow,
systemPrompt,
tools,
this.apiConversationHistory
)
if (!isPromptWithinContextWindow) {
const truncatedMessages = truncateHalfConversation(this.apiConversationHistory)
await this.overwriteApiConversationHistory(truncatedMessages)
}
const { message, userCredits } = await this.api.createMessage(
systemPrompt, systemPrompt,
this.apiConversationHistory, this.apiConversationHistory,
tools tools
) )
const { message, userCredits } = await this.api.createMessage(systemPrompt, adjustedMessages, tools)
if (userCredits !== undefined) { if (userCredits !== undefined) {
console.log("Updating credits", userCredits) console.log("Updating credits", userCredits)
// TODO: update credits // TODO: update credits

View File

@@ -2,74 +2,54 @@ import { Anthropic } from "@anthropic-ai/sdk"
import { countTokens } from "@anthropic-ai/tokenizer" import { countTokens } from "@anthropic-ai/tokenizer"
import { Buffer } from "buffer" import { Buffer } from "buffer"
import sizeOf from "image-size" import sizeOf from "image-size"
import cloneDeep from "clone-deep"
export function slidingWindowContextManagement( export function isWithinContextWindow(
contextWindow: number, contextWindow: number,
systemPrompt: string, systemPrompt: string,
messages: Anthropic.Messages.MessageParam[], tools: Anthropic.Messages.Tool[],
tools: Anthropic.Messages.Tool[] messages: Anthropic.Messages.MessageParam[]
): Anthropic.Messages.MessageParam[] { ): boolean {
const adjustedContextWindow = contextWindow - 10_000 // Buffer to account for tokenizer differences const adjustedContextWindow = contextWindow - 10_000 // Buffer to account for tokenizer differences
// counting tokens is expensive, so we first try to estimate before doing a more accurate calculation
const estimatedTotalMessageTokens = countTokens(systemPrompt + JSON.stringify(tools) + JSON.stringify(messages))
if (estimatedTotalMessageTokens <= adjustedContextWindow) {
return true
}
const systemPromptTokens = countTokens(systemPrompt) const systemPromptTokens = countTokens(systemPrompt)
const toolsTokens = countTokens(JSON.stringify(tools)) const toolsTokens = countTokens(JSON.stringify(tools))
let availableTokens = adjustedContextWindow - systemPromptTokens - toolsTokens let availableTokens = adjustedContextWindow - systemPromptTokens - toolsTokens
let totalMessageTokens = messages.reduce((sum, message) => sum + countMessageTokens(message), 0) let accurateTotalMessageTokens = messages.reduce((sum, message) => sum + countMessageTokens(message), 0)
return accurateTotalMessageTokens <= availableTokens
if (totalMessageTokens <= availableTokens) {
return messages
} }
// If over limit, remove messages starting from the third message onwards (task and claude's step-by-step thought process are important to keep in context) /*
const newMessages = cloneDeep(messages) // since we're manipulating nested objects and arrays, need to deep clone to prevent mutating original history We can't implement a dynamically updating sliding window as it would break prompt cache
let index = 2 every time. To maintain the benefits of caching, we need to keep conversation history
while (totalMessageTokens > availableTokens && index < newMessages.length) { static. This operation should be performed as infrequently as possible. If a user reaches
const messageToEmpty = newMessages[index] a 200k context, we can assume that the first half is likely irrelevant to their current task.
const originalTokens = countMessageTokens(messageToEmpty) Therefore, this function should only be called when absolutely necessary to fit within
// Empty the content of the message (messages must be in a specific order so we can't just remove) context limits, not as a continuous process.
if (typeof messageToEmpty.content === "string") { */
messageToEmpty.content = "(truncated due to context limits)" export function truncateHalfConversation(
} else if (Array.isArray(messageToEmpty.content)) { messages: Anthropic.Messages.MessageParam[]
messageToEmpty.content = messageToEmpty.content.map((item) => { ): Anthropic.Messages.MessageParam[] {
if (typeof item === "string") { // Anthropic expects messages to be in user-assistant order, and tool use messages must be followed by tool results. We need to maintain this structure while truncating.
return {
type: "text", // Keep the first Task message (likely the most important)
text: "(truncated due to context limits)", const truncatedMessages = [messages[0]]
} as Anthropic.Messages.TextBlockParam
} else if (item.type === "text") { // Remove half of user-assistant pairs
return { const messagesToRemove = Math.floor(messages.length / 4) * 2 // has to be even number
type: "text", const summaryMessage: Anthropic.Messages.MessageParam = {
text: "(truncated due to context limits)", role: "assistant",
} as Anthropic.Messages.TextBlockParam content: `(${messagesToRemove} messages were truncated to fit within context limits)`,
} else if (item.type === "image") {
return {
type: "text",
text: "(image removed due to context limits)",
} as Anthropic.Messages.TextBlockParam
} else if (item.type === "tool_use") {
return { ...item, input: {} } as Anthropic.Messages.ToolUseBlockParam
} else if (item.type === "tool_result") {
return {
...item,
content: Array.isArray(item.content)
? item.content.map((contentItem) =>
contentItem.type === "text"
? { type: "text", text: "(truncated due to context limits)" }
: contentItem.type === "image"
? { type: "text", text: "(image removed due to context limits)" }
: contentItem
)
: "(truncated due to context limits)",
} as Anthropic.Messages.ToolResultBlockParam
} }
return item truncatedMessages.push(summaryMessage)
})
} const remainingMessages = messages.slice(messagesToRemove)
const newTokens = countMessageTokens(messageToEmpty) truncatedMessages.push(...remainingMessages)
totalMessageTokens -= originalTokens - newTokens
index++ return truncatedMessages
}
return newMessages
} }
function countMessageTokens(message: Anthropic.Messages.MessageParam): number { function countMessageTokens(message: Anthropic.Messages.MessageParam): number {