fix: update Azure AI deployment handling to support dynamic model IDs and custom deployment names

This commit is contained in:
pacnpal
2025-02-02 11:32:46 -05:00
parent cfc8c08ec6
commit f6c5303925
4 changed files with 274 additions and 238 deletions

View File

@@ -5,30 +5,24 @@ import ModelClient from "@azure-rest/ai-inference"
// Mock the Azure AI client // Mock the Azure AI client
jest.mock("@azure-rest/ai-inference", () => { jest.mock("@azure-rest/ai-inference", () => {
const mockClient = jest.fn().mockImplementation(() => ({
path: jest.fn().mockReturnValue({
post: jest.fn(),
}),
}))
return { return {
__esModule: true, __esModule: true,
default: jest.fn().mockImplementation(() => ({ default: mockClient,
path: jest.fn().mockReturnValue({ isUnexpected: jest.fn(),
post: jest.fn()
})
})),
isUnexpected: jest.fn()
} }
}) })
describe("AzureAiHandler", () => { describe("AzureAiHandler", () => {
const mockOptions: ApiHandlerOptions = { const mockOptions: ApiHandlerOptions = {
apiProvider: "azure-ai",
apiModelId: "azure-gpt-35", apiModelId: "azure-gpt-35",
azureAiEndpoint: "https://test-resource.inference.azure.com", azureAiEndpoint: "https://test-resource.inference.azure.com",
azureAiKey: "test-key", azureAiKey: "test-key",
azureAiDeployments: {
"azure-gpt-35": {
name: "custom-gpt35",
apiVersion: "2024-02-15-preview",
modelMeshName: "test-mesh-model"
}
}
} }
beforeEach(() => { beforeEach(() => {
@@ -59,45 +53,50 @@ describe("AzureAiHandler", () => {
choices: [ choices: [
{ {
message: { message: {
content: "test response" content: "test response",
} },
} },
] ],
} },
} }
const mockClient = ModelClient as jest.MockedClass<typeof ModelClient> const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
mockClient.prototype.path.mockReturnValue({ mockClient.mockReturnValue({
post: jest.fn().mockResolvedValue(mockResponse) path: jest.fn().mockReturnValue({
}) post: jest.fn().mockResolvedValue(mockResponse),
}),
} as any)
const result = await handler.completePrompt("test prompt") const result = await handler.completePrompt("test prompt")
expect(result).toBe("test response") expect(result).toBe("test response")
expect(mockClient.prototype.path).toHaveBeenCalledWith("/chat/completions")
expect(mockClient.prototype.path().post).toHaveBeenCalledWith({
body: {
messages: [{ role: "user", content: "test prompt" }],
temperature: 0
}
})
}) })
test("handles streaming responses correctly", async () => { test("handles streaming responses correctly", async () => {
const handler = new AzureAiHandler(mockOptions) const handler = new AzureAiHandler(mockOptions)
const mockStream = Readable.from([ const mockStream = new Readable({
'data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}\n\n', read() {
this.push('data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}\n\n')
this.push(
'data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2}}\n\n', 'data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2}}\n\n',
'data: [DONE]\n\n' )
]) this.push("data: [DONE]\n\n")
this.push(null)
},
})
const mockClient = ModelClient as jest.MockedClass<typeof ModelClient> const mockResponse = {
mockClient.prototype.path.mockReturnValue({
post: jest.fn().mockResolvedValue({
status: 200, status: 200,
body: mockStream, body: mockStream,
}) }
})
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
mockClient.mockReturnValue({
path: jest.fn().mockReturnValue({
post: jest.fn().mockReturnValue({
asNodeStream: () => Promise.resolve(mockResponse),
}),
}),
} as any)
const messages = [] const messages = []
for await (const message of handler.createMessage("system prompt", [])) { for await (const message of handler.createMessage("system prompt", [])) {
@@ -107,17 +106,8 @@ describe("AzureAiHandler", () => {
expect(messages).toEqual([ expect(messages).toEqual([
{ type: "text", text: "Hello" }, { type: "text", text: "Hello" },
{ type: "text", text: " world" }, { type: "text", text: " world" },
{ type: "usage", inputTokens: 10, outputTokens: 2 } { type: "usage", inputTokens: 10, outputTokens: 2 },
]) ])
expect(mockClient.prototype.path().post).toHaveBeenCalledWith({
body: {
messages: [{ role: "system", content: "system prompt" }],
temperature: 0,
stream: true,
max_tokens: expect.any(Number)
}
})
}) })
test("handles rate limit errors", async () => { test("handles rate limit errors", async () => {
@@ -125,13 +115,15 @@ describe("AzureAiHandler", () => {
const mockError = new Error("Rate limit exceeded") const mockError = new Error("Rate limit exceeded")
Object.assign(mockError, { status: 429 }) Object.assign(mockError, { status: 429 })
const mockClient = ModelClient as jest.MockedClass<typeof ModelClient> const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
mockClient.prototype.path.mockReturnValue({ mockClient.mockReturnValue({
post: jest.fn().mockRejectedValue(mockError) path: jest.fn().mockReturnValue({
}) post: jest.fn().mockRejectedValue(mockError),
}),
} as any)
await expect(handler.completePrompt("test")).rejects.toThrow( await expect(handler.completePrompt("test")).rejects.toThrow(
"Azure AI rate limit exceeded. Please try again later." "Azure AI rate limit exceeded. Please try again later.",
) )
}) })
@@ -142,30 +134,51 @@ describe("AzureAiHandler", () => {
body: { body: {
error: { error: {
code: "ContentFilterError", code: "ContentFilterError",
message: "Content was flagged by content safety filters" message: "Content was flagged by content safety filters",
} },
} },
} }
const mockClient = ModelClient as jest.MockedClass<typeof ModelClient> const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
mockClient.prototype.path.mockReturnValue({ mockClient.mockReturnValue({
post: jest.fn().mockRejectedValue(mockError) path: jest.fn().mockReturnValue({
}) post: jest.fn().mockRejectedValue(mockError),
}),
} as any)
await expect(handler.completePrompt("test")).rejects.toThrow( await expect(handler.completePrompt("test")).rejects.toThrow(
"Azure AI completion error: Content was flagged by content safety filters" "Content was flagged by Azure AI content safety filters",
) )
}) })
test("falls back to default model configuration", async () => { test("falls back to default model configuration", () => {
const options = { ...mockOptions } const handler = new AzureAiHandler({
delete options.azureAiDeployments azureAiEndpoint: "https://test.azure.com",
azureAiKey: "test-key",
const handler = new AzureAiHandler(options) })
const model = handler.getModel() const model = handler.getModel()
expect(model.id).toBe("azure-gpt-35") expect(model.id).toBe("azure-gpt-35")
expect(model.info).toBeDefined() expect(model.info).toBeDefined()
expect(model.info.defaultDeployment.name).toBe("azure-gpt-35") })
test("supports custom deployment names", async () => {
const customOptions = {
...mockOptions,
apiModelId: "custom-model",
azureAiDeployments: {
"custom-model": {
name: "my-custom-deployment",
apiVersion: "2024-02-15-preview",
modelMeshName: "my-custom-model",
},
},
}
const handler = new AzureAiHandler(customOptions)
const model = handler.getModel()
expect(model.id).toBe("custom-model")
expect(model.info).toBeDefined()
}) })
}) })

View File

@@ -2,22 +2,17 @@ import { Anthropic } from "@anthropic-ai/sdk"
import ModelClient from "@azure-rest/ai-inference" import ModelClient from "@azure-rest/ai-inference"
import { isUnexpected } from "@azure-rest/ai-inference" import { isUnexpected } from "@azure-rest/ai-inference"
import { AzureKeyCredential } from "@azure/core-auth" import { AzureKeyCredential } from "@azure/core-auth"
import { import { ApiHandlerOptions, ModelInfo, AzureDeploymentConfig } from "../../shared/api"
ApiHandlerOptions,
ModelInfo,
azureAiDefaultModelId,
AzureAiModelId,
azureAiModels,
AzureDeploymentConfig,
} from "../../shared/api"
import { ApiHandler, SingleCompletionHandler } from "../index" import { ApiHandler, SingleCompletionHandler } from "../index"
import { convertToOpenAiMessages } from "../transform/openai-format" import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream" import { ApiStream } from "../transform/stream"
import { createSseStream } from "@azure/core-rest-pipeline"
const DEFAULT_API_VERSION = "2024-02-15-preview"
const DEFAULT_MAX_TOKENS = 4096
export class AzureAiHandler implements ApiHandler, SingleCompletionHandler { export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions private options: ApiHandlerOptions
private client: ModelClient private client: ReturnType<typeof ModelClient>
constructor(options: ApiHandlerOptions) { constructor(options: ApiHandlerOptions) {
this.options = options this.options = options
@@ -30,22 +25,36 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
throw new Error("Azure AI key is required") throw new Error("Azure AI key is required")
} }
this.client = new ModelClient(options.azureAiEndpoint, new AzureKeyCredential(options.azureAiKey)) this.client = ModelClient(options.azureAiEndpoint, new AzureKeyCredential(options.azureAiKey))
} }
private getDeploymentConfig(): AzureDeploymentConfig { private getDeploymentConfig(): AzureDeploymentConfig {
const model = this.getModel() const modelId = this.options.apiModelId
const defaultConfig = azureAiModels[model.id].defaultDeployment if (!modelId) {
return { return {
name: this.options.azureAiDeployments?.[model.id]?.name || defaultConfig.name, name: "gpt-35-turbo", // Default deployment name if none specified
apiVersion: this.options.azureAiDeployments?.[model.id]?.apiVersion || defaultConfig.apiVersion, apiVersion: DEFAULT_API_VERSION,
modelMeshName: this.options.azureAiDeployments?.[model.id]?.modelMeshName, }
}
const customConfig = this.options.azureAiDeployments?.[modelId]
if (customConfig) {
return {
name: customConfig.name,
apiVersion: customConfig.apiVersion || DEFAULT_API_VERSION,
modelMeshName: customConfig.modelMeshName,
}
}
// If no custom config, use model ID as deployment name
return {
name: modelId,
apiVersion: DEFAULT_API_VERSION,
} }
} }
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
const modelInfo = this.getModel().info const deployment = this.getDeploymentConfig()
const chatMessages = [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)] const chatMessages = [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)]
try { try {
@@ -56,12 +65,12 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
messages: chatMessages, messages: chatMessages,
temperature: 0, temperature: 0,
stream: true, stream: true,
max_tokens: modelInfo.maxTokens, max_tokens: DEFAULT_MAX_TOKENS,
response_format: { type: "text" }, // Ensure text format for chat response_format: { type: "text" },
}, },
headers: this.getDeploymentConfig().modelMeshName headers: deployment.modelMeshName
? { ? {
"x-ms-model-mesh-model-name": this.getDeploymentConfig().modelMeshName, "x-ms-model-mesh-model-name": deployment.modelMeshName,
} }
: undefined, : undefined,
}) })
@@ -69,22 +78,22 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
const stream = response.body const stream = response.body
if (!stream) { if (!stream) {
throw new Error(`Failed to get chat completions with status: ${response.status}`) throw new Error("Failed to get chat completions stream")
} }
if (response.status !== 200) { const statusCode = Number(response.status)
throw new Error(`Failed to get chat completions: ${response.body.error}`) if (statusCode !== 200) {
throw new Error(`Failed to get chat completions: HTTP ${statusCode}`)
} }
const sseStream = createSseStream(stream) for await (const chunk of stream) {
const chunkStr = chunk.toString()
for await (const event of sseStream) { if (chunkStr === "data: [DONE]\n\n") {
if (event.data === "[DONE]") {
return return
} }
try { try {
const data = JSON.parse(event.data) const data = JSON.parse(chunkStr.replace("data: ", ""))
const delta = data.choices[0]?.delta const delta = data.choices[0]?.delta
if (delta?.content) { if (delta?.content) {
@@ -124,26 +133,29 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
} }
} }
getModel(): { id: AzureAiModelId; info: ModelInfo } { getModel(): { id: string; info: ModelInfo } {
const modelId = this.options.apiModelId return {
if (modelId && modelId in azureAiModels) { id: this.options.apiModelId || "gpt-35-turbo",
const id = modelId as AzureAiModelId info: {
return { id, info: azureAiModels[id] } maxTokens: DEFAULT_MAX_TOKENS,
contextWindow: 16385, // Conservative default
supportsPromptCache: true,
},
} }
return { id: azureAiDefaultModelId, info: azureAiModels[azureAiDefaultModelId] }
} }
async completePrompt(prompt: string): Promise<string> { async completePrompt(prompt: string): Promise<string> {
try { try {
const deployment = this.getDeploymentConfig()
const response = await this.client.path("/chat/completions").post({ const response = await this.client.path("/chat/completions").post({
body: { body: {
messages: [{ role: "user", content: prompt }], messages: [{ role: "user", content: prompt }],
temperature: 0, temperature: 0,
response_format: { type: "text" }, response_format: { type: "text" },
}, },
headers: this.getDeploymentConfig().modelMeshName headers: deployment.modelMeshName
? { ? {
"x-ms-model-mesh-model-name": this.getDeploymentConfig().modelMeshName, "x-ms-model-mesh-model-name": deployment.modelMeshName,
} }
: undefined, : undefined,
}) })

View File

@@ -1077,9 +1077,18 @@ export class ClineProvider implements vscode.WebviewViewProvider {
break break
case "updateAzureAiDeployment": case "updateAzureAiDeployment":
if (message.azureAiDeployment) { if (message.azureAiDeployment) {
const deployments = await this.getGlobalState("azureAiDeployments") || {} const deployments = ((await this.getGlobalState("azureAiDeployments")) || {}) as Record<
string,
{
name: string
apiVersion: string
modelMeshName?: string
}
>
deployments[message.azureAiDeployment.modelId] = { deployments[message.azureAiDeployment.modelId] = {
...message.azureAiDeployment, name: message.azureAiDeployment.name,
apiVersion: message.azureAiDeployment.apiVersion,
modelMeshName: message.azureAiDeployment.modelMeshName,
} }
await this.updateGlobalState("azureAiDeployments", deployments) await this.updateGlobalState("azureAiDeployments", deployments)
await this.postStateToWebview() await this.postStateToWebview()

View File

@@ -15,7 +15,7 @@ export type ApiProvider =
| "vscode-lm" | "vscode-lm"
| "mistral" | "mistral"
| "unbound" | "unbound"
| "azure-ai" | "azure-ai"
export interface ApiHandlerOptions { export interface ApiHandlerOptions {
apiModelId?: string apiModelId?: string
@@ -63,13 +63,15 @@ export interface ApiHandlerOptions {
unboundModelId?: string unboundModelId?: string
azureAiEndpoint?: string azureAiEndpoint?: string
azureAiKey?: string azureAiKey?: string
azureAiDeployments?: { azureAiDeployments?:
[key in AzureAiModelId]?: { | {
[key: string]: {
name: string name: string
apiVersion: string apiVersion: string
modelMeshName?: string modelMeshName?: string
} }
} }
| undefined
} }
export type ApiConfiguration = ApiHandlerOptions & { export type ApiConfiguration = ApiHandlerOptions & {
@@ -664,8 +666,8 @@ export const azureAiModels: Record<AzureAiModelId, ModelInfo & { defaultDeployme
outputPrice: 0.002, outputPrice: 0.002,
defaultDeployment: { defaultDeployment: {
name: "azure-gpt-35", name: "azure-gpt-35",
apiVersion: "2024-02-15-preview" apiVersion: "2024-02-15-preview",
} },
}, },
"azure-gpt-4": { "azure-gpt-4": {
maxTokens: 8192, maxTokens: 8192,
@@ -675,8 +677,8 @@ export const azureAiModels: Record<AzureAiModelId, ModelInfo & { defaultDeployme
outputPrice: 0.06, outputPrice: 0.06,
defaultDeployment: { defaultDeployment: {
name: "azure-gpt-4", name: "azure-gpt-4",
apiVersion: "2024-02-15-preview" apiVersion: "2024-02-15-preview",
} },
}, },
"azure-gpt-4-turbo": { "azure-gpt-4-turbo": {
maxTokens: 4096, maxTokens: 4096,
@@ -686,9 +688,9 @@ export const azureAiModels: Record<AzureAiModelId, ModelInfo & { defaultDeployme
outputPrice: 0.03, outputPrice: 0.03,
defaultDeployment: { defaultDeployment: {
name: "azure-gpt-4-turbo", name: "azure-gpt-4-turbo",
apiVersion: "2024-02-15-preview" apiVersion: "2024-02-15-preview",
} },
} },
} as const satisfies Record<AzureAiModelId, ModelInfo & { defaultDeployment: AzureDeploymentConfig }> } as const satisfies Record<AzureAiModelId, ModelInfo & { defaultDeployment: AzureDeploymentConfig }>
export const azureAiDefaultModelId: AzureAiModelId = "azure-gpt-35" export const azureAiDefaultModelId: AzureAiModelId = "azure-gpt-35"