diff --git a/src/api/providers/azure-ai.ts b/src/api/providers/azure-ai.ts
index adbfb47..47268f6 100644
--- a/src/api/providers/azure-ai.ts
+++ b/src/api/providers/azure-ai.ts
@@ -2,13 +2,12 @@ import { Anthropic } from "@anthropic-ai/sdk"
import ModelClient from "@azure-rest/ai-inference"
import { isUnexpected } from "@azure-rest/ai-inference"
import { AzureKeyCredential } from "@azure/core-auth"
-import { ApiHandlerOptions, ModelInfo, AzureDeploymentConfig } from "../../shared/api"
+import { ApiHandlerOptions, ModelInfo, AzureDeploymentConfig, azureAiModelInfoSaneDefaults } from "../../shared/api"
import { ApiHandler, SingleCompletionHandler } from "../index"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"
-const DEFAULT_API_VERSION = "2024-02-15-preview"
-const DEFAULT_MAX_TOKENS = 4096
+const DEFAULT_API_VERSION = "2024-05-01-preview"
export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
@@ -32,7 +31,7 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
const modelId = this.options.apiModelId
if (!modelId) {
return {
- name: "gpt-35-turbo", // Default deployment name if none specified
+ name: "default",
apiVersion: DEFAULT_API_VERSION,
}
}
@@ -46,7 +45,6 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
}
}
- // If no custom config, use model ID as deployment name
return {
name: modelId,
apiVersion: DEFAULT_API_VERSION,
@@ -65,14 +63,17 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
messages: chatMessages,
temperature: 0,
stream: true,
- max_tokens: DEFAULT_MAX_TOKENS,
+ max_tokens: azureAiModelInfoSaneDefaults.maxTokens,
response_format: { type: "text" },
},
- headers: deployment.modelMeshName
- ? {
- "x-ms-model-mesh-model-name": deployment.modelMeshName,
- }
- : undefined,
+ headers: {
+ "extra-parameters": "drop",
+ ...(deployment.modelMeshName
+ ? {
+ "azureml-model-deployment": deployment.modelMeshName,
+ }
+ : {}),
+ },
})
.asNodeStream()
@@ -118,13 +119,14 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
} catch (error) {
if (error instanceof Error) {
// Handle Azure-specific error cases
- if ("status" in error && error.status === 429) {
+ if (isUnexpected(error) && error.status === 429) {
throw new Error("Azure AI rate limit exceeded. Please try again later.")
}
- if ("status" in error && error.status === 400) {
- const azureError = error as any
- if (azureError.body?.error?.code === "ContentFilterError") {
- throw new Error("Content was flagged by Azure AI content safety filters")
+ if (isUnexpected(error)) {
+ // Use proper Model Inference error handling
+ const message = error.body?.error?.message || error.message
+ if (error.status === 422) {
+ throw new Error(`Request validation failed: ${message}`)
}
}
throw new Error(`Azure AI error: ${error.message}`)
@@ -135,12 +137,8 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
getModel(): { id: string; info: ModelInfo } {
return {
- id: this.options.apiModelId || "gpt-35-turbo",
- info: {
- maxTokens: DEFAULT_MAX_TOKENS,
- contextWindow: 16385, // Conservative default
- supportsPromptCache: true,
- },
+ id: this.options.apiModelId || "default",
+ info: azureAiModelInfoSaneDefaults,
}
}
@@ -153,11 +151,14 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
temperature: 0,
response_format: { type: "text" },
},
- headers: deployment.modelMeshName
- ? {
- "x-ms-model-mesh-model-name": deployment.modelMeshName,
- }
- : undefined,
+ headers: {
+ "extra-parameters": "drop",
+ ...(deployment.modelMeshName
+ ? {
+ "azureml-model-deployment": deployment.modelMeshName,
+ }
+ : {}),
+ },
})
if (isUnexpected(response)) {
@@ -168,13 +169,13 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
} catch (error) {
if (error instanceof Error) {
// Handle Azure-specific error cases
- if ("status" in error && error.status === 429) {
+ if (isUnexpected(error) && error.status === 429) {
throw new Error("Azure AI rate limit exceeded. Please try again later.")
}
- if ("status" in error && error.status === 400) {
- const azureError = error as any
- if (azureError.body?.error?.code === "ContentFilterError") {
- throw new Error("Content was flagged by Azure AI content safety filters")
+ if (isUnexpected(error)) {
+ const message = error.body?.error?.message || error.message
+ if (error.status === 422) {
+ throw new Error(`Request validation failed: ${message}`)
}
}
throw new Error(`Azure AI completion error: ${error.message}`)
diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx
index 4bdff0b..75ba67f 100644
--- a/webview-ui/src/components/settings/ApiOptions.tsx
+++ b/webview-ui/src/components/settings/ApiOptions.tsx
@@ -39,6 +39,7 @@ import OpenRouterModelPicker, {
OPENROUTER_MODEL_PICKER_Z_INDEX,
} from "./OpenRouterModelPicker"
import OpenAiModelPicker from "./OpenAiModelPicker"
+import AzureAiModelPicker from "./AzureAiModelPicker"
import GlamaModelPicker from "./GlamaModelPicker"
interface ApiOptionsProps {
@@ -138,6 +139,7 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) =
options={[
{ value: "openrouter", label: "OpenRouter" },
{ value: "anthropic", label: "Anthropic" },
+ { value: "azure-ai", label: "Azure AI Model Inference" },
{ value: "gemini", label: "Google Gemini" },
{ value: "deepseek", label: "DeepSeek" },
{ value: "openai-native", label: "OpenAI" },
@@ -208,6 +210,8 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) =
)}
+ {selectedProvider === "azure-ai" &&
+ Configure capabilities for your deployed model. +
+ ++ Total tokens the model can process in a single request. +
+
- Configure your Azure AI Model Inference endpoint and model deployment. The API key is stored locally.
+ Configure your Azure AI Model Inference endpoint and model deployment. API keys are stored locally.
{!apiConfiguration?.azureAiKey && (