From 53e307c8f3097a7d58ca2b90ea5436697961eecd Mon Sep 17 00:00:00 2001 From: pacnpal <183241239+pacnpal@users.noreply.github.com> Date: Sun, 2 Feb 2025 20:29:31 -0500 Subject: [PATCH] feat: update Azure AI handler and model picker for improved configuration and error handling --- src/api/providers/azure-ai.ts | 65 +++++----- .../src/components/settings/ApiOptions.tsx | 10 ++ .../settings/AzureAiModelPicker.tsx | 111 ++++++++++++++++-- 3 files changed, 143 insertions(+), 43 deletions(-) diff --git a/src/api/providers/azure-ai.ts b/src/api/providers/azure-ai.ts index adbfb47..47268f6 100644 --- a/src/api/providers/azure-ai.ts +++ b/src/api/providers/azure-ai.ts @@ -2,13 +2,12 @@ import { Anthropic } from "@anthropic-ai/sdk" import ModelClient from "@azure-rest/ai-inference" import { isUnexpected } from "@azure-rest/ai-inference" import { AzureKeyCredential } from "@azure/core-auth" -import { ApiHandlerOptions, ModelInfo, AzureDeploymentConfig } from "../../shared/api" +import { ApiHandlerOptions, ModelInfo, AzureDeploymentConfig, azureAiModelInfoSaneDefaults } from "../../shared/api" import { ApiHandler, SingleCompletionHandler } from "../index" import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream } from "../transform/stream" -const DEFAULT_API_VERSION = "2024-02-15-preview" -const DEFAULT_MAX_TOKENS = 4096 +const DEFAULT_API_VERSION = "2024-05-01-preview" export class AzureAiHandler implements ApiHandler, SingleCompletionHandler { private options: ApiHandlerOptions @@ -32,7 +31,7 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler { const modelId = this.options.apiModelId if (!modelId) { return { - name: "gpt-35-turbo", // Default deployment name if none specified + name: "default", apiVersion: DEFAULT_API_VERSION, } } @@ -46,7 +45,6 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler { } } - // If no custom config, use model ID as deployment name return { name: modelId, apiVersion: DEFAULT_API_VERSION, @@ -65,14 +63,17 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler { messages: chatMessages, temperature: 0, stream: true, - max_tokens: DEFAULT_MAX_TOKENS, + max_tokens: azureAiModelInfoSaneDefaults.maxTokens, response_format: { type: "text" }, }, - headers: deployment.modelMeshName - ? { - "x-ms-model-mesh-model-name": deployment.modelMeshName, - } - : undefined, + headers: { + "extra-parameters": "drop", + ...(deployment.modelMeshName + ? { + "azureml-model-deployment": deployment.modelMeshName, + } + : {}), + }, }) .asNodeStream() @@ -118,13 +119,14 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler { } catch (error) { if (error instanceof Error) { // Handle Azure-specific error cases - if ("status" in error && error.status === 429) { + if (isUnexpected(error) && error.status === 429) { throw new Error("Azure AI rate limit exceeded. Please try again later.") } - if ("status" in error && error.status === 400) { - const azureError = error as any - if (azureError.body?.error?.code === "ContentFilterError") { - throw new Error("Content was flagged by Azure AI content safety filters") + if (isUnexpected(error)) { + // Use proper Model Inference error handling + const message = error.body?.error?.message || error.message + if (error.status === 422) { + throw new Error(`Request validation failed: ${message}`) } } throw new Error(`Azure AI error: ${error.message}`) @@ -135,12 +137,8 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler { getModel(): { id: string; info: ModelInfo } { return { - id: this.options.apiModelId || "gpt-35-turbo", - info: { - maxTokens: DEFAULT_MAX_TOKENS, - contextWindow: 16385, // Conservative default - supportsPromptCache: true, - }, + id: this.options.apiModelId || "default", + info: azureAiModelInfoSaneDefaults, } } @@ -153,11 +151,14 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler { temperature: 0, response_format: { type: "text" }, }, - headers: deployment.modelMeshName - ? { - "x-ms-model-mesh-model-name": deployment.modelMeshName, - } - : undefined, + headers: { + "extra-parameters": "drop", + ...(deployment.modelMeshName + ? { + "azureml-model-deployment": deployment.modelMeshName, + } + : {}), + }, }) if (isUnexpected(response)) { @@ -168,13 +169,13 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler { } catch (error) { if (error instanceof Error) { // Handle Azure-specific error cases - if ("status" in error && error.status === 429) { + if (isUnexpected(error) && error.status === 429) { throw new Error("Azure AI rate limit exceeded. Please try again later.") } - if ("status" in error && error.status === 400) { - const azureError = error as any - if (azureError.body?.error?.code === "ContentFilterError") { - throw new Error("Content was flagged by Azure AI content safety filters") + if (isUnexpected(error)) { + const message = error.body?.error?.message || error.message + if (error.status === 422) { + throw new Error(`Request validation failed: ${message}`) } } throw new Error(`Azure AI completion error: ${error.message}`) diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 4bdff0b..75ba67f 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -39,6 +39,7 @@ import OpenRouterModelPicker, { OPENROUTER_MODEL_PICKER_Z_INDEX, } from "./OpenRouterModelPicker" import OpenAiModelPicker from "./OpenAiModelPicker" +import AzureAiModelPicker from "./AzureAiModelPicker" import GlamaModelPicker from "./GlamaModelPicker" interface ApiOptionsProps { @@ -138,6 +139,7 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = options={[ { value: "openrouter", label: "OpenRouter" }, { value: "anthropic", label: "Anthropic" }, + { value: "azure-ai", label: "Azure AI Model Inference" }, { value: "gemini", label: "Google Gemini" }, { value: "deepseek", label: "DeepSeek" }, { value: "openai-native", label: "OpenAI" }, @@ -208,6 +210,8 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = )} + {selectedProvider === "azure-ai" && } + {selectedProvider === "glama" && (
{ const { apiConfiguration, handleInputChange } = useExtensionState() return ( - <> +
- Azure AI Endpoint + onChange={handleInputChange("azureAiEndpoint")} + placeholder="https://your-endpoint.region.inference.ai.azure.com"> + Base URL - Azure AI Key + API Key - Deployment Name + Model Deployment Name + + handleInputChange("openAiCustomModelInfo")({ + target: { value: azureAiModelInfoSaneDefaults }, + }), + }, + ]}> +
+

+ Configure capabilities for your deployed model. +

+ +
+ + Model Features + + +
+ { + const parsed = parseInt(e.target.value) + handleInputChange("openAiCustomModelInfo")({ + target: { + value: { + ...(apiConfiguration?.openAiCustomModelInfo || + azureAiModelInfoSaneDefaults), + contextWindow: + e.target.value === "" + ? undefined + : isNaN(parsed) + ? azureAiModelInfoSaneDefaults.contextWindow + : parsed, + }, + }, + }) + }} + placeholder="e.g. 128000"> + Context Window Size + +

+ Total tokens the model can process in a single request. +

+
+
+
+
+

- Configure your Azure AI Model Inference endpoint and model deployment. The API key is stored locally. + Configure your Azure AI Model Inference endpoint and model deployment. API keys are stored locally. {!apiConfiguration?.azureAiKey && ( {" "} - Learn more about Azure AI Model Inference endpoints. + Learn more about Azure AI Model Inference. )}

- +
) }