mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-20 04:11:10 -05:00
test: enhance Azure AI handler tests with improved mocking and error handling
This commit is contained in:
@@ -3,18 +3,13 @@ import { ApiHandlerOptions } from "../../../shared/api"
|
||||
import { Readable } from "stream"
|
||||
import ModelClient from "@azure-rest/ai-inference"
|
||||
|
||||
// Mock the Azure AI client
|
||||
// Mock isUnexpected separately since it's a named export
|
||||
const mockIsUnexpected = jest.fn()
|
||||
jest.mock("@azure-rest/ai-inference", () => {
|
||||
const mockClient = jest.fn().mockImplementation(() => ({
|
||||
path: jest.fn().mockReturnValue({
|
||||
post: jest.fn(),
|
||||
}),
|
||||
}))
|
||||
|
||||
return {
|
||||
__esModule: true,
|
||||
default: mockClient,
|
||||
isUnexpected: jest.fn(),
|
||||
default: jest.fn(),
|
||||
isUnexpected: () => mockIsUnexpected(),
|
||||
}
|
||||
})
|
||||
|
||||
@@ -27,6 +22,7 @@ describe("AzureAiHandler", () => {
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
mockIsUnexpected.mockReturnValue(false)
|
||||
})
|
||||
|
||||
test("constructs with required options", () => {
|
||||
@@ -47,8 +43,8 @@ describe("AzureAiHandler", () => {
|
||||
})
|
||||
|
||||
test("creates chat completion correctly", async () => {
|
||||
const handler = new AzureAiHandler(mockOptions)
|
||||
const mockResponse = {
|
||||
const mockPost = jest.fn().mockResolvedValue({
|
||||
status: 200,
|
||||
body: {
|
||||
choices: [
|
||||
{
|
||||
@@ -58,94 +54,119 @@ describe("AzureAiHandler", () => {
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
|
||||
mockClient.mockReturnValue({
|
||||
path: jest.fn().mockReturnValue({
|
||||
post: jest.fn().mockResolvedValue(mockResponse),
|
||||
}),
|
||||
} as any)
|
||||
const mockPath = jest.fn().mockReturnValue({ post: mockPost })
|
||||
;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath })
|
||||
|
||||
const handler = new AzureAiHandler(mockOptions)
|
||||
const result = await handler.completePrompt("test prompt")
|
||||
|
||||
expect(result).toBe("test response")
|
||||
expect(mockPath).toHaveBeenCalledWith("/chat/completions")
|
||||
expect(mockPost).toHaveBeenCalledWith(expect.any(Object))
|
||||
})
|
||||
|
||||
test("handles streaming responses correctly", async () => {
|
||||
const handler = new AzureAiHandler(mockOptions)
|
||||
const mockStream = new Readable({
|
||||
read() {
|
||||
this.push('data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}\n\n')
|
||||
this.push(
|
||||
'data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2}}\n\n',
|
||||
)
|
||||
this.push("data: [DONE]\n\n")
|
||||
this.push(null)
|
||||
},
|
||||
})
|
||||
// Create a mock stream that properly emits SSE data
|
||||
class MockReadable extends Readable {
|
||||
private chunks: string[]
|
||||
private index: number
|
||||
|
||||
constructor() {
|
||||
super()
|
||||
this.chunks = [
|
||||
'data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}\n\n',
|
||||
'data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2}}\n\n',
|
||||
"data: [DONE]\n\n",
|
||||
]
|
||||
this.index = 0
|
||||
}
|
||||
|
||||
override _read() {
|
||||
if (this.index < this.chunks.length) {
|
||||
this.push(Buffer.from(this.chunks[this.index++]))
|
||||
} else {
|
||||
this.push(null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const mockStream = new MockReadable()
|
||||
|
||||
// Mock the client response with proper structure
|
||||
const mockResponse = {
|
||||
status: 200,
|
||||
_response: { status: 200 },
|
||||
body: mockStream,
|
||||
}
|
||||
|
||||
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
|
||||
mockClient.mockReturnValue({
|
||||
path: jest.fn().mockReturnValue({
|
||||
post: jest.fn().mockReturnValue({
|
||||
asNodeStream: () => Promise.resolve(mockResponse),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
const mockPost = jest.fn().mockReturnValue({
|
||||
asNodeStream: jest.fn().mockResolvedValue(mockResponse),
|
||||
})
|
||||
const mockPath = jest.fn().mockReturnValue({ post: mockPost })
|
||||
;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath })
|
||||
|
||||
const handler = new AzureAiHandler(mockOptions)
|
||||
const messages = []
|
||||
|
||||
// Process the stream
|
||||
for await (const message of handler.createMessage("system prompt", [])) {
|
||||
messages.push(message)
|
||||
}
|
||||
|
||||
// Verify the results
|
||||
expect(messages).toEqual([
|
||||
{ type: "text", text: "Hello" },
|
||||
{ type: "text", text: " world" },
|
||||
{ type: "usage", inputTokens: 10, outputTokens: 2 },
|
||||
])
|
||||
|
||||
// Verify the client was called correctly
|
||||
expect(mockPath).toHaveBeenCalledWith("/chat/completions")
|
||||
expect(mockPost).toHaveBeenCalledWith({
|
||||
body: {
|
||||
messages: [{ role: "system", content: "system prompt" }],
|
||||
temperature: 0,
|
||||
stream: true,
|
||||
max_tokens: 4096,
|
||||
response_format: { type: "text" },
|
||||
},
|
||||
headers: undefined,
|
||||
})
|
||||
})
|
||||
|
||||
test("handles rate limit errors", async () => {
|
||||
const handler = new AzureAiHandler(mockOptions)
|
||||
const mockError = new Error("Rate limit exceeded")
|
||||
Object.assign(mockError, { status: 429 })
|
||||
Object.defineProperty(mockError, "status", { value: 429 })
|
||||
|
||||
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
|
||||
mockClient.mockReturnValue({
|
||||
path: jest.fn().mockReturnValue({
|
||||
post: jest.fn().mockRejectedValue(mockError),
|
||||
}),
|
||||
} as any)
|
||||
const mockPost = jest.fn().mockRejectedValue(mockError)
|
||||
const mockPath = jest.fn().mockReturnValue({ post: mockPost })
|
||||
;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath })
|
||||
|
||||
const handler = new AzureAiHandler(mockOptions)
|
||||
await expect(handler.completePrompt("test")).rejects.toThrow(
|
||||
"Azure AI rate limit exceeded. Please try again later.",
|
||||
)
|
||||
})
|
||||
|
||||
test("handles content safety errors", async () => {
|
||||
const handler = new AzureAiHandler(mockOptions)
|
||||
const mockError = {
|
||||
status: 400,
|
||||
body: {
|
||||
const mockError = new Error("Content filter error")
|
||||
Object.defineProperty(mockError, "status", { value: 400 })
|
||||
Object.defineProperty(mockError, "body", {
|
||||
value: {
|
||||
error: {
|
||||
code: "ContentFilterError",
|
||||
message: "Content was flagged by content safety filters",
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
|
||||
mockClient.mockReturnValue({
|
||||
path: jest.fn().mockReturnValue({
|
||||
post: jest.fn().mockRejectedValue(mockError),
|
||||
}),
|
||||
} as any)
|
||||
const mockPost = jest.fn().mockRejectedValue(mockError)
|
||||
const mockPath = jest.fn().mockReturnValue({ post: mockPost })
|
||||
;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath })
|
||||
|
||||
const handler = new AzureAiHandler(mockOptions)
|
||||
await expect(handler.completePrompt("test")).rejects.toThrow(
|
||||
"Content was flagged by Azure AI content safety filters",
|
||||
)
|
||||
@@ -158,7 +179,7 @@ describe("AzureAiHandler", () => {
|
||||
})
|
||||
const model = handler.getModel()
|
||||
|
||||
expect(model.id).toBe("azure-gpt-35")
|
||||
expect(model.id).toBe("gpt-35-turbo")
|
||||
expect(model.info).toBeDefined()
|
||||
})
|
||||
|
||||
@@ -179,6 +200,6 @@ describe("AzureAiHandler", () => {
|
||||
const model = handler.getModel()
|
||||
|
||||
expect(model.id).toBe("custom-model")
|
||||
expect(model.info).toBeDefined()
|
||||
expect(model.info.contextWindow).toBe(16385)
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user