feat: update Azure AI handler and model picker for improved configuration and error handling

This commit is contained in:
pacnpal
2025-02-02 20:29:31 -05:00
parent 703cda7678
commit 53e307c8f3
3 changed files with 143 additions and 43 deletions

View File

@@ -2,13 +2,12 @@ import { Anthropic } from "@anthropic-ai/sdk"
import ModelClient from "@azure-rest/ai-inference" import ModelClient from "@azure-rest/ai-inference"
import { isUnexpected } from "@azure-rest/ai-inference" import { isUnexpected } from "@azure-rest/ai-inference"
import { AzureKeyCredential } from "@azure/core-auth" import { AzureKeyCredential } from "@azure/core-auth"
import { ApiHandlerOptions, ModelInfo, AzureDeploymentConfig } from "../../shared/api" import { ApiHandlerOptions, ModelInfo, AzureDeploymentConfig, azureAiModelInfoSaneDefaults } from "../../shared/api"
import { ApiHandler, SingleCompletionHandler } from "../index" import { ApiHandler, SingleCompletionHandler } from "../index"
import { convertToOpenAiMessages } from "../transform/openai-format" import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream" import { ApiStream } from "../transform/stream"
const DEFAULT_API_VERSION = "2024-02-15-preview" const DEFAULT_API_VERSION = "2024-05-01-preview"
const DEFAULT_MAX_TOKENS = 4096
export class AzureAiHandler implements ApiHandler, SingleCompletionHandler { export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions private options: ApiHandlerOptions
@@ -32,7 +31,7 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
const modelId = this.options.apiModelId const modelId = this.options.apiModelId
if (!modelId) { if (!modelId) {
return { return {
name: "gpt-35-turbo", // Default deployment name if none specified name: "default",
apiVersion: DEFAULT_API_VERSION, apiVersion: DEFAULT_API_VERSION,
} }
} }
@@ -46,7 +45,6 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
} }
} }
// If no custom config, use model ID as deployment name
return { return {
name: modelId, name: modelId,
apiVersion: DEFAULT_API_VERSION, apiVersion: DEFAULT_API_VERSION,
@@ -65,14 +63,17 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
messages: chatMessages, messages: chatMessages,
temperature: 0, temperature: 0,
stream: true, stream: true,
max_tokens: DEFAULT_MAX_TOKENS, max_tokens: azureAiModelInfoSaneDefaults.maxTokens,
response_format: { type: "text" }, response_format: { type: "text" },
}, },
headers: deployment.modelMeshName headers: {
? { "extra-parameters": "drop",
"x-ms-model-mesh-model-name": deployment.modelMeshName, ...(deployment.modelMeshName
} ? {
: undefined, "azureml-model-deployment": deployment.modelMeshName,
}
: {}),
},
}) })
.asNodeStream() .asNodeStream()
@@ -118,13 +119,14 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
} catch (error) { } catch (error) {
if (error instanceof Error) { if (error instanceof Error) {
// Handle Azure-specific error cases // Handle Azure-specific error cases
if ("status" in error && error.status === 429) { if (isUnexpected(error) && error.status === 429) {
throw new Error("Azure AI rate limit exceeded. Please try again later.") throw new Error("Azure AI rate limit exceeded. Please try again later.")
} }
if ("status" in error && error.status === 400) { if (isUnexpected(error)) {
const azureError = error as any // Use proper Model Inference error handling
if (azureError.body?.error?.code === "ContentFilterError") { const message = error.body?.error?.message || error.message
throw new Error("Content was flagged by Azure AI content safety filters") if (error.status === 422) {
throw new Error(`Request validation failed: ${message}`)
} }
} }
throw new Error(`Azure AI error: ${error.message}`) throw new Error(`Azure AI error: ${error.message}`)
@@ -135,12 +137,8 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
getModel(): { id: string; info: ModelInfo } { getModel(): { id: string; info: ModelInfo } {
return { return {
id: this.options.apiModelId || "gpt-35-turbo", id: this.options.apiModelId || "default",
info: { info: azureAiModelInfoSaneDefaults,
maxTokens: DEFAULT_MAX_TOKENS,
contextWindow: 16385, // Conservative default
supportsPromptCache: true,
},
} }
} }
@@ -153,11 +151,14 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
temperature: 0, temperature: 0,
response_format: { type: "text" }, response_format: { type: "text" },
}, },
headers: deployment.modelMeshName headers: {
? { "extra-parameters": "drop",
"x-ms-model-mesh-model-name": deployment.modelMeshName, ...(deployment.modelMeshName
} ? {
: undefined, "azureml-model-deployment": deployment.modelMeshName,
}
: {}),
},
}) })
if (isUnexpected(response)) { if (isUnexpected(response)) {
@@ -168,13 +169,13 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
} catch (error) { } catch (error) {
if (error instanceof Error) { if (error instanceof Error) {
// Handle Azure-specific error cases // Handle Azure-specific error cases
if ("status" in error && error.status === 429) { if (isUnexpected(error) && error.status === 429) {
throw new Error("Azure AI rate limit exceeded. Please try again later.") throw new Error("Azure AI rate limit exceeded. Please try again later.")
} }
if ("status" in error && error.status === 400) { if (isUnexpected(error)) {
const azureError = error as any const message = error.body?.error?.message || error.message
if (azureError.body?.error?.code === "ContentFilterError") { if (error.status === 422) {
throw new Error("Content was flagged by Azure AI content safety filters") throw new Error(`Request validation failed: ${message}`)
} }
} }
throw new Error(`Azure AI completion error: ${error.message}`) throw new Error(`Azure AI completion error: ${error.message}`)

View File

@@ -39,6 +39,7 @@ import OpenRouterModelPicker, {
OPENROUTER_MODEL_PICKER_Z_INDEX, OPENROUTER_MODEL_PICKER_Z_INDEX,
} from "./OpenRouterModelPicker" } from "./OpenRouterModelPicker"
import OpenAiModelPicker from "./OpenAiModelPicker" import OpenAiModelPicker from "./OpenAiModelPicker"
import AzureAiModelPicker from "./AzureAiModelPicker"
import GlamaModelPicker from "./GlamaModelPicker" import GlamaModelPicker from "./GlamaModelPicker"
interface ApiOptionsProps { interface ApiOptionsProps {
@@ -138,6 +139,7 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) =
options={[ options={[
{ value: "openrouter", label: "OpenRouter" }, { value: "openrouter", label: "OpenRouter" },
{ value: "anthropic", label: "Anthropic" }, { value: "anthropic", label: "Anthropic" },
{ value: "azure-ai", label: "Azure AI Model Inference" },
{ value: "gemini", label: "Google Gemini" }, { value: "gemini", label: "Google Gemini" },
{ value: "deepseek", label: "DeepSeek" }, { value: "deepseek", label: "DeepSeek" },
{ value: "openai-native", label: "OpenAI" }, { value: "openai-native", label: "OpenAI" },
@@ -208,6 +210,8 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) =
</div> </div>
)} )}
{selectedProvider === "azure-ai" && <AzureAiModelPicker />}
{selectedProvider === "glama" && ( {selectedProvider === "glama" && (
<div> <div>
<VSCodeTextField <VSCodeTextField
@@ -1556,6 +1560,12 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
selectedModelId: apiConfiguration?.openRouterModelId || openRouterDefaultModelId, selectedModelId: apiConfiguration?.openRouterModelId || openRouterDefaultModelId,
selectedModelInfo: apiConfiguration?.openRouterModelInfo || openRouterDefaultModelInfo, selectedModelInfo: apiConfiguration?.openRouterModelInfo || openRouterDefaultModelInfo,
} }
case "azure-ai":
return {
selectedProvider: provider,
selectedModelId: apiConfiguration?.apiModelId || "",
selectedModelInfo: azureAiModelInfoSaneDefaults,
}
case "openai": case "openai":
return { return {
selectedProvider: provider, selectedProvider: provider,

View File

@@ -1,56 +1,145 @@
import { VSCodeLink, VSCodeTextField } from "@vscode/webview-ui-toolkit/react" import { VSCodeLink, VSCodeTextField } from "@vscode/webview-ui-toolkit/react"
import { memo } from "react" import { memo } from "react"
import { useExtensionState } from "../../context/ExtensionStateContext" import { useExtensionState } from "../../context/ExtensionStateContext"
import { Pane } from "vscrui"
import { azureAiModelInfoSaneDefaults } from "../../../../src/shared/api"
const AzureAiModelPicker: React.FC = () => { const AzureAiModelPicker: React.FC = () => {
const { apiConfiguration, handleInputChange } = useExtensionState() const { apiConfiguration, handleInputChange } = useExtensionState()
return ( return (
<> <div style={{ display: "flex", flexDirection: "column", rowGap: "5px" }}>
<VSCodeTextField <VSCodeTextField
value={apiConfiguration?.azureAiEndpoint || ""} value={apiConfiguration?.azureAiEndpoint || ""}
style={{ width: "100%" }} style={{ width: "100%" }}
type="url" type="url"
onInput={handleInputChange("azureAiEndpoint")} onChange={handleInputChange("azureAiEndpoint")}
placeholder="https://ai-services-resource.services.ai.azure.com/models"> placeholder="https://your-endpoint.region.inference.ai.azure.com">
<span style={{ fontWeight: 500 }}>Azure AI Endpoint</span> <span style={{ fontWeight: 500 }}>Base URL</span>
</VSCodeTextField> </VSCodeTextField>
<VSCodeTextField <VSCodeTextField
value={apiConfiguration?.azureAiKey || ""} value={apiConfiguration?.azureAiKey || ""}
style={{ width: "100%" }} style={{ width: "100%" }}
type="password" type="password"
onInput={handleInputChange("azureAiKey")} onChange={handleInputChange("azureAiKey")}
placeholder="Enter API Key..."> placeholder="Enter API Key...">
<span style={{ fontWeight: 500 }}>Azure AI Key</span> <span style={{ fontWeight: 500 }}>API Key</span>
</VSCodeTextField> </VSCodeTextField>
<VSCodeTextField <VSCodeTextField
value={apiConfiguration?.apiModelId || ""} value={apiConfiguration?.apiModelId || ""}
style={{ width: "100%" }} style={{ width: "100%" }}
type="text" type="text"
onInput={handleInputChange("apiModelId")} onChange={handleInputChange("apiModelId")}
placeholder="Enter model deployment name..."> placeholder="Enter model deployment name...">
<span style={{ fontWeight: 500 }}>Deployment Name</span> <span style={{ fontWeight: 500 }}>Model Deployment Name</span>
</VSCodeTextField> </VSCodeTextField>
<Pane
title="Model Configuration"
open={false}
actions={[
{
iconName: "refresh",
onClick: () =>
handleInputChange("openAiCustomModelInfo")({
target: { value: azureAiModelInfoSaneDefaults },
}),
},
]}>
<div
style={{
padding: 15,
backgroundColor: "var(--vscode-editor-background)",
}}>
<p
style={{
fontSize: "12px",
color: "var(--vscode-descriptionForeground)",
margin: "0 0 15px 0",
lineHeight: "1.4",
}}>
Configure capabilities for your deployed model.
</p>
<div
style={{
backgroundColor: "var(--vscode-editor-inactiveSelectionBackground)",
padding: "12px",
borderRadius: "4px",
marginTop: "8px",
}}>
<span
style={{
fontSize: "11px",
fontWeight: 500,
color: "var(--vscode-editor-foreground)",
display: "block",
marginBottom: "10px",
}}>
Model Features
</span>
<div style={{ display: "flex", flexDirection: "column", gap: "8px" }}>
<VSCodeTextField
value={
apiConfiguration?.openAiCustomModelInfo?.contextWindow?.toString() ||
azureAiModelInfoSaneDefaults.contextWindow?.toString() ||
""
}
type="text"
style={{ width: "100%" }}
onChange={(e: any) => {
const parsed = parseInt(e.target.value)
handleInputChange("openAiCustomModelInfo")({
target: {
value: {
...(apiConfiguration?.openAiCustomModelInfo ||
azureAiModelInfoSaneDefaults),
contextWindow:
e.target.value === ""
? undefined
: isNaN(parsed)
? azureAiModelInfoSaneDefaults.contextWindow
: parsed,
},
},
})
}}
placeholder="e.g. 128000">
<span style={{ fontWeight: 500 }}>Context Window Size</span>
</VSCodeTextField>
<p
style={{
fontSize: "11px",
color: "var(--vscode-descriptionForeground)",
marginTop: "4px",
}}>
Total tokens the model can process in a single request.
</p>
</div>
</div>
</div>
</Pane>
<p <p
style={{ style={{
fontSize: "12px", fontSize: "12px",
marginTop: "5px", marginTop: "5px",
color: "var(--vscode-descriptionForeground)", color: "var(--vscode-descriptionForeground)",
}}> }}>
Configure your Azure AI Model Inference endpoint and model deployment. The API key is stored locally. Configure your Azure AI Model Inference endpoint and model deployment. API keys are stored locally.
{!apiConfiguration?.azureAiKey && ( {!apiConfiguration?.azureAiKey && (
<VSCodeLink <VSCodeLink
href="https://learn.microsoft.com/azure/ai-foundry/model-inference/reference/reference-model-inference-chat-completions" href="https://learn.microsoft.com/azure/ai-foundry/model-inference/reference/reference-model-inference-chat-completions"
style={{ display: "inline", fontSize: "inherit" }}> style={{ display: "inline", fontSize: "inherit" }}>
{" "} {" "}
Learn more about Azure AI Model Inference endpoints. Learn more about Azure AI Model Inference.
</VSCodeLink> </VSCodeLink>
)} )}
</p> </p>
</> </div>
) )
} }