diff --git a/src/api/providers/__tests__/azure-ai.test.ts b/src/api/providers/__tests__/azure-ai.test.ts index 3741b55..9015720 100644 --- a/src/api/providers/__tests__/azure-ai.test.ts +++ b/src/api/providers/__tests__/azure-ai.test.ts @@ -3,18 +3,13 @@ import { ApiHandlerOptions } from "../../../shared/api" import { Readable } from "stream" import ModelClient from "@azure-rest/ai-inference" -// Mock the Azure AI client +// Mock isUnexpected separately since it's a named export +const mockIsUnexpected = jest.fn() jest.mock("@azure-rest/ai-inference", () => { - const mockClient = jest.fn().mockImplementation(() => ({ - path: jest.fn().mockReturnValue({ - post: jest.fn(), - }), - })) - return { __esModule: true, - default: mockClient, - isUnexpected: jest.fn(), + default: jest.fn(), + isUnexpected: () => mockIsUnexpected(), } }) @@ -27,6 +22,7 @@ describe("AzureAiHandler", () => { beforeEach(() => { jest.clearAllMocks() + mockIsUnexpected.mockReturnValue(false) }) test("constructs with required options", () => { @@ -47,8 +43,8 @@ describe("AzureAiHandler", () => { }) test("creates chat completion correctly", async () => { - const handler = new AzureAiHandler(mockOptions) - const mockResponse = { + const mockPost = jest.fn().mockResolvedValue({ + status: 200, body: { choices: [ { @@ -58,94 +54,119 @@ describe("AzureAiHandler", () => { }, ], }, - } + }) - const mockClient = ModelClient as jest.MockedFunction - mockClient.mockReturnValue({ - path: jest.fn().mockReturnValue({ - post: jest.fn().mockResolvedValue(mockResponse), - }), - } as any) + const mockPath = jest.fn().mockReturnValue({ post: mockPost }) + ;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath }) + const handler = new AzureAiHandler(mockOptions) const result = await handler.completePrompt("test prompt") + expect(result).toBe("test response") + expect(mockPath).toHaveBeenCalledWith("/chat/completions") + expect(mockPost).toHaveBeenCalledWith(expect.any(Object)) }) test("handles streaming responses correctly", async () => { - const handler = new AzureAiHandler(mockOptions) - const mockStream = new Readable({ - read() { - this.push('data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}\n\n') - this.push( - 'data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2}}\n\n', - ) - this.push("data: [DONE]\n\n") - this.push(null) - }, - }) + // Create a mock stream that properly emits SSE data + class MockReadable extends Readable { + private chunks: string[] + private index: number + constructor() { + super() + this.chunks = [ + 'data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}\n\n', + 'data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2}}\n\n', + "data: [DONE]\n\n", + ] + this.index = 0 + } + + override _read() { + if (this.index < this.chunks.length) { + this.push(Buffer.from(this.chunks[this.index++])) + } else { + this.push(null) + } + } + } + + const mockStream = new MockReadable() + + // Mock the client response with proper structure const mockResponse = { status: 200, + _response: { status: 200 }, body: mockStream, } - const mockClient = ModelClient as jest.MockedFunction - mockClient.mockReturnValue({ - path: jest.fn().mockReturnValue({ - post: jest.fn().mockReturnValue({ - asNodeStream: () => Promise.resolve(mockResponse), - }), - }), - } as any) + const mockPost = jest.fn().mockReturnValue({ + asNodeStream: jest.fn().mockResolvedValue(mockResponse), + }) + const mockPath = jest.fn().mockReturnValue({ post: mockPost }) + ;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath }) + const handler = new AzureAiHandler(mockOptions) const messages = [] + + // Process the stream for await (const message of handler.createMessage("system prompt", [])) { messages.push(message) } + // Verify the results expect(messages).toEqual([ { type: "text", text: "Hello" }, { type: "text", text: " world" }, { type: "usage", inputTokens: 10, outputTokens: 2 }, ]) + + // Verify the client was called correctly + expect(mockPath).toHaveBeenCalledWith("/chat/completions") + expect(mockPost).toHaveBeenCalledWith({ + body: { + messages: [{ role: "system", content: "system prompt" }], + temperature: 0, + stream: true, + max_tokens: 4096, + response_format: { type: "text" }, + }, + headers: undefined, + }) }) test("handles rate limit errors", async () => { - const handler = new AzureAiHandler(mockOptions) const mockError = new Error("Rate limit exceeded") - Object.assign(mockError, { status: 429 }) + Object.defineProperty(mockError, "status", { value: 429 }) - const mockClient = ModelClient as jest.MockedFunction - mockClient.mockReturnValue({ - path: jest.fn().mockReturnValue({ - post: jest.fn().mockRejectedValue(mockError), - }), - } as any) + const mockPost = jest.fn().mockRejectedValue(mockError) + const mockPath = jest.fn().mockReturnValue({ post: mockPost }) + ;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath }) + const handler = new AzureAiHandler(mockOptions) await expect(handler.completePrompt("test")).rejects.toThrow( "Azure AI rate limit exceeded. Please try again later.", ) }) test("handles content safety errors", async () => { - const handler = new AzureAiHandler(mockOptions) - const mockError = { - status: 400, - body: { + const mockError = new Error("Content filter error") + Object.defineProperty(mockError, "status", { value: 400 }) + Object.defineProperty(mockError, "body", { + value: { error: { code: "ContentFilterError", message: "Content was flagged by content safety filters", }, }, - } + }) - const mockClient = ModelClient as jest.MockedFunction - mockClient.mockReturnValue({ - path: jest.fn().mockReturnValue({ - post: jest.fn().mockRejectedValue(mockError), - }), - } as any) + const mockPost = jest.fn().mockRejectedValue(mockError) + const mockPath = jest.fn().mockReturnValue({ post: mockPost }) + ;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath }) + const handler = new AzureAiHandler(mockOptions) await expect(handler.completePrompt("test")).rejects.toThrow( "Content was flagged by Azure AI content safety filters", ) @@ -158,7 +179,7 @@ describe("AzureAiHandler", () => { }) const model = handler.getModel() - expect(model.id).toBe("azure-gpt-35") + expect(model.id).toBe("gpt-35-turbo") expect(model.info).toBeDefined() }) @@ -179,6 +200,6 @@ describe("AzureAiHandler", () => { const model = handler.getModel() expect(model.id).toBe("custom-model") - expect(model.info).toBeDefined() + expect(model.info.contextWindow).toBe(16385) }) })