test: enhance Azure AI handler tests with improved mocking and error handling

This commit is contained in:
pacnpal
2025-02-02 13:27:14 -05:00
parent f6c5303925
commit 7874205951

View File

@@ -3,18 +3,13 @@ import { ApiHandlerOptions } from "../../../shared/api"
import { Readable } from "stream" import { Readable } from "stream"
import ModelClient from "@azure-rest/ai-inference" import ModelClient from "@azure-rest/ai-inference"
// Mock the Azure AI client // Mock isUnexpected separately since it's a named export
const mockIsUnexpected = jest.fn()
jest.mock("@azure-rest/ai-inference", () => { jest.mock("@azure-rest/ai-inference", () => {
const mockClient = jest.fn().mockImplementation(() => ({
path: jest.fn().mockReturnValue({
post: jest.fn(),
}),
}))
return { return {
__esModule: true, __esModule: true,
default: mockClient, default: jest.fn(),
isUnexpected: jest.fn(), isUnexpected: () => mockIsUnexpected(),
} }
}) })
@@ -27,6 +22,7 @@ describe("AzureAiHandler", () => {
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks() jest.clearAllMocks()
mockIsUnexpected.mockReturnValue(false)
}) })
test("constructs with required options", () => { test("constructs with required options", () => {
@@ -47,8 +43,8 @@ describe("AzureAiHandler", () => {
}) })
test("creates chat completion correctly", async () => { test("creates chat completion correctly", async () => {
const handler = new AzureAiHandler(mockOptions) const mockPost = jest.fn().mockResolvedValue({
const mockResponse = { status: 200,
body: { body: {
choices: [ choices: [
{ {
@@ -58,94 +54,119 @@ describe("AzureAiHandler", () => {
}, },
], ],
}, },
} })
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient> const mockPath = jest.fn().mockReturnValue({ post: mockPost })
mockClient.mockReturnValue({ ;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath })
path: jest.fn().mockReturnValue({
post: jest.fn().mockResolvedValue(mockResponse),
}),
} as any)
const handler = new AzureAiHandler(mockOptions)
const result = await handler.completePrompt("test prompt") const result = await handler.completePrompt("test prompt")
expect(result).toBe("test response") expect(result).toBe("test response")
expect(mockPath).toHaveBeenCalledWith("/chat/completions")
expect(mockPost).toHaveBeenCalledWith(expect.any(Object))
}) })
test("handles streaming responses correctly", async () => { test("handles streaming responses correctly", async () => {
const handler = new AzureAiHandler(mockOptions) // Create a mock stream that properly emits SSE data
const mockStream = new Readable({ class MockReadable extends Readable {
read() { private chunks: string[]
this.push('data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}\n\n') private index: number
this.push(
'data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2}}\n\n',
)
this.push("data: [DONE]\n\n")
this.push(null)
},
})
constructor() {
super()
this.chunks = [
'data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}\n\n',
'data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2}}\n\n',
"data: [DONE]\n\n",
]
this.index = 0
}
override _read() {
if (this.index < this.chunks.length) {
this.push(Buffer.from(this.chunks[this.index++]))
} else {
this.push(null)
}
}
}
const mockStream = new MockReadable()
// Mock the client response with proper structure
const mockResponse = { const mockResponse = {
status: 200, status: 200,
_response: { status: 200 },
body: mockStream, body: mockStream,
} }
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient> const mockPost = jest.fn().mockReturnValue({
mockClient.mockReturnValue({ asNodeStream: jest.fn().mockResolvedValue(mockResponse),
path: jest.fn().mockReturnValue({ })
post: jest.fn().mockReturnValue({ const mockPath = jest.fn().mockReturnValue({ post: mockPost })
asNodeStream: () => Promise.resolve(mockResponse), ;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath })
}),
}),
} as any)
const handler = new AzureAiHandler(mockOptions)
const messages = [] const messages = []
// Process the stream
for await (const message of handler.createMessage("system prompt", [])) { for await (const message of handler.createMessage("system prompt", [])) {
messages.push(message) messages.push(message)
} }
// Verify the results
expect(messages).toEqual([ expect(messages).toEqual([
{ type: "text", text: "Hello" }, { type: "text", text: "Hello" },
{ type: "text", text: " world" }, { type: "text", text: " world" },
{ type: "usage", inputTokens: 10, outputTokens: 2 }, { type: "usage", inputTokens: 10, outputTokens: 2 },
]) ])
// Verify the client was called correctly
expect(mockPath).toHaveBeenCalledWith("/chat/completions")
expect(mockPost).toHaveBeenCalledWith({
body: {
messages: [{ role: "system", content: "system prompt" }],
temperature: 0,
stream: true,
max_tokens: 4096,
response_format: { type: "text" },
},
headers: undefined,
})
}) })
test("handles rate limit errors", async () => { test("handles rate limit errors", async () => {
const handler = new AzureAiHandler(mockOptions)
const mockError = new Error("Rate limit exceeded") const mockError = new Error("Rate limit exceeded")
Object.assign(mockError, { status: 429 }) Object.defineProperty(mockError, "status", { value: 429 })
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient> const mockPost = jest.fn().mockRejectedValue(mockError)
mockClient.mockReturnValue({ const mockPath = jest.fn().mockReturnValue({ post: mockPost })
path: jest.fn().mockReturnValue({ ;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath })
post: jest.fn().mockRejectedValue(mockError),
}),
} as any)
const handler = new AzureAiHandler(mockOptions)
await expect(handler.completePrompt("test")).rejects.toThrow( await expect(handler.completePrompt("test")).rejects.toThrow(
"Azure AI rate limit exceeded. Please try again later.", "Azure AI rate limit exceeded. Please try again later.",
) )
}) })
test("handles content safety errors", async () => { test("handles content safety errors", async () => {
const handler = new AzureAiHandler(mockOptions) const mockError = new Error("Content filter error")
const mockError = { Object.defineProperty(mockError, "status", { value: 400 })
status: 400, Object.defineProperty(mockError, "body", {
body: { value: {
error: { error: {
code: "ContentFilterError", code: "ContentFilterError",
message: "Content was flagged by content safety filters", message: "Content was flagged by content safety filters",
}, },
}, },
} })
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient> const mockPost = jest.fn().mockRejectedValue(mockError)
mockClient.mockReturnValue({ const mockPath = jest.fn().mockReturnValue({ post: mockPost })
path: jest.fn().mockReturnValue({ ;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath })
post: jest.fn().mockRejectedValue(mockError),
}),
} as any)
const handler = new AzureAiHandler(mockOptions)
await expect(handler.completePrompt("test")).rejects.toThrow( await expect(handler.completePrompt("test")).rejects.toThrow(
"Content was flagged by Azure AI content safety filters", "Content was flagged by Azure AI content safety filters",
) )
@@ -158,7 +179,7 @@ describe("AzureAiHandler", () => {
}) })
const model = handler.getModel() const model = handler.getModel()
expect(model.id).toBe("azure-gpt-35") expect(model.id).toBe("gpt-35-turbo")
expect(model.info).toBeDefined() expect(model.info).toBeDefined()
}) })
@@ -179,6 +200,6 @@ describe("AzureAiHandler", () => {
const model = handler.getModel() const model = handler.getModel()
expect(model.id).toBe("custom-model") expect(model.id).toBe("custom-model")
expect(model.info).toBeDefined() expect(model.info.contextWindow).toBe(16385)
}) })
}) })