Add test coverage

This commit is contained in:
Matt Rubens
2025-01-07 01:52:42 -05:00
parent 537514de44
commit 5fbfe9b775
19 changed files with 3106 additions and 493 deletions

View File

@@ -0,0 +1,230 @@
import { OpenAiNativeHandler } from "../openai-native"
import OpenAI from "openai"
import { ApiHandlerOptions, openAiNativeDefaultModelId } from "../../../shared/api"
import { Anthropic } from "@anthropic-ai/sdk"
// Mock OpenAI
jest.mock("openai")
describe("OpenAiNativeHandler", () => {
let handler: OpenAiNativeHandler
let mockOptions: ApiHandlerOptions
let mockOpenAIClient: jest.Mocked<OpenAI>
let mockCreate: jest.Mock
beforeEach(() => {
// Reset mocks
jest.clearAllMocks()
// Setup mock options
mockOptions = {
openAiNativeApiKey: "test-api-key",
apiModelId: "gpt-4o", // Use the correct model ID from shared/api.ts
}
// Setup mock create function
mockCreate = jest.fn()
// Setup mock OpenAI client
mockOpenAIClient = {
chat: {
completions: {
create: mockCreate,
},
},
} as unknown as jest.Mocked<OpenAI>
// Mock OpenAI constructor
;(OpenAI as jest.MockedClass<typeof OpenAI>).mockImplementation(() => mockOpenAIClient)
// Create handler instance
handler = new OpenAiNativeHandler(mockOptions)
})
describe("constructor", () => {
it("should initialize with provided options", () => {
expect(OpenAI).toHaveBeenCalledWith({
apiKey: mockOptions.openAiNativeApiKey,
})
})
})
describe("getModel", () => {
it("should return specified model when valid", () => {
const result = handler.getModel()
expect(result.id).toBe("gpt-4o") // Use the correct model ID
})
it("should return default model when model ID is invalid", () => {
handler = new OpenAiNativeHandler({
...mockOptions,
apiModelId: "invalid-model" as any,
})
const result = handler.getModel()
expect(result.id).toBe(openAiNativeDefaultModelId)
})
it("should return default model when model ID is not provided", () => {
handler = new OpenAiNativeHandler({
...mockOptions,
apiModelId: undefined,
})
const result = handler.getModel()
expect(result.id).toBe(openAiNativeDefaultModelId)
})
})
describe("createMessage", () => {
const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [
{ role: "user", content: "Hello" },
]
describe("o1 models", () => {
beforeEach(() => {
handler = new OpenAiNativeHandler({
...mockOptions,
apiModelId: "o1-preview",
})
})
it("should handle non-streaming response for o1 models", async () => {
const mockResponse = {
choices: [{ message: { content: "Hello there!" } }],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
},
}
mockCreate.mockResolvedValueOnce(mockResponse)
const generator = handler.createMessage(systemPrompt, messages)
const results = []
for await (const result of generator) {
results.push(result)
}
expect(results).toEqual([
{ type: "text", text: "Hello there!" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
expect(mockCreate).toHaveBeenCalledWith({
model: "o1-preview",
messages: [
{ role: "user", content: systemPrompt },
{ role: "user", content: "Hello" },
],
})
})
it("should handle missing content in response", async () => {
const mockResponse = {
choices: [{ message: { content: null } }],
usage: null,
}
mockCreate.mockResolvedValueOnce(mockResponse)
const generator = handler.createMessage(systemPrompt, messages)
const results = []
for await (const result of generator) {
results.push(result)
}
expect(results).toEqual([
{ type: "text", text: "" },
{ type: "usage", inputTokens: 0, outputTokens: 0 },
])
})
})
describe("streaming models", () => {
beforeEach(() => {
handler = new OpenAiNativeHandler({
...mockOptions,
apiModelId: "gpt-4o",
})
})
it("should handle streaming response", async () => {
const mockStream = [
{ choices: [{ delta: { content: "Hello" } }], usage: null },
{ choices: [{ delta: { content: " there" } }], usage: null },
{ choices: [{ delta: { content: "!" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
]
mockCreate.mockResolvedValueOnce(
(async function* () {
for (const chunk of mockStream) {
yield chunk
}
})()
)
const generator = handler.createMessage(systemPrompt, messages)
const results = []
for await (const result of generator) {
results.push(result)
}
expect(results).toEqual([
{ type: "text", text: "Hello" },
{ type: "text", text: " there" },
{ type: "text", text: "!" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
expect(mockCreate).toHaveBeenCalledWith({
model: "gpt-4o",
temperature: 0,
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: "Hello" },
],
stream: true,
stream_options: { include_usage: true },
})
})
it("should handle empty delta content", async () => {
const mockStream = [
{ choices: [{ delta: {} }], usage: null },
{ choices: [{ delta: { content: null } }], usage: null },
{ choices: [{ delta: { content: "Hello" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
]
mockCreate.mockResolvedValueOnce(
(async function* () {
for (const chunk of mockStream) {
yield chunk
}
})()
)
const generator = handler.createMessage(systemPrompt, messages)
const results = []
for await (const result of generator) {
results.push(result)
}
expect(results).toEqual([
{ type: "text", text: "Hello" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
})
})
it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error("API Error"))
const generator = handler.createMessage(systemPrompt, messages)
await expect(async () => {
for await (const _ of generator) {
// consume generator
}
}).rejects.toThrow("API Error")
})
})
})