test: enhance Azure AI handler tests with improved mocking and error handling

This commit is contained in:
pacnpal
2025-02-02 13:27:14 -05:00
parent f6c5303925
commit 7874205951

View File

@@ -3,18 +3,13 @@ import { ApiHandlerOptions } from "../../../shared/api"
import { Readable } from "stream"
import ModelClient from "@azure-rest/ai-inference"
// Mock the Azure AI client
// Mock isUnexpected separately since it's a named export
const mockIsUnexpected = jest.fn()
jest.mock("@azure-rest/ai-inference", () => {
const mockClient = jest.fn().mockImplementation(() => ({
path: jest.fn().mockReturnValue({
post: jest.fn(),
}),
}))
return {
__esModule: true,
default: mockClient,
isUnexpected: jest.fn(),
default: jest.fn(),
isUnexpected: () => mockIsUnexpected(),
}
})
@@ -27,6 +22,7 @@ describe("AzureAiHandler", () => {
beforeEach(() => {
jest.clearAllMocks()
mockIsUnexpected.mockReturnValue(false)
})
test("constructs with required options", () => {
@@ -47,8 +43,8 @@ describe("AzureAiHandler", () => {
})
test("creates chat completion correctly", async () => {
const handler = new AzureAiHandler(mockOptions)
const mockResponse = {
const mockPost = jest.fn().mockResolvedValue({
status: 200,
body: {
choices: [
{
@@ -58,94 +54,119 @@ describe("AzureAiHandler", () => {
},
],
},
}
})
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
mockClient.mockReturnValue({
path: jest.fn().mockReturnValue({
post: jest.fn().mockResolvedValue(mockResponse),
}),
} as any)
const mockPath = jest.fn().mockReturnValue({ post: mockPost })
;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath })
const handler = new AzureAiHandler(mockOptions)
const result = await handler.completePrompt("test prompt")
expect(result).toBe("test response")
expect(mockPath).toHaveBeenCalledWith("/chat/completions")
expect(mockPost).toHaveBeenCalledWith(expect.any(Object))
})
test("handles streaming responses correctly", async () => {
const handler = new AzureAiHandler(mockOptions)
const mockStream = new Readable({
read() {
this.push('data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}\n\n')
this.push(
'data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2}}\n\n',
)
this.push("data: [DONE]\n\n")
this.push(null)
},
})
// Create a mock stream that properly emits SSE data
class MockReadable extends Readable {
private chunks: string[]
private index: number
constructor() {
super()
this.chunks = [
'data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}\n\n',
'data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2}}\n\n',
"data: [DONE]\n\n",
]
this.index = 0
}
override _read() {
if (this.index < this.chunks.length) {
this.push(Buffer.from(this.chunks[this.index++]))
} else {
this.push(null)
}
}
}
const mockStream = new MockReadable()
// Mock the client response with proper structure
const mockResponse = {
status: 200,
_response: { status: 200 },
body: mockStream,
}
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
mockClient.mockReturnValue({
path: jest.fn().mockReturnValue({
post: jest.fn().mockReturnValue({
asNodeStream: () => Promise.resolve(mockResponse),
}),
}),
} as any)
const mockPost = jest.fn().mockReturnValue({
asNodeStream: jest.fn().mockResolvedValue(mockResponse),
})
const mockPath = jest.fn().mockReturnValue({ post: mockPost })
;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath })
const handler = new AzureAiHandler(mockOptions)
const messages = []
// Process the stream
for await (const message of handler.createMessage("system prompt", [])) {
messages.push(message)
}
// Verify the results
expect(messages).toEqual([
{ type: "text", text: "Hello" },
{ type: "text", text: " world" },
{ type: "usage", inputTokens: 10, outputTokens: 2 },
])
// Verify the client was called correctly
expect(mockPath).toHaveBeenCalledWith("/chat/completions")
expect(mockPost).toHaveBeenCalledWith({
body: {
messages: [{ role: "system", content: "system prompt" }],
temperature: 0,
stream: true,
max_tokens: 4096,
response_format: { type: "text" },
},
headers: undefined,
})
})
test("handles rate limit errors", async () => {
const handler = new AzureAiHandler(mockOptions)
const mockError = new Error("Rate limit exceeded")
Object.assign(mockError, { status: 429 })
Object.defineProperty(mockError, "status", { value: 429 })
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
mockClient.mockReturnValue({
path: jest.fn().mockReturnValue({
post: jest.fn().mockRejectedValue(mockError),
}),
} as any)
const mockPost = jest.fn().mockRejectedValue(mockError)
const mockPath = jest.fn().mockReturnValue({ post: mockPost })
;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath })
const handler = new AzureAiHandler(mockOptions)
await expect(handler.completePrompt("test")).rejects.toThrow(
"Azure AI rate limit exceeded. Please try again later.",
)
})
test("handles content safety errors", async () => {
const handler = new AzureAiHandler(mockOptions)
const mockError = {
status: 400,
body: {
const mockError = new Error("Content filter error")
Object.defineProperty(mockError, "status", { value: 400 })
Object.defineProperty(mockError, "body", {
value: {
error: {
code: "ContentFilterError",
message: "Content was flagged by content safety filters",
},
},
}
})
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
mockClient.mockReturnValue({
path: jest.fn().mockReturnValue({
post: jest.fn().mockRejectedValue(mockError),
}),
} as any)
const mockPost = jest.fn().mockRejectedValue(mockError)
const mockPath = jest.fn().mockReturnValue({ post: mockPost })
;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath })
const handler = new AzureAiHandler(mockOptions)
await expect(handler.completePrompt("test")).rejects.toThrow(
"Content was flagged by Azure AI content safety filters",
)
@@ -158,7 +179,7 @@ describe("AzureAiHandler", () => {
})
const model = handler.getModel()
expect(model.id).toBe("azure-gpt-35")
expect(model.id).toBe("gpt-35-turbo")
expect(model.info).toBeDefined()
})
@@ -179,6 +200,6 @@ describe("AzureAiHandler", () => {
const model = handler.getModel()
expect(model.id).toBe("custom-model")
expect(model.info).toBeDefined()
expect(model.info.contextWindow).toBe(16385)
})
})