From cfc8c08ec60215145e65519a1ea5c280d04ce929 Mon Sep 17 00:00:00 2001 From: pacnpal <183241239+pacnpal@users.noreply.github.com> Date: Sun, 2 Feb 2025 11:32:31 -0500 Subject: [PATCH] fix: update Azure AI handler to improve error handling and support new response format --- azure-ai-inference-provider-plan.md | 337 ---------------------------- package-lock.json | 249 ++++++++++++++++++++ package.json | 4 +- src/api/providers/azure-ai.ts | 274 ++++++++++++---------- 4 files changed, 401 insertions(+), 463 deletions(-) delete mode 100644 azure-ai-inference-provider-plan.md diff --git a/azure-ai-inference-provider-plan.md b/azure-ai-inference-provider-plan.md deleted file mode 100644 index 41fbf2e..0000000 --- a/azure-ai-inference-provider-plan.md +++ /dev/null @@ -1,337 +0,0 @@ -# Azure AI Inference Provider Implementation Plan - -## Overview -This document outlines the implementation plan for adding Azure AI Inference support as a new provider in `src/api/providers/`. While Azure AI uses OpenAI's API format as a base, there are significant differences in implementation that need to be accounted for. - -## Key Differences from OpenAI - -### Endpoint Structure -- OpenAI: `https://api.openai.com/v1/chat/completions` -- Azure: `https://{resource-name}.openai.azure.com/openai/deployments/{deployment-name}/chat/completions?api-version={api-version}` - -### Authentication -- OpenAI: Uses `Authorization: Bearer sk-...` -- Azure: Uses `api-key: {key}` - -### Request Format -- OpenAI: Requires `model` field in request body -- Azure: Omits `model` from body (uses deployment name in URL instead) - -### Special Considerations -- Required API version in URL query parameter -- Model-Mesh deployments require additional header: `x-ms-model-mesh-model-name` -- Different API versions for different features (e.g., 2023-12-01-preview, 2024-02-15-preview) - -## Dependencies - -```typescript -import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI, { AzureOpenAI } from "openai" -import { - ApiHandlerOptions, - ModelInfo, - azureAiDefaultModelId, - AzureAiModelId, - azureAiModels -} from "../../shared/api" -import { ApiHandler, SingleCompletionHandler } from "../index" -import { convertToOpenAiMessages } from "../transform/openai-format" -import { ApiStream } from "../transform/stream" -``` - -## Configuration (shared/api.ts) - -```typescript -export type AzureAiModelId = "gpt-35-turbo" | "gpt-4" | "gpt-4-turbo" - -export interface AzureDeploymentConfig { - name: string - apiVersion: string - modelMeshName?: string // For Model-Mesh deployments -} - -export const azureAiModels: Record = { - "gpt-35-turbo": { - maxTokens: 4096, - contextWindow: 16385, - supportsPromptCache: true, - inputPrice: 0.0015, - outputPrice: 0.002, - defaultDeployment: { - name: "gpt-35-turbo", - apiVersion: "2024-02-15-preview" - } - }, - "gpt-4": { - maxTokens: 8192, - contextWindow: 8192, - supportsPromptCache: true, - inputPrice: 0.03, - outputPrice: 0.06, - defaultDeployment: { - name: "gpt-4", - apiVersion: "2024-02-15-preview" - } - }, - "gpt-4-turbo": { - maxTokens: 4096, - contextWindow: 128000, - supportsPromptCache: true, - inputPrice: 0.01, - outputPrice: 0.03, - defaultDeployment: { - name: "gpt-4-turbo", - apiVersion: "2024-02-15-preview" - } - } -} - -export const azureAiDefaultModelId: AzureAiModelId = "gpt-35-turbo" -``` - -## Implementation (src/api/providers/azure-ai.ts) - -```typescript -export class AzureAiHandler implements ApiHandler, SingleCompletionHandler { - private options: ApiHandlerOptions - private client: AzureOpenAI - - constructor(options: ApiHandlerOptions) { - this.options = options - - if (!options.azureAiEndpoint) { - throw new Error("Azure AI endpoint is required") - } - - if (!options.azureAiKey) { - throw new Error("Azure AI key is required") - } - - const deployment = this.getDeploymentConfig() - - this.client = new AzureOpenAI({ - apiKey: options.azureAiKey, - endpoint: options.azureAiEndpoint, - deployment: deployment.name, - apiVersion: deployment.apiVersion, - headers: deployment.modelMeshName ? { - 'x-ms-model-mesh-model-name': deployment.modelMeshName - } : undefined - }) - } - - private getDeploymentConfig(): AzureDeploymentConfig { - const model = this.getModel() - const defaultConfig = azureAiModels[model.id].defaultDeployment - - // Override with user-provided deployment names if available - const deploymentName = - this.options.azureAiDeployments?.[model.id]?.name || - defaultConfig.name - - const apiVersion = - this.options.azureAiDeployments?.[model.id]?.apiVersion || - defaultConfig.apiVersion - - const modelMeshName = - this.options.azureAiDeployments?.[model.id]?.modelMeshName - - return { - name: deploymentName, - apiVersion, - modelMeshName - } - } - - async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - const modelInfo = this.getModel().info - - const systemMessage = { - role: "system", - content: systemPrompt - } - - // Note: model parameter is omitted as it's handled by deployment - const requestOptions: Omit = { - messages: [systemMessage, ...convertToOpenAiMessages(messages)], - temperature: 0, - stream: true, - max_tokens: modelInfo.maxTokens - } - - try { - const stream = await this.client.chat.completions.create(requestOptions as any) - - for await (const chunk of stream) { - const delta = chunk.choices[0]?.delta - - if (delta?.content) { - yield { - type: "text", - text: delta.content - } - } - - if (chunk.usage) { - yield { - type: "usage", - inputTokens: chunk.usage.prompt_tokens || 0, - outputTokens: chunk.usage.completion_tokens || 0 - } - } - } - } catch (error) { - // Handle Azure-specific error format - if (error instanceof Error) { - const azureError = error as any - throw new Error( - `Azure AI error (${azureError.code || 'Unknown'}): ${azureError.message}` - ) - } - throw error - } - } - - getModel(): { id: AzureAiModelId; info: ModelInfo } { - const modelId = this.options.apiModelId - if (modelId && modelId in azureAiModels) { - const id = modelId as AzureAiModelId - return { id, info: azureAiModels[id] } - } - return { id: azureAiDefaultModelId, info: azureAiModels[azureAiDefaultModelId] } - } - - async completePrompt(prompt: string): Promise { - try { - // Note: model parameter is omitted as it's handled by deployment - const response = await this.client.chat.completions.create({ - messages: [{ role: "user", content: prompt }], - temperature: 0 - } as any) - - return response.choices[0]?.message.content || "" - } catch (error) { - // Handle Azure-specific error format - if (error instanceof Error) { - const azureError = error as any - throw new Error( - `Azure AI completion error (${azureError.code || 'Unknown'}): ${azureError.message}` - ) - } - throw error - } - } -} -``` - -## Required Updates to ApiHandlerOptions - -Add to ApiHandlerOptions interface in shared/api.ts: - -```typescript -azureAiEndpoint?: string -azureAiKey?: string -azureAiDeployments?: { - [key in AzureAiModelId]?: { - name: string - apiVersion: string - modelMeshName?: string - } -} -``` - -## Testing Plan - -1. Create __tests__ directory with azure-ai.test.ts: - ```typescript - describe('AzureAiHandler', () => { - // Test URL construction - test('constructs correct Azure endpoint URL', () => {}) - - // Test authentication - test('sets correct authentication headers', () => {}) - - // Test deployment configuration - test('uses correct deployment names', () => {}) - test('handles Model-Mesh configuration', () => {}) - - // Test error handling - test('handles Azure-specific error format', () => {}) - - // Test request/response format - test('omits model from request body', () => {}) - test('handles Azure response format', () => {}) - }) - ``` - -## Integration Steps - -1. Add Azure AI models and types to shared/api.ts -2. Create azure-ai.ts provider implementation -3. Add provider tests -4. Update API handler options -5. Add deployment configuration support -6. Implement Azure-specific error handling -7. Test with real Azure AI endpoints - -## Error Handling - -Azure returns errors in a specific format: -```typescript -interface AzureError { - code: string // e.g., "InternalServerError", "InvalidRequest" - message: string - target?: string - details?: Array<{ - code: string - message: string - }> -} -``` - -Implementation should: -- Parse Azure error format -- Include error codes in messages -- Handle deployment-specific errors -- Provide clear upgrade paths for API version issues - -## Documentation Updates - -1. Add Azure AI configuration section to README.md: - - Endpoint configuration - - Authentication setup - - Deployment mapping - - API version selection - - Model-Mesh support - -2. Document configuration examples: - ```typescript - { - azureAiEndpoint: "https://your-resource.openai.azure.com", - azureAiKey: "your-api-key", - azureAiDeployments: { - "gpt-4": { - name: "your-gpt4-deployment", - apiVersion: "2024-02-15-preview", - modelMeshName: "optional-model-mesh-name" - } - } - } - ``` - -## Future Improvements - -1. Support for Azure-specific features: - - Fine-tuning endpoints - - Custom deployment configurations - - Managed identity authentication - -2. Performance optimizations: - - Connection pooling - - Regional endpoint selection - - Automatic API version negotiation - -3. Advanced features: - - Response format control - - Function calling support - - Vision model support \ No newline at end of file diff --git a/package-lock.json b/package-lock.json index eecc0e6..86f8c84 100644 --- a/package-lock.json +++ b/package-lock.json @@ -12,6 +12,8 @@ "@anthropic-ai/sdk": "^0.26.0", "@anthropic-ai/vertex-sdk": "^0.4.1", "@aws-sdk/client-bedrock-runtime": "^3.706.0", + "@azure-rest/ai-inference": "^1.0.0-beta.5", + "@azure/core-auth": "^1.5.0", "@google/generative-ai": "^0.18.0", "@mistralai/mistralai": "^1.3.6", "@modelcontextprotocol/sdk": "^1.0.1", @@ -2222,6 +2224,253 @@ "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==" }, + "node_modules/@azure-rest/ai-inference": { + "version": "1.0.0-beta.5", + "resolved": "https://registry.npmjs.org/@azure-rest/ai-inference/-/ai-inference-1.0.0-beta.5.tgz", + "integrity": "sha512-G6tAWR7DGHTfWx5+N5csTWX304lWNWeePXHx1LBYKLhTeonNTY4OrpqC6DD12oPxLuK0WbEJ3JXK/A3HdKj+BA==", + "license": "MIT", + "dependencies": { + "@azure-rest/core-client": "^2.1.0", + "@azure/abort-controller": "^1.0.0", + "@azure/core-auth": "^1.7.2", + "@azure/core-lro": "^2.6.0", + "@azure/core-rest-pipeline": "^1.14.0", + "@azure/core-tracing": "^1.2.0", + "@azure/logger": "^1.0.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure-rest/ai-inference/node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "license": "0BSD" + }, + "node_modules/@azure-rest/core-client": { + "version": "2.3.2", + "resolved": "https://registry.npmjs.org/@azure-rest/core-client/-/core-client-2.3.2.tgz", + "integrity": "sha512-rS8Z6iNCaGYQZz96SdUpRw75j3b5vRpEJqocSJwnuByrydirubjUkY54pThm7GshRBgh7GdMK4hGOZA6BSeRaw==", + "license": "MIT", + "dependencies": { + "@azure/abort-controller": "^2.0.0", + "@azure/core-auth": "^1.3.0", + "@azure/core-rest-pipeline": "^1.5.0", + "@azure/core-tracing": "^1.0.1", + "@azure/core-util": "^1.0.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure-rest/core-client/node_modules/@azure/abort-controller": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/@azure/abort-controller/-/abort-controller-2.1.2.tgz", + "integrity": "sha512-nBrLsEWm4J2u5LpAPjxADTlq3trDgVZZXHNKabeXZtpq3d3AbN/KGO82R87rdDz5/lYB024rtEf10/q0urNgsA==", + "license": "MIT", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure-rest/core-client/node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "license": "0BSD" + }, + "node_modules/@azure/abort-controller": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@azure/abort-controller/-/abort-controller-1.1.0.tgz", + "integrity": "sha512-TrRLIoSQVzfAJX9H1JeFjzAoDGcoK1IYX1UImfceTZpsyYfWr09Ss1aHW1y5TrrR3iq6RZLBwJ3E24uwPhwahw==", + "license": "MIT", + "dependencies": { + "tslib": "^2.2.0" + }, + "engines": { + "node": ">=12.0.0" + } + }, + "node_modules/@azure/abort-controller/node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "license": "0BSD" + }, + "node_modules/@azure/core-auth": { + "version": "1.9.0", + "resolved": "https://registry.npmjs.org/@azure/core-auth/-/core-auth-1.9.0.tgz", + "integrity": "sha512-FPwHpZywuyasDSLMqJ6fhbOK3TqUdviZNF8OqRGA4W5Ewib2lEEZ+pBsYcBa88B2NGO/SEnYPGhyBqNlE8ilSw==", + "license": "MIT", + "dependencies": { + "@azure/abort-controller": "^2.0.0", + "@azure/core-util": "^1.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/core-auth/node_modules/@azure/abort-controller": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/@azure/abort-controller/-/abort-controller-2.1.2.tgz", + "integrity": "sha512-nBrLsEWm4J2u5LpAPjxADTlq3trDgVZZXHNKabeXZtpq3d3AbN/KGO82R87rdDz5/lYB024rtEf10/q0urNgsA==", + "license": "MIT", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/core-auth/node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "license": "0BSD" + }, + "node_modules/@azure/core-lro": { + "version": "2.7.2", + "resolved": "https://registry.npmjs.org/@azure/core-lro/-/core-lro-2.7.2.tgz", + "integrity": "sha512-0YIpccoX8m/k00O7mDDMdJpbr6mf1yWo2dfmxt5A8XVZVVMz2SSKaEbMCeJRvgQ0IaSlqhjT47p4hVIRRy90xw==", + "license": "MIT", + "dependencies": { + "@azure/abort-controller": "^2.0.0", + "@azure/core-util": "^1.2.0", + "@azure/logger": "^1.0.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/core-lro/node_modules/@azure/abort-controller": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/@azure/abort-controller/-/abort-controller-2.1.2.tgz", + "integrity": "sha512-nBrLsEWm4J2u5LpAPjxADTlq3trDgVZZXHNKabeXZtpq3d3AbN/KGO82R87rdDz5/lYB024rtEf10/q0urNgsA==", + "license": "MIT", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/core-lro/node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "license": "0BSD" + }, + "node_modules/@azure/core-rest-pipeline": { + "version": "1.18.2", + "resolved": "https://registry.npmjs.org/@azure/core-rest-pipeline/-/core-rest-pipeline-1.18.2.tgz", + "integrity": "sha512-IkTf/DWKyCklEtN/WYW3lqEsIaUDshlzWRlZNNwSYtFcCBQz++OtOjxNpm8rr1VcbMS6RpjybQa3u6B6nG0zNw==", + "license": "MIT", + "dependencies": { + "@azure/abort-controller": "^2.0.0", + "@azure/core-auth": "^1.8.0", + "@azure/core-tracing": "^1.0.1", + "@azure/core-util": "^1.11.0", + "@azure/logger": "^1.0.0", + "http-proxy-agent": "^7.0.0", + "https-proxy-agent": "^7.0.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/core-rest-pipeline/node_modules/@azure/abort-controller": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/@azure/abort-controller/-/abort-controller-2.1.2.tgz", + "integrity": "sha512-nBrLsEWm4J2u5LpAPjxADTlq3trDgVZZXHNKabeXZtpq3d3AbN/KGO82R87rdDz5/lYB024rtEf10/q0urNgsA==", + "license": "MIT", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/core-rest-pipeline/node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "license": "0BSD" + }, + "node_modules/@azure/core-tracing": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/@azure/core-tracing/-/core-tracing-1.2.0.tgz", + "integrity": "sha512-UKTiEJPkWcESPYJz3X5uKRYyOcJD+4nYph+KpfdPRnQJVrZfk0KJgdnaAWKfhsBBtAf/D58Az4AvCJEmWgIBAg==", + "license": "MIT", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/core-tracing/node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "license": "0BSD" + }, + "node_modules/@azure/core-util": { + "version": "1.11.0", + "resolved": "https://registry.npmjs.org/@azure/core-util/-/core-util-1.11.0.tgz", + "integrity": "sha512-DxOSLua+NdpWoSqULhjDyAZTXFdP/LKkqtYuxxz1SCN289zk3OG8UOpnCQAz/tygyACBtWp/BoO72ptK7msY8g==", + "license": "MIT", + "dependencies": { + "@azure/abort-controller": "^2.0.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/core-util/node_modules/@azure/abort-controller": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/@azure/abort-controller/-/abort-controller-2.1.2.tgz", + "integrity": "sha512-nBrLsEWm4J2u5LpAPjxADTlq3trDgVZZXHNKabeXZtpq3d3AbN/KGO82R87rdDz5/lYB024rtEf10/q0urNgsA==", + "license": "MIT", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/core-util/node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "license": "0BSD" + }, + "node_modules/@azure/logger": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/@azure/logger/-/logger-1.1.4.tgz", + "integrity": "sha512-4IXXzcCdLdlXuCG+8UKEwLA1T1NHqUfanhXYHiQTn+6sfWCZXduqbtXDGceg3Ce5QxTGo7EqmbV6Bi+aqKuClQ==", + "license": "MIT", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/logger/node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "license": "0BSD" + }, "node_modules/@babel/code-frame": { "version": "7.26.2", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.26.2.tgz", diff --git a/package.json b/package.json index ec5147a..5ba27e6 100644 --- a/package.json +++ b/package.json @@ -271,8 +271,8 @@ "@anthropic-ai/bedrock-sdk": "^0.10.2", "@anthropic-ai/sdk": "^0.26.0", "@anthropic-ai/vertex-sdk": "^0.4.1", - "@azure-rest/ai-inference": "^1.0.0", - "@azure/core-auth": "^1.5.0", + "@azure-rest/ai-inference": "^1.0.0-beta.5", + "@azure/core-auth": "^1.5.0", "@aws-sdk/client-bedrock-runtime": "^3.706.0", "@google/generative-ai": "^0.18.0", "@mistralai/mistralai": "^1.3.6", diff --git a/src/api/providers/azure-ai.ts b/src/api/providers/azure-ai.ts index 18feee9..322c7cf 100644 --- a/src/api/providers/azure-ai.ts +++ b/src/api/providers/azure-ai.ts @@ -3,145 +3,171 @@ import ModelClient from "@azure-rest/ai-inference" import { isUnexpected } from "@azure-rest/ai-inference" import { AzureKeyCredential } from "@azure/core-auth" import { - ApiHandlerOptions, - ModelInfo, - azureAiDefaultModelId, - AzureAiModelId, - azureAiModels, - AzureDeploymentConfig + ApiHandlerOptions, + ModelInfo, + azureAiDefaultModelId, + AzureAiModelId, + azureAiModels, + AzureDeploymentConfig, } from "../../shared/api" import { ApiHandler, SingleCompletionHandler } from "../index" import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream } from "../transform/stream" +import { createSseStream } from "@azure/core-rest-pipeline" export class AzureAiHandler implements ApiHandler, SingleCompletionHandler { - private options: ApiHandlerOptions - private client: ModelClient - - constructor(options: ApiHandlerOptions) { - this.options = options - - if (!options.azureAiEndpoint) { - throw new Error("Azure AI endpoint is required") - } - - if (!options.azureAiKey) { - throw new Error("Azure AI key is required") - } + private options: ApiHandlerOptions + private client: ModelClient - this.client = new ModelClient( - options.azureAiEndpoint, - new AzureKeyCredential(options.azureAiKey) - ) - } - - private getDeploymentConfig(): AzureDeploymentConfig { - const model = this.getModel() - const defaultConfig = azureAiModels[model.id].defaultDeployment - - return { - name: this.options.azureAiDeployments?.[model.id]?.name || defaultConfig.name, - apiVersion: this.options.azureAiDeployments?.[model.id]?.apiVersion || defaultConfig.apiVersion, - modelMeshName: this.options.azureAiDeployments?.[model.id]?.modelMeshName - } - } + constructor(options: ApiHandlerOptions) { + this.options = options - async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - const modelInfo = this.getModel().info - const chatMessages = [ - { role: "system", content: systemPrompt }, - ...convertToOpenAiMessages(messages) - ] + if (!options.azureAiEndpoint) { + throw new Error("Azure AI endpoint is required") + } - try { - const response = await this.client.path("/chat/completions").post({ - body: { - messages: chatMessages, - temperature: 0, - stream: true, - max_tokens: modelInfo.maxTokens - } - }).asNodeStream() + if (!options.azureAiKey) { + throw new Error("Azure AI key is required") + } - const stream = response.body - if (!stream) { - throw new Error(`Failed to get chat completions with status: ${response.status}`) - } + this.client = new ModelClient(options.azureAiEndpoint, new AzureKeyCredential(options.azureAiKey)) + } - if (response.status !== 200) { - throw new Error(`Failed to get chat completions: ${response.body.error}`) - } + private getDeploymentConfig(): AzureDeploymentConfig { + const model = this.getModel() + const defaultConfig = azureAiModels[model.id].defaultDeployment - for await (const chunk of stream) { - if (chunk.toString() === 'data: [DONE]') { - return - } + return { + name: this.options.azureAiDeployments?.[model.id]?.name || defaultConfig.name, + apiVersion: this.options.azureAiDeployments?.[model.id]?.apiVersion || defaultConfig.apiVersion, + modelMeshName: this.options.azureAiDeployments?.[model.id]?.modelMeshName, + } + } - try { - const data = JSON.parse(chunk.toString().replace('data: ', '')) - const delta = data.choices[0]?.delta - - if (delta?.content) { - yield { - type: "text", - text: delta.content - } - } + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { + const modelInfo = this.getModel().info + const chatMessages = [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)] - if (data.usage) { - yield { - type: "usage", - inputTokens: data.usage.prompt_tokens || 0, - outputTokens: data.usage.completion_tokens || 0 - } - } - } catch (e) { - // Ignore parse errors from incomplete chunks - continue - } - } - } catch (error) { - if (error instanceof Error) { - if ('status' in error && error.status === 429) { - throw new Error("Azure AI rate limit exceeded. Please try again later.") - } - throw new Error(`Azure AI error: ${error.message}`) - } - throw error - } - } + try { + const response = await this.client + .path("/chat/completions") + .post({ + body: { + messages: chatMessages, + temperature: 0, + stream: true, + max_tokens: modelInfo.maxTokens, + response_format: { type: "text" }, // Ensure text format for chat + }, + headers: this.getDeploymentConfig().modelMeshName + ? { + "x-ms-model-mesh-model-name": this.getDeploymentConfig().modelMeshName, + } + : undefined, + }) + .asNodeStream() - getModel(): { id: AzureAiModelId; info: ModelInfo } { - const modelId = this.options.apiModelId - if (modelId && modelId in azureAiModels) { - const id = modelId as AzureAiModelId - return { id, info: azureAiModels[id] } - } - return { id: azureAiDefaultModelId, info: azureAiModels[azureAiDefaultModelId] } - } + const stream = response.body + if (!stream) { + throw new Error(`Failed to get chat completions with status: ${response.status}`) + } - async completePrompt(prompt: string): Promise { - try { - const response = await this.client.path("/chat/completions").post({ - body: { - messages: [{ role: "user", content: prompt }], - temperature: 0 - } - }) + if (response.status !== 200) { + throw new Error(`Failed to get chat completions: ${response.body.error}`) + } - if (isUnexpected(response)) { - throw response.body.error - } + const sseStream = createSseStream(stream) - return response.body.choices[0]?.message?.content || "" - } catch (error) { - if (error instanceof Error) { - if ('status' in error && error.status === 429) { - throw new Error("Azure AI rate limit exceeded. Please try again later.") - } - throw new Error(`Azure AI completion error: ${error.message}`) - } - throw error - } - } -} \ No newline at end of file + for await (const event of sseStream) { + if (event.data === "[DONE]") { + return + } + + try { + const data = JSON.parse(event.data) + const delta = data.choices[0]?.delta + + if (delta?.content) { + yield { + type: "text", + text: delta.content, + } + } + + if (data.usage) { + yield { + type: "usage", + inputTokens: data.usage.prompt_tokens || 0, + outputTokens: data.usage.completion_tokens || 0, + } + } + } catch (e) { + // Ignore parse errors from incomplete chunks + continue + } + } + } catch (error) { + if (error instanceof Error) { + // Handle Azure-specific error cases + if ("status" in error && error.status === 429) { + throw new Error("Azure AI rate limit exceeded. Please try again later.") + } + if ("status" in error && error.status === 400) { + const azureError = error as any + if (azureError.body?.error?.code === "ContentFilterError") { + throw new Error("Content was flagged by Azure AI content safety filters") + } + } + throw new Error(`Azure AI error: ${error.message}`) + } + throw error + } + } + + getModel(): { id: AzureAiModelId; info: ModelInfo } { + const modelId = this.options.apiModelId + if (modelId && modelId in azureAiModels) { + const id = modelId as AzureAiModelId + return { id, info: azureAiModels[id] } + } + return { id: azureAiDefaultModelId, info: azureAiModels[azureAiDefaultModelId] } + } + + async completePrompt(prompt: string): Promise { + try { + const response = await this.client.path("/chat/completions").post({ + body: { + messages: [{ role: "user", content: prompt }], + temperature: 0, + response_format: { type: "text" }, + }, + headers: this.getDeploymentConfig().modelMeshName + ? { + "x-ms-model-mesh-model-name": this.getDeploymentConfig().modelMeshName, + } + : undefined, + }) + + if (isUnexpected(response)) { + throw response.body.error + } + + return response.body.choices[0]?.message?.content || "" + } catch (error) { + if (error instanceof Error) { + // Handle Azure-specific error cases + if ("status" in error && error.status === 429) { + throw new Error("Azure AI rate limit exceeded. Please try again later.") + } + if ("status" in error && error.status === 400) { + const azureError = error as any + if (azureError.body?.error?.code === "ContentFilterError") { + throw new Error("Content was flagged by Azure AI content safety filters") + } + } + throw new Error(`Azure AI completion error: ${error.message}`) + } + throw error + } + } +}