Add typing for usageData and adds some unit tests

This commit is contained in:
Vignesh Subbiah
2025-01-29 22:53:43 +05:30
parent 6ccb061d33
commit d6433591b2
2 changed files with 54 additions and 19 deletions

View File

@@ -1,6 +1,5 @@
import { UnboundHandler } from "../unbound" import { UnboundHandler } from "../unbound"
import { ApiHandlerOptions } from "../../../shared/api" import { ApiHandlerOptions } from "../../../shared/api"
import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk" import { Anthropic } from "@anthropic-ai/sdk"
// Mock OpenAI client // Mock OpenAI client
@@ -16,6 +15,7 @@ jest.mock("openai", () => {
create: (...args: any[]) => { create: (...args: any[]) => {
const stream = { const stream = {
[Symbol.asyncIterator]: async function* () { [Symbol.asyncIterator]: async function* () {
// First chunk with content
yield { yield {
choices: [ choices: [
{ {
@@ -24,13 +24,25 @@ jest.mock("openai", () => {
}, },
], ],
} }
// Second chunk with usage data
yield { yield {
choices: [ choices: [{ delta: {}, index: 0 }],
{ usage: {
delta: {}, prompt_tokens: 10,
index: 0, completion_tokens: 5,
total_tokens: 15,
},
}
// Third chunk with cache usage data
yield {
choices: [{ delta: {}, index: 0 }],
usage: {
prompt_tokens: 8,
completion_tokens: 4,
total_tokens: 12,
cache_creation_input_tokens: 3,
cache_read_input_tokens: 2,
}, },
],
} }
}, },
} }
@@ -95,19 +107,37 @@ describe("UnboundHandler", () => {
}, },
] ]
it("should handle streaming responses", async () => { it("should handle streaming responses with text and usage data", async () => {
const stream = handler.createMessage(systemPrompt, messages) const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = [] const chunks: Array<{ type: string } & Record<string, any>> = []
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk) chunks.push(chunk)
} }
expect(chunks.length).toBe(1) expect(chunks.length).toBe(3)
// Verify text chunk
expect(chunks[0]).toEqual({ expect(chunks[0]).toEqual({
type: "text", type: "text",
text: "Test response", text: "Test response",
}) })
// Verify regular usage data
expect(chunks[1]).toEqual({
type: "usage",
inputTokens: 10,
outputTokens: 5,
})
// Verify usage data with cache information
expect(chunks[2]).toEqual({
type: "usage",
inputTokens: 8,
outputTokens: 4,
cacheWriteTokens: 3,
cacheReadTokens: 2,
})
expect(mockCreate).toHaveBeenCalledWith( expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({ expect.objectContaining({
model: "claude-3-5-sonnet-20241022", model: "claude-3-5-sonnet-20241022",

View File

@@ -3,7 +3,12 @@ import OpenAI from "openai"
import { ApiHandler, SingleCompletionHandler } from "../" import { ApiHandler, SingleCompletionHandler } from "../"
import { ApiHandlerOptions, ModelInfo, UnboundModelId, unboundDefaultModelId, unboundModels } from "../../shared/api" import { ApiHandlerOptions, ModelInfo, UnboundModelId, unboundDefaultModelId, unboundModels } from "../../shared/api"
import { convertToOpenAiMessages } from "../transform/openai-format" import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
interface UnboundUsage extends OpenAI.CompletionUsage {
cache_creation_input_tokens?: number
cache_read_input_tokens?: number
}
export class UnboundHandler implements ApiHandler, SingleCompletionHandler { export class UnboundHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions private options: ApiHandlerOptions
@@ -96,7 +101,7 @@ export class UnboundHandler implements ApiHandler, SingleCompletionHandler {
for await (const chunk of completion) { for await (const chunk of completion) {
const delta = chunk.choices[0]?.delta const delta = chunk.choices[0]?.delta
const usage = chunk.usage const usage = chunk.usage as UnboundUsage
if (delta?.content) { if (delta?.content) {
yield { yield {
@@ -106,18 +111,18 @@ export class UnboundHandler implements ApiHandler, SingleCompletionHandler {
} }
if (usage) { if (usage) {
const usageData: any = { const usageData: ApiStreamUsageChunk = {
type: "usage", type: "usage",
inputTokens: usage?.prompt_tokens || 0, inputTokens: usage.prompt_tokens || 0,
outputTokens: usage?.completion_tokens || 0, outputTokens: usage.completion_tokens || 0,
} }
// Only add cache tokens if they exist // Only add cache tokens if they exist
if ((usage as any)?.cache_creation_input_tokens) { if (usage.cache_creation_input_tokens) {
usageData.cacheWriteTokens = (usage as any).cache_creation_input_tokens usageData.cacheWriteTokens = usage.cache_creation_input_tokens
} }
if ((usage as any)?.cache_read_input_tokens) { if (usage.cache_read_input_tokens) {
usageData.cacheReadTokens = (usage as any).cache_read_input_tokens usageData.cacheReadTokens = usage.cache_read_input_tokens
} }
yield usageData yield usageData