mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-20 04:11:10 -05:00
test: enhance Azure AI handler tests with improved mocking and error handling
This commit is contained in:
@@ -3,18 +3,13 @@ import { ApiHandlerOptions } from "../../../shared/api"
|
|||||||
import { Readable } from "stream"
|
import { Readable } from "stream"
|
||||||
import ModelClient from "@azure-rest/ai-inference"
|
import ModelClient from "@azure-rest/ai-inference"
|
||||||
|
|
||||||
// Mock the Azure AI client
|
// Mock isUnexpected separately since it's a named export
|
||||||
|
const mockIsUnexpected = jest.fn()
|
||||||
jest.mock("@azure-rest/ai-inference", () => {
|
jest.mock("@azure-rest/ai-inference", () => {
|
||||||
const mockClient = jest.fn().mockImplementation(() => ({
|
|
||||||
path: jest.fn().mockReturnValue({
|
|
||||||
post: jest.fn(),
|
|
||||||
}),
|
|
||||||
}))
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
__esModule: true,
|
__esModule: true,
|
||||||
default: mockClient,
|
default: jest.fn(),
|
||||||
isUnexpected: jest.fn(),
|
isUnexpected: () => mockIsUnexpected(),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -27,6 +22,7 @@ describe("AzureAiHandler", () => {
|
|||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
jest.clearAllMocks()
|
||||||
|
mockIsUnexpected.mockReturnValue(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
test("constructs with required options", () => {
|
test("constructs with required options", () => {
|
||||||
@@ -47,8 +43,8 @@ describe("AzureAiHandler", () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
test("creates chat completion correctly", async () => {
|
test("creates chat completion correctly", async () => {
|
||||||
const handler = new AzureAiHandler(mockOptions)
|
const mockPost = jest.fn().mockResolvedValue({
|
||||||
const mockResponse = {
|
status: 200,
|
||||||
body: {
|
body: {
|
||||||
choices: [
|
choices: [
|
||||||
{
|
{
|
||||||
@@ -58,94 +54,119 @@ describe("AzureAiHandler", () => {
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
}
|
})
|
||||||
|
|
||||||
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
|
const mockPath = jest.fn().mockReturnValue({ post: mockPost })
|
||||||
mockClient.mockReturnValue({
|
;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath })
|
||||||
path: jest.fn().mockReturnValue({
|
|
||||||
post: jest.fn().mockResolvedValue(mockResponse),
|
|
||||||
}),
|
|
||||||
} as any)
|
|
||||||
|
|
||||||
|
const handler = new AzureAiHandler(mockOptions)
|
||||||
const result = await handler.completePrompt("test prompt")
|
const result = await handler.completePrompt("test prompt")
|
||||||
|
|
||||||
expect(result).toBe("test response")
|
expect(result).toBe("test response")
|
||||||
|
expect(mockPath).toHaveBeenCalledWith("/chat/completions")
|
||||||
|
expect(mockPost).toHaveBeenCalledWith(expect.any(Object))
|
||||||
})
|
})
|
||||||
|
|
||||||
test("handles streaming responses correctly", async () => {
|
test("handles streaming responses correctly", async () => {
|
||||||
const handler = new AzureAiHandler(mockOptions)
|
// Create a mock stream that properly emits SSE data
|
||||||
const mockStream = new Readable({
|
class MockReadable extends Readable {
|
||||||
read() {
|
private chunks: string[]
|
||||||
this.push('data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}\n\n')
|
private index: number
|
||||||
this.push(
|
|
||||||
'data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2}}\n\n',
|
|
||||||
)
|
|
||||||
this.push("data: [DONE]\n\n")
|
|
||||||
this.push(null)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
super()
|
||||||
|
this.chunks = [
|
||||||
|
'data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}\n\n',
|
||||||
|
'data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2}}\n\n',
|
||||||
|
"data: [DONE]\n\n",
|
||||||
|
]
|
||||||
|
this.index = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
override _read() {
|
||||||
|
if (this.index < this.chunks.length) {
|
||||||
|
this.push(Buffer.from(this.chunks[this.index++]))
|
||||||
|
} else {
|
||||||
|
this.push(null)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const mockStream = new MockReadable()
|
||||||
|
|
||||||
|
// Mock the client response with proper structure
|
||||||
const mockResponse = {
|
const mockResponse = {
|
||||||
status: 200,
|
status: 200,
|
||||||
|
_response: { status: 200 },
|
||||||
body: mockStream,
|
body: mockStream,
|
||||||
}
|
}
|
||||||
|
|
||||||
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
|
const mockPost = jest.fn().mockReturnValue({
|
||||||
mockClient.mockReturnValue({
|
asNodeStream: jest.fn().mockResolvedValue(mockResponse),
|
||||||
path: jest.fn().mockReturnValue({
|
})
|
||||||
post: jest.fn().mockReturnValue({
|
const mockPath = jest.fn().mockReturnValue({ post: mockPost })
|
||||||
asNodeStream: () => Promise.resolve(mockResponse),
|
;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath })
|
||||||
}),
|
|
||||||
}),
|
|
||||||
} as any)
|
|
||||||
|
|
||||||
|
const handler = new AzureAiHandler(mockOptions)
|
||||||
const messages = []
|
const messages = []
|
||||||
|
|
||||||
|
// Process the stream
|
||||||
for await (const message of handler.createMessage("system prompt", [])) {
|
for await (const message of handler.createMessage("system prompt", [])) {
|
||||||
messages.push(message)
|
messages.push(message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Verify the results
|
||||||
expect(messages).toEqual([
|
expect(messages).toEqual([
|
||||||
{ type: "text", text: "Hello" },
|
{ type: "text", text: "Hello" },
|
||||||
{ type: "text", text: " world" },
|
{ type: "text", text: " world" },
|
||||||
{ type: "usage", inputTokens: 10, outputTokens: 2 },
|
{ type: "usage", inputTokens: 10, outputTokens: 2 },
|
||||||
])
|
])
|
||||||
|
|
||||||
|
// Verify the client was called correctly
|
||||||
|
expect(mockPath).toHaveBeenCalledWith("/chat/completions")
|
||||||
|
expect(mockPost).toHaveBeenCalledWith({
|
||||||
|
body: {
|
||||||
|
messages: [{ role: "system", content: "system prompt" }],
|
||||||
|
temperature: 0,
|
||||||
|
stream: true,
|
||||||
|
max_tokens: 4096,
|
||||||
|
response_format: { type: "text" },
|
||||||
|
},
|
||||||
|
headers: undefined,
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
test("handles rate limit errors", async () => {
|
test("handles rate limit errors", async () => {
|
||||||
const handler = new AzureAiHandler(mockOptions)
|
|
||||||
const mockError = new Error("Rate limit exceeded")
|
const mockError = new Error("Rate limit exceeded")
|
||||||
Object.assign(mockError, { status: 429 })
|
Object.defineProperty(mockError, "status", { value: 429 })
|
||||||
|
|
||||||
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
|
const mockPost = jest.fn().mockRejectedValue(mockError)
|
||||||
mockClient.mockReturnValue({
|
const mockPath = jest.fn().mockReturnValue({ post: mockPost })
|
||||||
path: jest.fn().mockReturnValue({
|
;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath })
|
||||||
post: jest.fn().mockRejectedValue(mockError),
|
|
||||||
}),
|
|
||||||
} as any)
|
|
||||||
|
|
||||||
|
const handler = new AzureAiHandler(mockOptions)
|
||||||
await expect(handler.completePrompt("test")).rejects.toThrow(
|
await expect(handler.completePrompt("test")).rejects.toThrow(
|
||||||
"Azure AI rate limit exceeded. Please try again later.",
|
"Azure AI rate limit exceeded. Please try again later.",
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
test("handles content safety errors", async () => {
|
test("handles content safety errors", async () => {
|
||||||
const handler = new AzureAiHandler(mockOptions)
|
const mockError = new Error("Content filter error")
|
||||||
const mockError = {
|
Object.defineProperty(mockError, "status", { value: 400 })
|
||||||
status: 400,
|
Object.defineProperty(mockError, "body", {
|
||||||
body: {
|
value: {
|
||||||
error: {
|
error: {
|
||||||
code: "ContentFilterError",
|
code: "ContentFilterError",
|
||||||
message: "Content was flagged by content safety filters",
|
message: "Content was flagged by content safety filters",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
})
|
||||||
|
|
||||||
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
|
const mockPost = jest.fn().mockRejectedValue(mockError)
|
||||||
mockClient.mockReturnValue({
|
const mockPath = jest.fn().mockReturnValue({ post: mockPost })
|
||||||
path: jest.fn().mockReturnValue({
|
;(ModelClient as jest.Mock).mockReturnValue({ path: mockPath })
|
||||||
post: jest.fn().mockRejectedValue(mockError),
|
|
||||||
}),
|
|
||||||
} as any)
|
|
||||||
|
|
||||||
|
const handler = new AzureAiHandler(mockOptions)
|
||||||
await expect(handler.completePrompt("test")).rejects.toThrow(
|
await expect(handler.completePrompt("test")).rejects.toThrow(
|
||||||
"Content was flagged by Azure AI content safety filters",
|
"Content was flagged by Azure AI content safety filters",
|
||||||
)
|
)
|
||||||
@@ -158,7 +179,7 @@ describe("AzureAiHandler", () => {
|
|||||||
})
|
})
|
||||||
const model = handler.getModel()
|
const model = handler.getModel()
|
||||||
|
|
||||||
expect(model.id).toBe("azure-gpt-35")
|
expect(model.id).toBe("gpt-35-turbo")
|
||||||
expect(model.info).toBeDefined()
|
expect(model.info).toBeDefined()
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -179,6 +200,6 @@ describe("AzureAiHandler", () => {
|
|||||||
const model = handler.getModel()
|
const model = handler.getModel()
|
||||||
|
|
||||||
expect(model.id).toBe("custom-model")
|
expect(model.id).toBe("custom-model")
|
||||||
expect(model.info).toBeDefined()
|
expect(model.info.contextWindow).toBe(16385)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user