Improve sliding window algorithm to not break prompt cache so often

This commit is contained in:
Saoud Rizwan
2024-08-28 20:32:58 -04:00
parent 98fdf898be
commit a160e8d67b
2 changed files with 49 additions and 61 deletions

View File

@@ -25,7 +25,7 @@ import { HistoryItem } from "./shared/HistoryItem"
import { combineApiRequests } from "./shared/combineApiRequests" import { combineApiRequests } from "./shared/combineApiRequests"
import { combineCommandSequences } from "./shared/combineCommandSequences" import { combineCommandSequences } from "./shared/combineCommandSequences"
import { findLastIndex } from "./utils" import { findLastIndex } from "./utils"
import { slidingWindowContextManagement } from "./utils/context-management" import { isWithinContextWindow, truncateHalfConversation } from "./utils/context-management"
const SYSTEM_PROMPT = const SYSTEM_PROMPT =
() => `You are Claude Dev, a highly skilled software developer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices. () => `You are Claude Dev, a highly skilled software developer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
@@ -1253,13 +1253,21 @@ The following additional instructions are provided by the user. They should be f
${this.customInstructions.trim()} ${this.customInstructions.trim()}
` `
} }
const adjustedMessages = slidingWindowContextManagement( const isPromptWithinContextWindow = isWithinContextWindow(
this.api.getModel().info.contextWindow, this.api.getModel().info.contextWindow,
systemPrompt,
tools,
this.apiConversationHistory
)
if (!isPromptWithinContextWindow) {
const truncatedMessages = truncateHalfConversation(this.apiConversationHistory)
await this.overwriteApiConversationHistory(truncatedMessages)
}
const { message, userCredits } = await this.api.createMessage(
systemPrompt, systemPrompt,
this.apiConversationHistory, this.apiConversationHistory,
tools tools
) )
const { message, userCredits } = await this.api.createMessage(systemPrompt, adjustedMessages, tools)
if (userCredits !== undefined) { if (userCredits !== undefined) {
console.log("Updating credits", userCredits) console.log("Updating credits", userCredits)
// TODO: update credits // TODO: update credits

View File

@@ -2,74 +2,54 @@ import { Anthropic } from "@anthropic-ai/sdk"
import { countTokens } from "@anthropic-ai/tokenizer" import { countTokens } from "@anthropic-ai/tokenizer"
import { Buffer } from "buffer" import { Buffer } from "buffer"
import sizeOf from "image-size" import sizeOf from "image-size"
import cloneDeep from "clone-deep"
export function slidingWindowContextManagement( export function isWithinContextWindow(
contextWindow: number, contextWindow: number,
systemPrompt: string, systemPrompt: string,
messages: Anthropic.Messages.MessageParam[], tools: Anthropic.Messages.Tool[],
tools: Anthropic.Messages.Tool[] messages: Anthropic.Messages.MessageParam[]
): Anthropic.Messages.MessageParam[] { ): boolean {
const adjustedContextWindow = contextWindow - 10_000 // Buffer to account for tokenizer differences const adjustedContextWindow = contextWindow - 10_000 // Buffer to account for tokenizer differences
// counting tokens is expensive, so we first try to estimate before doing a more accurate calculation
const estimatedTotalMessageTokens = countTokens(systemPrompt + JSON.stringify(tools) + JSON.stringify(messages))
if (estimatedTotalMessageTokens <= adjustedContextWindow) {
return true
}
const systemPromptTokens = countTokens(systemPrompt) const systemPromptTokens = countTokens(systemPrompt)
const toolsTokens = countTokens(JSON.stringify(tools)) const toolsTokens = countTokens(JSON.stringify(tools))
let availableTokens = adjustedContextWindow - systemPromptTokens - toolsTokens let availableTokens = adjustedContextWindow - systemPromptTokens - toolsTokens
let totalMessageTokens = messages.reduce((sum, message) => sum + countMessageTokens(message), 0) let accurateTotalMessageTokens = messages.reduce((sum, message) => sum + countMessageTokens(message), 0)
return accurateTotalMessageTokens <= availableTokens
}
if (totalMessageTokens <= availableTokens) { /*
return messages We can't implement a dynamically updating sliding window as it would break prompt cache
} every time. To maintain the benefits of caching, we need to keep conversation history
static. This operation should be performed as infrequently as possible. If a user reaches
a 200k context, we can assume that the first half is likely irrelevant to their current task.
Therefore, this function should only be called when absolutely necessary to fit within
context limits, not as a continuous process.
*/
export function truncateHalfConversation(
messages: Anthropic.Messages.MessageParam[]
): Anthropic.Messages.MessageParam[] {
// Anthropic expects messages to be in user-assistant order, and tool use messages must be followed by tool results. We need to maintain this structure while truncating.
// If over limit, remove messages starting from the third message onwards (task and claude's step-by-step thought process are important to keep in context) // Keep the first Task message (likely the most important)
const newMessages = cloneDeep(messages) // since we're manipulating nested objects and arrays, need to deep clone to prevent mutating original history const truncatedMessages = [messages[0]]
let index = 2
while (totalMessageTokens > availableTokens && index < newMessages.length) { // Remove half of user-assistant pairs
const messageToEmpty = newMessages[index] const messagesToRemove = Math.floor(messages.length / 4) * 2 // has to be even number
const originalTokens = countMessageTokens(messageToEmpty) const summaryMessage: Anthropic.Messages.MessageParam = {
// Empty the content of the message (messages must be in a specific order so we can't just remove) role: "assistant",
if (typeof messageToEmpty.content === "string") { content: `(${messagesToRemove} messages were truncated to fit within context limits)`,
messageToEmpty.content = "(truncated due to context limits)"
} else if (Array.isArray(messageToEmpty.content)) {
messageToEmpty.content = messageToEmpty.content.map((item) => {
if (typeof item === "string") {
return {
type: "text",
text: "(truncated due to context limits)",
} as Anthropic.Messages.TextBlockParam
} else if (item.type === "text") {
return {
type: "text",
text: "(truncated due to context limits)",
} as Anthropic.Messages.TextBlockParam
} else if (item.type === "image") {
return {
type: "text",
text: "(image removed due to context limits)",
} as Anthropic.Messages.TextBlockParam
} else if (item.type === "tool_use") {
return { ...item, input: {} } as Anthropic.Messages.ToolUseBlockParam
} else if (item.type === "tool_result") {
return {
...item,
content: Array.isArray(item.content)
? item.content.map((contentItem) =>
contentItem.type === "text"
? { type: "text", text: "(truncated due to context limits)" }
: contentItem.type === "image"
? { type: "text", text: "(image removed due to context limits)" }
: contentItem
)
: "(truncated due to context limits)",
} as Anthropic.Messages.ToolResultBlockParam
} }
return item truncatedMessages.push(summaryMessage)
})
} const remainingMessages = messages.slice(messagesToRemove)
const newTokens = countMessageTokens(messageToEmpty) truncatedMessages.push(...remainingMessages)
totalMessageTokens -= originalTokens - newTokens
index++ return truncatedMessages
}
return newMessages
} }
function countMessageTokens(message: Anthropic.Messages.MessageParam): number { function countMessageTokens(message: Anthropic.Messages.MessageParam): number {