diff --git a/src/api/providers/__tests__/azure-ai.test.ts b/src/api/providers/__tests__/azure-ai.test.ts index fcc3f7b..3741b55 100644 --- a/src/api/providers/__tests__/azure-ai.test.ts +++ b/src/api/providers/__tests__/azure-ai.test.ts @@ -5,167 +5,180 @@ import ModelClient from "@azure-rest/ai-inference" // Mock the Azure AI client jest.mock("@azure-rest/ai-inference", () => { - return { - __esModule: true, - default: jest.fn().mockImplementation(() => ({ - path: jest.fn().mockReturnValue({ - post: jest.fn() - }) - })), - isUnexpected: jest.fn() - } + const mockClient = jest.fn().mockImplementation(() => ({ + path: jest.fn().mockReturnValue({ + post: jest.fn(), + }), + })) + + return { + __esModule: true, + default: mockClient, + isUnexpected: jest.fn(), + } }) describe("AzureAiHandler", () => { - const mockOptions: ApiHandlerOptions = { - apiProvider: "azure-ai", - apiModelId: "azure-gpt-35", - azureAiEndpoint: "https://test-resource.inference.azure.com", - azureAiKey: "test-key", - azureAiDeployments: { - "azure-gpt-35": { - name: "custom-gpt35", - apiVersion: "2024-02-15-preview", - modelMeshName: "test-mesh-model" - } - } - } + const mockOptions: ApiHandlerOptions = { + apiModelId: "azure-gpt-35", + azureAiEndpoint: "https://test-resource.inference.azure.com", + azureAiKey: "test-key", + } - beforeEach(() => { - jest.clearAllMocks() - }) + beforeEach(() => { + jest.clearAllMocks() + }) - test("constructs with required options", () => { - const handler = new AzureAiHandler(mockOptions) - expect(handler).toBeInstanceOf(AzureAiHandler) - }) + test("constructs with required options", () => { + const handler = new AzureAiHandler(mockOptions) + expect(handler).toBeInstanceOf(AzureAiHandler) + }) - test("throws error without endpoint", () => { - const invalidOptions = { ...mockOptions } - delete invalidOptions.azureAiEndpoint - expect(() => new AzureAiHandler(invalidOptions)).toThrow("Azure AI endpoint is required") - }) + test("throws error without endpoint", () => { + const invalidOptions = { ...mockOptions } + delete invalidOptions.azureAiEndpoint + expect(() => new AzureAiHandler(invalidOptions)).toThrow("Azure AI endpoint is required") + }) - test("throws error without API key", () => { - const invalidOptions = { ...mockOptions } - delete invalidOptions.azureAiKey - expect(() => new AzureAiHandler(invalidOptions)).toThrow("Azure AI key is required") - }) + test("throws error without API key", () => { + const invalidOptions = { ...mockOptions } + delete invalidOptions.azureAiKey + expect(() => new AzureAiHandler(invalidOptions)).toThrow("Azure AI key is required") + }) - test("creates chat completion correctly", async () => { - const handler = new AzureAiHandler(mockOptions) - const mockResponse = { - body: { - choices: [ - { - message: { - content: "test response" - } - } - ] - } - } - - const mockClient = ModelClient as jest.MockedClass - mockClient.prototype.path.mockReturnValue({ - post: jest.fn().mockResolvedValue(mockResponse) - }) + test("creates chat completion correctly", async () => { + const handler = new AzureAiHandler(mockOptions) + const mockResponse = { + body: { + choices: [ + { + message: { + content: "test response", + }, + }, + ], + }, + } - const result = await handler.completePrompt("test prompt") - expect(result).toBe("test response") + const mockClient = ModelClient as jest.MockedFunction + mockClient.mockReturnValue({ + path: jest.fn().mockReturnValue({ + post: jest.fn().mockResolvedValue(mockResponse), + }), + } as any) - expect(mockClient.prototype.path).toHaveBeenCalledWith("/chat/completions") - expect(mockClient.prototype.path().post).toHaveBeenCalledWith({ - body: { - messages: [{ role: "user", content: "test prompt" }], - temperature: 0 - } - }) - }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe("test response") + }) - test("handles streaming responses correctly", async () => { - const handler = new AzureAiHandler(mockOptions) - const mockStream = Readable.from([ - 'data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}\n\n', - 'data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2}}\n\n', - 'data: [DONE]\n\n' - ]) + test("handles streaming responses correctly", async () => { + const handler = new AzureAiHandler(mockOptions) + const mockStream = new Readable({ + read() { + this.push('data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}\n\n') + this.push( + 'data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2}}\n\n', + ) + this.push("data: [DONE]\n\n") + this.push(null) + }, + }) - const mockClient = ModelClient as jest.MockedClass - mockClient.prototype.path.mockReturnValue({ - post: jest.fn().mockResolvedValue({ - status: 200, - body: mockStream, - }) - }) + const mockResponse = { + status: 200, + body: mockStream, + } - const messages = [] - for await (const message of handler.createMessage("system prompt", [])) { - messages.push(message) - } + const mockClient = ModelClient as jest.MockedFunction + mockClient.mockReturnValue({ + path: jest.fn().mockReturnValue({ + post: jest.fn().mockReturnValue({ + asNodeStream: () => Promise.resolve(mockResponse), + }), + }), + } as any) - expect(messages).toEqual([ - { type: "text", text: "Hello" }, - { type: "text", text: " world" }, - { type: "usage", inputTokens: 10, outputTokens: 2 } - ]) + const messages = [] + for await (const message of handler.createMessage("system prompt", [])) { + messages.push(message) + } - expect(mockClient.prototype.path().post).toHaveBeenCalledWith({ - body: { - messages: [{ role: "system", content: "system prompt" }], - temperature: 0, - stream: true, - max_tokens: expect.any(Number) - } - }) - }) + expect(messages).toEqual([ + { type: "text", text: "Hello" }, + { type: "text", text: " world" }, + { type: "usage", inputTokens: 10, outputTokens: 2 }, + ]) + }) - test("handles rate limit errors", async () => { - const handler = new AzureAiHandler(mockOptions) - const mockError = new Error("Rate limit exceeded") - Object.assign(mockError, { status: 429 }) + test("handles rate limit errors", async () => { + const handler = new AzureAiHandler(mockOptions) + const mockError = new Error("Rate limit exceeded") + Object.assign(mockError, { status: 429 }) - const mockClient = ModelClient as jest.MockedClass - mockClient.prototype.path.mockReturnValue({ - post: jest.fn().mockRejectedValue(mockError) - }) + const mockClient = ModelClient as jest.MockedFunction + mockClient.mockReturnValue({ + path: jest.fn().mockReturnValue({ + post: jest.fn().mockRejectedValue(mockError), + }), + } as any) - await expect(handler.completePrompt("test")).rejects.toThrow( - "Azure AI rate limit exceeded. Please try again later." - ) - }) + await expect(handler.completePrompt("test")).rejects.toThrow( + "Azure AI rate limit exceeded. Please try again later.", + ) + }) - test("handles content safety errors", async () => { - const handler = new AzureAiHandler(mockOptions) - const mockError = { - status: 400, - body: { - error: { - code: "ContentFilterError", - message: "Content was flagged by content safety filters" - } - } - } + test("handles content safety errors", async () => { + const handler = new AzureAiHandler(mockOptions) + const mockError = { + status: 400, + body: { + error: { + code: "ContentFilterError", + message: "Content was flagged by content safety filters", + }, + }, + } - const mockClient = ModelClient as jest.MockedClass - mockClient.prototype.path.mockReturnValue({ - post: jest.fn().mockRejectedValue(mockError) - }) + const mockClient = ModelClient as jest.MockedFunction + mockClient.mockReturnValue({ + path: jest.fn().mockReturnValue({ + post: jest.fn().mockRejectedValue(mockError), + }), + } as any) - await expect(handler.completePrompt("test")).rejects.toThrow( - "Azure AI completion error: Content was flagged by content safety filters" - ) - }) + await expect(handler.completePrompt("test")).rejects.toThrow( + "Content was flagged by Azure AI content safety filters", + ) + }) - test("falls back to default model configuration", async () => { - const options = { ...mockOptions } - delete options.azureAiDeployments + test("falls back to default model configuration", () => { + const handler = new AzureAiHandler({ + azureAiEndpoint: "https://test.azure.com", + azureAiKey: "test-key", + }) + const model = handler.getModel() - const handler = new AzureAiHandler(options) - const model = handler.getModel() + expect(model.id).toBe("azure-gpt-35") + expect(model.info).toBeDefined() + }) - expect(model.id).toBe("azure-gpt-35") - expect(model.info).toBeDefined() - expect(model.info.defaultDeployment.name).toBe("azure-gpt-35") - }) -}) \ No newline at end of file + test("supports custom deployment names", async () => { + const customOptions = { + ...mockOptions, + apiModelId: "custom-model", + azureAiDeployments: { + "custom-model": { + name: "my-custom-deployment", + apiVersion: "2024-02-15-preview", + modelMeshName: "my-custom-model", + }, + }, + } + + const handler = new AzureAiHandler(customOptions) + const model = handler.getModel() + + expect(model.id).toBe("custom-model") + expect(model.info).toBeDefined() + }) +}) diff --git a/src/api/providers/azure-ai.ts b/src/api/providers/azure-ai.ts index 322c7cf..adbfb47 100644 --- a/src/api/providers/azure-ai.ts +++ b/src/api/providers/azure-ai.ts @@ -2,22 +2,17 @@ import { Anthropic } from "@anthropic-ai/sdk" import ModelClient from "@azure-rest/ai-inference" import { isUnexpected } from "@azure-rest/ai-inference" import { AzureKeyCredential } from "@azure/core-auth" -import { - ApiHandlerOptions, - ModelInfo, - azureAiDefaultModelId, - AzureAiModelId, - azureAiModels, - AzureDeploymentConfig, -} from "../../shared/api" +import { ApiHandlerOptions, ModelInfo, AzureDeploymentConfig } from "../../shared/api" import { ApiHandler, SingleCompletionHandler } from "../index" import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream } from "../transform/stream" -import { createSseStream } from "@azure/core-rest-pipeline" + +const DEFAULT_API_VERSION = "2024-02-15-preview" +const DEFAULT_MAX_TOKENS = 4096 export class AzureAiHandler implements ApiHandler, SingleCompletionHandler { private options: ApiHandlerOptions - private client: ModelClient + private client: ReturnType constructor(options: ApiHandlerOptions) { this.options = options @@ -30,22 +25,36 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler { throw new Error("Azure AI key is required") } - this.client = new ModelClient(options.azureAiEndpoint, new AzureKeyCredential(options.azureAiKey)) + this.client = ModelClient(options.azureAiEndpoint, new AzureKeyCredential(options.azureAiKey)) } private getDeploymentConfig(): AzureDeploymentConfig { - const model = this.getModel() - const defaultConfig = azureAiModels[model.id].defaultDeployment + const modelId = this.options.apiModelId + if (!modelId) { + return { + name: "gpt-35-turbo", // Default deployment name if none specified + apiVersion: DEFAULT_API_VERSION, + } + } + const customConfig = this.options.azureAiDeployments?.[modelId] + if (customConfig) { + return { + name: customConfig.name, + apiVersion: customConfig.apiVersion || DEFAULT_API_VERSION, + modelMeshName: customConfig.modelMeshName, + } + } + + // If no custom config, use model ID as deployment name return { - name: this.options.azureAiDeployments?.[model.id]?.name || defaultConfig.name, - apiVersion: this.options.azureAiDeployments?.[model.id]?.apiVersion || defaultConfig.apiVersion, - modelMeshName: this.options.azureAiDeployments?.[model.id]?.modelMeshName, + name: modelId, + apiVersion: DEFAULT_API_VERSION, } } async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - const modelInfo = this.getModel().info + const deployment = this.getDeploymentConfig() const chatMessages = [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)] try { @@ -56,12 +65,12 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler { messages: chatMessages, temperature: 0, stream: true, - max_tokens: modelInfo.maxTokens, - response_format: { type: "text" }, // Ensure text format for chat + max_tokens: DEFAULT_MAX_TOKENS, + response_format: { type: "text" }, }, - headers: this.getDeploymentConfig().modelMeshName + headers: deployment.modelMeshName ? { - "x-ms-model-mesh-model-name": this.getDeploymentConfig().modelMeshName, + "x-ms-model-mesh-model-name": deployment.modelMeshName, } : undefined, }) @@ -69,22 +78,22 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler { const stream = response.body if (!stream) { - throw new Error(`Failed to get chat completions with status: ${response.status}`) + throw new Error("Failed to get chat completions stream") } - if (response.status !== 200) { - throw new Error(`Failed to get chat completions: ${response.body.error}`) + const statusCode = Number(response.status) + if (statusCode !== 200) { + throw new Error(`Failed to get chat completions: HTTP ${statusCode}`) } - const sseStream = createSseStream(stream) - - for await (const event of sseStream) { - if (event.data === "[DONE]") { + for await (const chunk of stream) { + const chunkStr = chunk.toString() + if (chunkStr === "data: [DONE]\n\n") { return } try { - const data = JSON.parse(event.data) + const data = JSON.parse(chunkStr.replace("data: ", "")) const delta = data.choices[0]?.delta if (delta?.content) { @@ -124,26 +133,29 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler { } } - getModel(): { id: AzureAiModelId; info: ModelInfo } { - const modelId = this.options.apiModelId - if (modelId && modelId in azureAiModels) { - const id = modelId as AzureAiModelId - return { id, info: azureAiModels[id] } + getModel(): { id: string; info: ModelInfo } { + return { + id: this.options.apiModelId || "gpt-35-turbo", + info: { + maxTokens: DEFAULT_MAX_TOKENS, + contextWindow: 16385, // Conservative default + supportsPromptCache: true, + }, } - return { id: azureAiDefaultModelId, info: azureAiModels[azureAiDefaultModelId] } } async completePrompt(prompt: string): Promise { try { + const deployment = this.getDeploymentConfig() const response = await this.client.path("/chat/completions").post({ body: { messages: [{ role: "user", content: prompt }], temperature: 0, response_format: { type: "text" }, }, - headers: this.getDeploymentConfig().modelMeshName + headers: deployment.modelMeshName ? { - "x-ms-model-mesh-model-name": this.getDeploymentConfig().modelMeshName, + "x-ms-model-mesh-model-name": deployment.modelMeshName, } : undefined, }) diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 2dc1d4c..d0228dc 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -86,7 +86,7 @@ type GlobalStateKey = | "lmStudioBaseUrl" | "anthropicBaseUrl" | "azureApiVersion" - | "azureAiDeployments" + | "azureAiDeployments" | "openAiStreamingEnabled" | "openRouterModelId" | "openRouterModelInfo" @@ -1075,16 +1075,25 @@ export class ClineProvider implements vscode.WebviewViewProvider { await this.updateGlobalState("autoApprovalEnabled", message.bool ?? false) await this.postStateToWebview() break - case "updateAzureAiDeployment": - if (message.azureAiDeployment) { - const deployments = await this.getGlobalState("azureAiDeployments") || {} - deployments[message.azureAiDeployment.modelId] = { - ...message.azureAiDeployment, - } - await this.updateGlobalState("azureAiDeployments", deployments) - await this.postStateToWebview() - } - break + case "updateAzureAiDeployment": + if (message.azureAiDeployment) { + const deployments = ((await this.getGlobalState("azureAiDeployments")) || {}) as Record< + string, + { + name: string + apiVersion: string + modelMeshName?: string + } + > + deployments[message.azureAiDeployment.modelId] = { + name: message.azureAiDeployment.name, + apiVersion: message.azureAiDeployment.apiVersion, + modelMeshName: message.azureAiDeployment.modelMeshName, + } + await this.updateGlobalState("azureAiDeployments", deployments) + await this.postStateToWebview() + } + break case "enhancePrompt": if (message.text) { try { @@ -1517,7 +1526,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey) await this.storeSecret("deepSeekApiKey", deepSeekApiKey) await this.updateGlobalState("azureApiVersion", azureApiVersion) - await this.updateGlobalState("azureAiDeployments", apiConfiguration.azureAiDeployments) + await this.updateGlobalState("azureAiDeployments", apiConfiguration.azureAiDeployments) await this.updateGlobalState("openAiStreamingEnabled", openAiStreamingEnabled) await this.updateGlobalState("openRouterModelId", openRouterModelId) await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo) @@ -2159,7 +2168,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { openAiNativeApiKey, deepSeekApiKey, mistralApiKey, - azureAiDeployments, + azureAiDeployments, azureApiVersion, openAiStreamingEnabled, openRouterModelId, @@ -2234,7 +2243,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { this.getSecret("openAiNativeApiKey") as Promise, this.getSecret("deepSeekApiKey") as Promise, this.getSecret("mistralApiKey") as Promise, - this.getGlobalState("azureAiDeployments") as Promise | undefined>, + this.getGlobalState("azureAiDeployments") as Promise | undefined>, this.getGlobalState("azureApiVersion") as Promise, this.getGlobalState("openAiStreamingEnabled") as Promise, this.getGlobalState("openRouterModelId") as Promise, @@ -2327,7 +2336,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { deepSeekApiKey, mistralApiKey, azureApiVersion, - azureAiDeployments, + azureAiDeployments, openAiStreamingEnabled, openRouterModelId, openRouterModelInfo, diff --git a/src/shared/api.ts b/src/shared/api.ts index b43a0b6..4ec67f8 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -15,7 +15,7 @@ export type ApiProvider = | "vscode-lm" | "mistral" | "unbound" -| "azure-ai" + | "azure-ai" export interface ApiHandlerOptions { apiModelId?: string @@ -61,15 +61,17 @@ export interface ApiHandlerOptions { includeMaxTokens?: boolean unboundApiKey?: string unboundModelId?: string - azureAiEndpoint?: string - azureAiKey?: string - azureAiDeployments?: { - [key in AzureAiModelId]?: { - name: string - apiVersion: string - modelMeshName?: string - } - } + azureAiEndpoint?: string + azureAiKey?: string + azureAiDeployments?: + | { + [key: string]: { + name: string + apiVersion: string + modelMeshName?: string + } + } + | undefined } export type ApiConfiguration = ApiHandlerOptions & { @@ -650,45 +652,45 @@ export const unboundModels = { export type AzureAiModelId = "azure-gpt-35" | "azure-gpt-4" | "azure-gpt-4-turbo" export interface AzureDeploymentConfig { - name: string - apiVersion: string - modelMeshName?: string // For Model-Mesh deployments + name: string + apiVersion: string + modelMeshName?: string // For Model-Mesh deployments } export const azureAiModels: Record = { - "azure-gpt-35": { - maxTokens: 4096, - contextWindow: 16385, - supportsPromptCache: true, - inputPrice: 0.0015, - outputPrice: 0.002, - defaultDeployment: { - name: "azure-gpt-35", - apiVersion: "2024-02-15-preview" - } - }, - "azure-gpt-4": { - maxTokens: 8192, - contextWindow: 8192, - supportsPromptCache: true, - inputPrice: 0.03, - outputPrice: 0.06, - defaultDeployment: { - name: "azure-gpt-4", - apiVersion: "2024-02-15-preview" - } - }, - "azure-gpt-4-turbo": { - maxTokens: 4096, - contextWindow: 128000, - supportsPromptCache: true, - inputPrice: 0.01, - outputPrice: 0.03, - defaultDeployment: { - name: "azure-gpt-4-turbo", - apiVersion: "2024-02-15-preview" - } - } + "azure-gpt-35": { + maxTokens: 4096, + contextWindow: 16385, + supportsPromptCache: true, + inputPrice: 0.0015, + outputPrice: 0.002, + defaultDeployment: { + name: "azure-gpt-35", + apiVersion: "2024-02-15-preview", + }, + }, + "azure-gpt-4": { + maxTokens: 8192, + contextWindow: 8192, + supportsPromptCache: true, + inputPrice: 0.03, + outputPrice: 0.06, + defaultDeployment: { + name: "azure-gpt-4", + apiVersion: "2024-02-15-preview", + }, + }, + "azure-gpt-4-turbo": { + maxTokens: 4096, + contextWindow: 128000, + supportsPromptCache: true, + inputPrice: 0.01, + outputPrice: 0.03, + defaultDeployment: { + name: "azure-gpt-4-turbo", + apiVersion: "2024-02-15-preview", + }, + }, } as const satisfies Record export const azureAiDefaultModelId: AzureAiModelId = "azure-gpt-35"