mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-20 04:11:10 -05:00
feat: update Azure AI handler and model picker for improved configuration and error handling
This commit is contained in:
@@ -2,13 +2,12 @@ import { Anthropic } from "@anthropic-ai/sdk"
|
|||||||
import ModelClient from "@azure-rest/ai-inference"
|
import ModelClient from "@azure-rest/ai-inference"
|
||||||
import { isUnexpected } from "@azure-rest/ai-inference"
|
import { isUnexpected } from "@azure-rest/ai-inference"
|
||||||
import { AzureKeyCredential } from "@azure/core-auth"
|
import { AzureKeyCredential } from "@azure/core-auth"
|
||||||
import { ApiHandlerOptions, ModelInfo, AzureDeploymentConfig } from "../../shared/api"
|
import { ApiHandlerOptions, ModelInfo, AzureDeploymentConfig, azureAiModelInfoSaneDefaults } from "../../shared/api"
|
||||||
import { ApiHandler, SingleCompletionHandler } from "../index"
|
import { ApiHandler, SingleCompletionHandler } from "../index"
|
||||||
import { convertToOpenAiMessages } from "../transform/openai-format"
|
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||||
import { ApiStream } from "../transform/stream"
|
import { ApiStream } from "../transform/stream"
|
||||||
|
|
||||||
const DEFAULT_API_VERSION = "2024-02-15-preview"
|
const DEFAULT_API_VERSION = "2024-05-01-preview"
|
||||||
const DEFAULT_MAX_TOKENS = 4096
|
|
||||||
|
|
||||||
export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
|
export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
|
||||||
private options: ApiHandlerOptions
|
private options: ApiHandlerOptions
|
||||||
@@ -32,7 +31,7 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
const modelId = this.options.apiModelId
|
const modelId = this.options.apiModelId
|
||||||
if (!modelId) {
|
if (!modelId) {
|
||||||
return {
|
return {
|
||||||
name: "gpt-35-turbo", // Default deployment name if none specified
|
name: "default",
|
||||||
apiVersion: DEFAULT_API_VERSION,
|
apiVersion: DEFAULT_API_VERSION,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -46,7 +45,6 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no custom config, use model ID as deployment name
|
|
||||||
return {
|
return {
|
||||||
name: modelId,
|
name: modelId,
|
||||||
apiVersion: DEFAULT_API_VERSION,
|
apiVersion: DEFAULT_API_VERSION,
|
||||||
@@ -65,14 +63,17 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
messages: chatMessages,
|
messages: chatMessages,
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
stream: true,
|
stream: true,
|
||||||
max_tokens: DEFAULT_MAX_TOKENS,
|
max_tokens: azureAiModelInfoSaneDefaults.maxTokens,
|
||||||
response_format: { type: "text" },
|
response_format: { type: "text" },
|
||||||
},
|
},
|
||||||
headers: deployment.modelMeshName
|
headers: {
|
||||||
? {
|
"extra-parameters": "drop",
|
||||||
"x-ms-model-mesh-model-name": deployment.modelMeshName,
|
...(deployment.modelMeshName
|
||||||
}
|
? {
|
||||||
: undefined,
|
"azureml-model-deployment": deployment.modelMeshName,
|
||||||
|
}
|
||||||
|
: {}),
|
||||||
|
},
|
||||||
})
|
})
|
||||||
.asNodeStream()
|
.asNodeStream()
|
||||||
|
|
||||||
@@ -118,13 +119,14 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (error instanceof Error) {
|
if (error instanceof Error) {
|
||||||
// Handle Azure-specific error cases
|
// Handle Azure-specific error cases
|
||||||
if ("status" in error && error.status === 429) {
|
if (isUnexpected(error) && error.status === 429) {
|
||||||
throw new Error("Azure AI rate limit exceeded. Please try again later.")
|
throw new Error("Azure AI rate limit exceeded. Please try again later.")
|
||||||
}
|
}
|
||||||
if ("status" in error && error.status === 400) {
|
if (isUnexpected(error)) {
|
||||||
const azureError = error as any
|
// Use proper Model Inference error handling
|
||||||
if (azureError.body?.error?.code === "ContentFilterError") {
|
const message = error.body?.error?.message || error.message
|
||||||
throw new Error("Content was flagged by Azure AI content safety filters")
|
if (error.status === 422) {
|
||||||
|
throw new Error(`Request validation failed: ${message}`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
throw new Error(`Azure AI error: ${error.message}`)
|
throw new Error(`Azure AI error: ${error.message}`)
|
||||||
@@ -135,12 +137,8 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
|
|
||||||
getModel(): { id: string; info: ModelInfo } {
|
getModel(): { id: string; info: ModelInfo } {
|
||||||
return {
|
return {
|
||||||
id: this.options.apiModelId || "gpt-35-turbo",
|
id: this.options.apiModelId || "default",
|
||||||
info: {
|
info: azureAiModelInfoSaneDefaults,
|
||||||
maxTokens: DEFAULT_MAX_TOKENS,
|
|
||||||
contextWindow: 16385, // Conservative default
|
|
||||||
supportsPromptCache: true,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -153,11 +151,14 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
temperature: 0,
|
temperature: 0,
|
||||||
response_format: { type: "text" },
|
response_format: { type: "text" },
|
||||||
},
|
},
|
||||||
headers: deployment.modelMeshName
|
headers: {
|
||||||
? {
|
"extra-parameters": "drop",
|
||||||
"x-ms-model-mesh-model-name": deployment.modelMeshName,
|
...(deployment.modelMeshName
|
||||||
}
|
? {
|
||||||
: undefined,
|
"azureml-model-deployment": deployment.modelMeshName,
|
||||||
|
}
|
||||||
|
: {}),
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
if (isUnexpected(response)) {
|
if (isUnexpected(response)) {
|
||||||
@@ -168,13 +169,13 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (error instanceof Error) {
|
if (error instanceof Error) {
|
||||||
// Handle Azure-specific error cases
|
// Handle Azure-specific error cases
|
||||||
if ("status" in error && error.status === 429) {
|
if (isUnexpected(error) && error.status === 429) {
|
||||||
throw new Error("Azure AI rate limit exceeded. Please try again later.")
|
throw new Error("Azure AI rate limit exceeded. Please try again later.")
|
||||||
}
|
}
|
||||||
if ("status" in error && error.status === 400) {
|
if (isUnexpected(error)) {
|
||||||
const azureError = error as any
|
const message = error.body?.error?.message || error.message
|
||||||
if (azureError.body?.error?.code === "ContentFilterError") {
|
if (error.status === 422) {
|
||||||
throw new Error("Content was flagged by Azure AI content safety filters")
|
throw new Error(`Request validation failed: ${message}`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
throw new Error(`Azure AI completion error: ${error.message}`)
|
throw new Error(`Azure AI completion error: ${error.message}`)
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ import OpenRouterModelPicker, {
|
|||||||
OPENROUTER_MODEL_PICKER_Z_INDEX,
|
OPENROUTER_MODEL_PICKER_Z_INDEX,
|
||||||
} from "./OpenRouterModelPicker"
|
} from "./OpenRouterModelPicker"
|
||||||
import OpenAiModelPicker from "./OpenAiModelPicker"
|
import OpenAiModelPicker from "./OpenAiModelPicker"
|
||||||
|
import AzureAiModelPicker from "./AzureAiModelPicker"
|
||||||
import GlamaModelPicker from "./GlamaModelPicker"
|
import GlamaModelPicker from "./GlamaModelPicker"
|
||||||
|
|
||||||
interface ApiOptionsProps {
|
interface ApiOptionsProps {
|
||||||
@@ -138,6 +139,7 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) =
|
|||||||
options={[
|
options={[
|
||||||
{ value: "openrouter", label: "OpenRouter" },
|
{ value: "openrouter", label: "OpenRouter" },
|
||||||
{ value: "anthropic", label: "Anthropic" },
|
{ value: "anthropic", label: "Anthropic" },
|
||||||
|
{ value: "azure-ai", label: "Azure AI Model Inference" },
|
||||||
{ value: "gemini", label: "Google Gemini" },
|
{ value: "gemini", label: "Google Gemini" },
|
||||||
{ value: "deepseek", label: "DeepSeek" },
|
{ value: "deepseek", label: "DeepSeek" },
|
||||||
{ value: "openai-native", label: "OpenAI" },
|
{ value: "openai-native", label: "OpenAI" },
|
||||||
@@ -208,6 +210,8 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) =
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{selectedProvider === "azure-ai" && <AzureAiModelPicker />}
|
||||||
|
|
||||||
{selectedProvider === "glama" && (
|
{selectedProvider === "glama" && (
|
||||||
<div>
|
<div>
|
||||||
<VSCodeTextField
|
<VSCodeTextField
|
||||||
@@ -1556,6 +1560,12 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
|
|||||||
selectedModelId: apiConfiguration?.openRouterModelId || openRouterDefaultModelId,
|
selectedModelId: apiConfiguration?.openRouterModelId || openRouterDefaultModelId,
|
||||||
selectedModelInfo: apiConfiguration?.openRouterModelInfo || openRouterDefaultModelInfo,
|
selectedModelInfo: apiConfiguration?.openRouterModelInfo || openRouterDefaultModelInfo,
|
||||||
}
|
}
|
||||||
|
case "azure-ai":
|
||||||
|
return {
|
||||||
|
selectedProvider: provider,
|
||||||
|
selectedModelId: apiConfiguration?.apiModelId || "",
|
||||||
|
selectedModelInfo: azureAiModelInfoSaneDefaults,
|
||||||
|
}
|
||||||
case "openai":
|
case "openai":
|
||||||
return {
|
return {
|
||||||
selectedProvider: provider,
|
selectedProvider: provider,
|
||||||
|
|||||||
@@ -1,56 +1,145 @@
|
|||||||
import { VSCodeLink, VSCodeTextField } from "@vscode/webview-ui-toolkit/react"
|
import { VSCodeLink, VSCodeTextField } from "@vscode/webview-ui-toolkit/react"
|
||||||
import { memo } from "react"
|
import { memo } from "react"
|
||||||
import { useExtensionState } from "../../context/ExtensionStateContext"
|
import { useExtensionState } from "../../context/ExtensionStateContext"
|
||||||
|
import { Pane } from "vscrui"
|
||||||
|
import { azureAiModelInfoSaneDefaults } from "../../../../src/shared/api"
|
||||||
|
|
||||||
const AzureAiModelPicker: React.FC = () => {
|
const AzureAiModelPicker: React.FC = () => {
|
||||||
const { apiConfiguration, handleInputChange } = useExtensionState()
|
const { apiConfiguration, handleInputChange } = useExtensionState()
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<div style={{ display: "flex", flexDirection: "column", rowGap: "5px" }}>
|
||||||
<VSCodeTextField
|
<VSCodeTextField
|
||||||
value={apiConfiguration?.azureAiEndpoint || ""}
|
value={apiConfiguration?.azureAiEndpoint || ""}
|
||||||
style={{ width: "100%" }}
|
style={{ width: "100%" }}
|
||||||
type="url"
|
type="url"
|
||||||
onInput={handleInputChange("azureAiEndpoint")}
|
onChange={handleInputChange("azureAiEndpoint")}
|
||||||
placeholder="https://ai-services-resource.services.ai.azure.com/models">
|
placeholder="https://your-endpoint.region.inference.ai.azure.com">
|
||||||
<span style={{ fontWeight: 500 }}>Azure AI Endpoint</span>
|
<span style={{ fontWeight: 500 }}>Base URL</span>
|
||||||
</VSCodeTextField>
|
</VSCodeTextField>
|
||||||
|
|
||||||
<VSCodeTextField
|
<VSCodeTextField
|
||||||
value={apiConfiguration?.azureAiKey || ""}
|
value={apiConfiguration?.azureAiKey || ""}
|
||||||
style={{ width: "100%" }}
|
style={{ width: "100%" }}
|
||||||
type="password"
|
type="password"
|
||||||
onInput={handleInputChange("azureAiKey")}
|
onChange={handleInputChange("azureAiKey")}
|
||||||
placeholder="Enter API Key...">
|
placeholder="Enter API Key...">
|
||||||
<span style={{ fontWeight: 500 }}>Azure AI Key</span>
|
<span style={{ fontWeight: 500 }}>API Key</span>
|
||||||
</VSCodeTextField>
|
</VSCodeTextField>
|
||||||
|
|
||||||
<VSCodeTextField
|
<VSCodeTextField
|
||||||
value={apiConfiguration?.apiModelId || ""}
|
value={apiConfiguration?.apiModelId || ""}
|
||||||
style={{ width: "100%" }}
|
style={{ width: "100%" }}
|
||||||
type="text"
|
type="text"
|
||||||
onInput={handleInputChange("apiModelId")}
|
onChange={handleInputChange("apiModelId")}
|
||||||
placeholder="Enter model deployment name...">
|
placeholder="Enter model deployment name...">
|
||||||
<span style={{ fontWeight: 500 }}>Deployment Name</span>
|
<span style={{ fontWeight: 500 }}>Model Deployment Name</span>
|
||||||
</VSCodeTextField>
|
</VSCodeTextField>
|
||||||
|
|
||||||
|
<Pane
|
||||||
|
title="Model Configuration"
|
||||||
|
open={false}
|
||||||
|
actions={[
|
||||||
|
{
|
||||||
|
iconName: "refresh",
|
||||||
|
onClick: () =>
|
||||||
|
handleInputChange("openAiCustomModelInfo")({
|
||||||
|
target: { value: azureAiModelInfoSaneDefaults },
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
]}>
|
||||||
|
<div
|
||||||
|
style={{
|
||||||
|
padding: 15,
|
||||||
|
backgroundColor: "var(--vscode-editor-background)",
|
||||||
|
}}>
|
||||||
|
<p
|
||||||
|
style={{
|
||||||
|
fontSize: "12px",
|
||||||
|
color: "var(--vscode-descriptionForeground)",
|
||||||
|
margin: "0 0 15px 0",
|
||||||
|
lineHeight: "1.4",
|
||||||
|
}}>
|
||||||
|
Configure capabilities for your deployed model.
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<div
|
||||||
|
style={{
|
||||||
|
backgroundColor: "var(--vscode-editor-inactiveSelectionBackground)",
|
||||||
|
padding: "12px",
|
||||||
|
borderRadius: "4px",
|
||||||
|
marginTop: "8px",
|
||||||
|
}}>
|
||||||
|
<span
|
||||||
|
style={{
|
||||||
|
fontSize: "11px",
|
||||||
|
fontWeight: 500,
|
||||||
|
color: "var(--vscode-editor-foreground)",
|
||||||
|
display: "block",
|
||||||
|
marginBottom: "10px",
|
||||||
|
}}>
|
||||||
|
Model Features
|
||||||
|
</span>
|
||||||
|
|
||||||
|
<div style={{ display: "flex", flexDirection: "column", gap: "8px" }}>
|
||||||
|
<VSCodeTextField
|
||||||
|
value={
|
||||||
|
apiConfiguration?.openAiCustomModelInfo?.contextWindow?.toString() ||
|
||||||
|
azureAiModelInfoSaneDefaults.contextWindow?.toString() ||
|
||||||
|
""
|
||||||
|
}
|
||||||
|
type="text"
|
||||||
|
style={{ width: "100%" }}
|
||||||
|
onChange={(e: any) => {
|
||||||
|
const parsed = parseInt(e.target.value)
|
||||||
|
handleInputChange("openAiCustomModelInfo")({
|
||||||
|
target: {
|
||||||
|
value: {
|
||||||
|
...(apiConfiguration?.openAiCustomModelInfo ||
|
||||||
|
azureAiModelInfoSaneDefaults),
|
||||||
|
contextWindow:
|
||||||
|
e.target.value === ""
|
||||||
|
? undefined
|
||||||
|
: isNaN(parsed)
|
||||||
|
? azureAiModelInfoSaneDefaults.contextWindow
|
||||||
|
: parsed,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
placeholder="e.g. 128000">
|
||||||
|
<span style={{ fontWeight: 500 }}>Context Window Size</span>
|
||||||
|
</VSCodeTextField>
|
||||||
|
<p
|
||||||
|
style={{
|
||||||
|
fontSize: "11px",
|
||||||
|
color: "var(--vscode-descriptionForeground)",
|
||||||
|
marginTop: "4px",
|
||||||
|
}}>
|
||||||
|
Total tokens the model can process in a single request.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Pane>
|
||||||
|
|
||||||
<p
|
<p
|
||||||
style={{
|
style={{
|
||||||
fontSize: "12px",
|
fontSize: "12px",
|
||||||
marginTop: "5px",
|
marginTop: "5px",
|
||||||
color: "var(--vscode-descriptionForeground)",
|
color: "var(--vscode-descriptionForeground)",
|
||||||
}}>
|
}}>
|
||||||
Configure your Azure AI Model Inference endpoint and model deployment. The API key is stored locally.
|
Configure your Azure AI Model Inference endpoint and model deployment. API keys are stored locally.
|
||||||
{!apiConfiguration?.azureAiKey && (
|
{!apiConfiguration?.azureAiKey && (
|
||||||
<VSCodeLink
|
<VSCodeLink
|
||||||
href="https://learn.microsoft.com/azure/ai-foundry/model-inference/reference/reference-model-inference-chat-completions"
|
href="https://learn.microsoft.com/azure/ai-foundry/model-inference/reference/reference-model-inference-chat-completions"
|
||||||
style={{ display: "inline", fontSize: "inherit" }}>
|
style={{ display: "inline", fontSize: "inherit" }}>
|
||||||
{" "}
|
{" "}
|
||||||
Learn more about Azure AI Model Inference endpoints.
|
Learn more about Azure AI Model Inference.
|
||||||
</VSCodeLink>
|
</VSCodeLink>
|
||||||
)}
|
)}
|
||||||
</p>
|
</p>
|
||||||
</>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user