mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-20 04:11:10 -05:00
feat: add Azure AI integration with deployment configuration
This commit is contained in:
337
azure-ai-inference-provider-plan.md
Normal file
337
azure-ai-inference-provider-plan.md
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
# Azure AI Inference Provider Implementation Plan
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
This document outlines the implementation plan for adding Azure AI Inference support as a new provider in `src/api/providers/`. While Azure AI uses OpenAI's API format as a base, there are significant differences in implementation that need to be accounted for.
|
||||||
|
|
||||||
|
## Key Differences from OpenAI
|
||||||
|
|
||||||
|
### Endpoint Structure
|
||||||
|
- OpenAI: `https://api.openai.com/v1/chat/completions`
|
||||||
|
- Azure: `https://{resource-name}.openai.azure.com/openai/deployments/{deployment-name}/chat/completions?api-version={api-version}`
|
||||||
|
|
||||||
|
### Authentication
|
||||||
|
- OpenAI: Uses `Authorization: Bearer sk-...`
|
||||||
|
- Azure: Uses `api-key: {key}`
|
||||||
|
|
||||||
|
### Request Format
|
||||||
|
- OpenAI: Requires `model` field in request body
|
||||||
|
- Azure: Omits `model` from body (uses deployment name in URL instead)
|
||||||
|
|
||||||
|
### Special Considerations
|
||||||
|
- Required API version in URL query parameter
|
||||||
|
- Model-Mesh deployments require additional header: `x-ms-model-mesh-model-name`
|
||||||
|
- Different API versions for different features (e.g., 2023-12-01-preview, 2024-02-15-preview)
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
|
import OpenAI, { AzureOpenAI } from "openai"
|
||||||
|
import {
|
||||||
|
ApiHandlerOptions,
|
||||||
|
ModelInfo,
|
||||||
|
azureAiDefaultModelId,
|
||||||
|
AzureAiModelId,
|
||||||
|
azureAiModels
|
||||||
|
} from "../../shared/api"
|
||||||
|
import { ApiHandler, SingleCompletionHandler } from "../index"
|
||||||
|
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||||
|
import { ApiStream } from "../transform/stream"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration (shared/api.ts)
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export type AzureAiModelId = "gpt-35-turbo" | "gpt-4" | "gpt-4-turbo"
|
||||||
|
|
||||||
|
export interface AzureDeploymentConfig {
|
||||||
|
name: string
|
||||||
|
apiVersion: string
|
||||||
|
modelMeshName?: string // For Model-Mesh deployments
|
||||||
|
}
|
||||||
|
|
||||||
|
export const azureAiModels: Record<AzureAiModelId, ModelInfo & { defaultDeployment: AzureDeploymentConfig }> = {
|
||||||
|
"gpt-35-turbo": {
|
||||||
|
maxTokens: 4096,
|
||||||
|
contextWindow: 16385,
|
||||||
|
supportsPromptCache: true,
|
||||||
|
inputPrice: 0.0015,
|
||||||
|
outputPrice: 0.002,
|
||||||
|
defaultDeployment: {
|
||||||
|
name: "gpt-35-turbo",
|
||||||
|
apiVersion: "2024-02-15-preview"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"gpt-4": {
|
||||||
|
maxTokens: 8192,
|
||||||
|
contextWindow: 8192,
|
||||||
|
supportsPromptCache: true,
|
||||||
|
inputPrice: 0.03,
|
||||||
|
outputPrice: 0.06,
|
||||||
|
defaultDeployment: {
|
||||||
|
name: "gpt-4",
|
||||||
|
apiVersion: "2024-02-15-preview"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"gpt-4-turbo": {
|
||||||
|
maxTokens: 4096,
|
||||||
|
contextWindow: 128000,
|
||||||
|
supportsPromptCache: true,
|
||||||
|
inputPrice: 0.01,
|
||||||
|
outputPrice: 0.03,
|
||||||
|
defaultDeployment: {
|
||||||
|
name: "gpt-4-turbo",
|
||||||
|
apiVersion: "2024-02-15-preview"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const azureAiDefaultModelId: AzureAiModelId = "gpt-35-turbo"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implementation (src/api/providers/azure-ai.ts)
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
|
||||||
|
private options: ApiHandlerOptions
|
||||||
|
private client: AzureOpenAI
|
||||||
|
|
||||||
|
constructor(options: ApiHandlerOptions) {
|
||||||
|
this.options = options
|
||||||
|
|
||||||
|
if (!options.azureAiEndpoint) {
|
||||||
|
throw new Error("Azure AI endpoint is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!options.azureAiKey) {
|
||||||
|
throw new Error("Azure AI key is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
const deployment = this.getDeploymentConfig()
|
||||||
|
|
||||||
|
this.client = new AzureOpenAI({
|
||||||
|
apiKey: options.azureAiKey,
|
||||||
|
endpoint: options.azureAiEndpoint,
|
||||||
|
deployment: deployment.name,
|
||||||
|
apiVersion: deployment.apiVersion,
|
||||||
|
headers: deployment.modelMeshName ? {
|
||||||
|
'x-ms-model-mesh-model-name': deployment.modelMeshName
|
||||||
|
} : undefined
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
private getDeploymentConfig(): AzureDeploymentConfig {
|
||||||
|
const model = this.getModel()
|
||||||
|
const defaultConfig = azureAiModels[model.id].defaultDeployment
|
||||||
|
|
||||||
|
// Override with user-provided deployment names if available
|
||||||
|
const deploymentName =
|
||||||
|
this.options.azureAiDeployments?.[model.id]?.name ||
|
||||||
|
defaultConfig.name
|
||||||
|
|
||||||
|
const apiVersion =
|
||||||
|
this.options.azureAiDeployments?.[model.id]?.apiVersion ||
|
||||||
|
defaultConfig.apiVersion
|
||||||
|
|
||||||
|
const modelMeshName =
|
||||||
|
this.options.azureAiDeployments?.[model.id]?.modelMeshName
|
||||||
|
|
||||||
|
return {
|
||||||
|
name: deploymentName,
|
||||||
|
apiVersion,
|
||||||
|
modelMeshName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
||||||
|
const modelInfo = this.getModel().info
|
||||||
|
|
||||||
|
const systemMessage = {
|
||||||
|
role: "system",
|
||||||
|
content: systemPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: model parameter is omitted as it's handled by deployment
|
||||||
|
const requestOptions: Omit<OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming, 'model'> = {
|
||||||
|
messages: [systemMessage, ...convertToOpenAiMessages(messages)],
|
||||||
|
temperature: 0,
|
||||||
|
stream: true,
|
||||||
|
max_tokens: modelInfo.maxTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const stream = await this.client.chat.completions.create(requestOptions as any)
|
||||||
|
|
||||||
|
for await (const chunk of stream) {
|
||||||
|
const delta = chunk.choices[0]?.delta
|
||||||
|
|
||||||
|
if (delta?.content) {
|
||||||
|
yield {
|
||||||
|
type: "text",
|
||||||
|
text: delta.content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (chunk.usage) {
|
||||||
|
yield {
|
||||||
|
type: "usage",
|
||||||
|
inputTokens: chunk.usage.prompt_tokens || 0,
|
||||||
|
outputTokens: chunk.usage.completion_tokens || 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// Handle Azure-specific error format
|
||||||
|
if (error instanceof Error) {
|
||||||
|
const azureError = error as any
|
||||||
|
throw new Error(
|
||||||
|
`Azure AI error (${azureError.code || 'Unknown'}): ${azureError.message}`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
getModel(): { id: AzureAiModelId; info: ModelInfo } {
|
||||||
|
const modelId = this.options.apiModelId
|
||||||
|
if (modelId && modelId in azureAiModels) {
|
||||||
|
const id = modelId as AzureAiModelId
|
||||||
|
return { id, info: azureAiModels[id] }
|
||||||
|
}
|
||||||
|
return { id: azureAiDefaultModelId, info: azureAiModels[azureAiDefaultModelId] }
|
||||||
|
}
|
||||||
|
|
||||||
|
async completePrompt(prompt: string): Promise<string> {
|
||||||
|
try {
|
||||||
|
// Note: model parameter is omitted as it's handled by deployment
|
||||||
|
const response = await this.client.chat.completions.create({
|
||||||
|
messages: [{ role: "user", content: prompt }],
|
||||||
|
temperature: 0
|
||||||
|
} as any)
|
||||||
|
|
||||||
|
return response.choices[0]?.message.content || ""
|
||||||
|
} catch (error) {
|
||||||
|
// Handle Azure-specific error format
|
||||||
|
if (error instanceof Error) {
|
||||||
|
const azureError = error as any
|
||||||
|
throw new Error(
|
||||||
|
`Azure AI completion error (${azureError.code || 'Unknown'}): ${azureError.message}`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Required Updates to ApiHandlerOptions
|
||||||
|
|
||||||
|
Add to ApiHandlerOptions interface in shared/api.ts:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
azureAiEndpoint?: string
|
||||||
|
azureAiKey?: string
|
||||||
|
azureAiDeployments?: {
|
||||||
|
[key in AzureAiModelId]?: {
|
||||||
|
name: string
|
||||||
|
apiVersion: string
|
||||||
|
modelMeshName?: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing Plan
|
||||||
|
|
||||||
|
1. Create __tests__ directory with azure-ai.test.ts:
|
||||||
|
```typescript
|
||||||
|
describe('AzureAiHandler', () => {
|
||||||
|
// Test URL construction
|
||||||
|
test('constructs correct Azure endpoint URL', () => {})
|
||||||
|
|
||||||
|
// Test authentication
|
||||||
|
test('sets correct authentication headers', () => {})
|
||||||
|
|
||||||
|
// Test deployment configuration
|
||||||
|
test('uses correct deployment names', () => {})
|
||||||
|
test('handles Model-Mesh configuration', () => {})
|
||||||
|
|
||||||
|
// Test error handling
|
||||||
|
test('handles Azure-specific error format', () => {})
|
||||||
|
|
||||||
|
// Test request/response format
|
||||||
|
test('omits model from request body', () => {})
|
||||||
|
test('handles Azure response format', () => {})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Integration Steps
|
||||||
|
|
||||||
|
1. Add Azure AI models and types to shared/api.ts
|
||||||
|
2. Create azure-ai.ts provider implementation
|
||||||
|
3. Add provider tests
|
||||||
|
4. Update API handler options
|
||||||
|
5. Add deployment configuration support
|
||||||
|
6. Implement Azure-specific error handling
|
||||||
|
7. Test with real Azure AI endpoints
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
Azure returns errors in a specific format:
|
||||||
|
```typescript
|
||||||
|
interface AzureError {
|
||||||
|
code: string // e.g., "InternalServerError", "InvalidRequest"
|
||||||
|
message: string
|
||||||
|
target?: string
|
||||||
|
details?: Array<{
|
||||||
|
code: string
|
||||||
|
message: string
|
||||||
|
}>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Implementation should:
|
||||||
|
- Parse Azure error format
|
||||||
|
- Include error codes in messages
|
||||||
|
- Handle deployment-specific errors
|
||||||
|
- Provide clear upgrade paths for API version issues
|
||||||
|
|
||||||
|
## Documentation Updates
|
||||||
|
|
||||||
|
1. Add Azure AI configuration section to README.md:
|
||||||
|
- Endpoint configuration
|
||||||
|
- Authentication setup
|
||||||
|
- Deployment mapping
|
||||||
|
- API version selection
|
||||||
|
- Model-Mesh support
|
||||||
|
|
||||||
|
2. Document configuration examples:
|
||||||
|
```typescript
|
||||||
|
{
|
||||||
|
azureAiEndpoint: "https://your-resource.openai.azure.com",
|
||||||
|
azureAiKey: "your-api-key",
|
||||||
|
azureAiDeployments: {
|
||||||
|
"gpt-4": {
|
||||||
|
name: "your-gpt4-deployment",
|
||||||
|
apiVersion: "2024-02-15-preview",
|
||||||
|
modelMeshName: "optional-model-mesh-name"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Future Improvements
|
||||||
|
|
||||||
|
1. Support for Azure-specific features:
|
||||||
|
- Fine-tuning endpoints
|
||||||
|
- Custom deployment configurations
|
||||||
|
- Managed identity authentication
|
||||||
|
|
||||||
|
2. Performance optimizations:
|
||||||
|
- Connection pooling
|
||||||
|
- Regional endpoint selection
|
||||||
|
- Automatic API version negotiation
|
||||||
|
|
||||||
|
3. Advanced features:
|
||||||
|
- Response format control
|
||||||
|
- Function calling support
|
||||||
|
- Vision model support
|
||||||
@@ -271,6 +271,8 @@
|
|||||||
"@anthropic-ai/bedrock-sdk": "^0.10.2",
|
"@anthropic-ai/bedrock-sdk": "^0.10.2",
|
||||||
"@anthropic-ai/sdk": "^0.26.0",
|
"@anthropic-ai/sdk": "^0.26.0",
|
||||||
"@anthropic-ai/vertex-sdk": "^0.4.1",
|
"@anthropic-ai/vertex-sdk": "^0.4.1",
|
||||||
|
"@azure-rest/ai-inference": "^1.0.0",
|
||||||
|
"@azure/core-auth": "^1.5.0",
|
||||||
"@aws-sdk/client-bedrock-runtime": "^3.706.0",
|
"@aws-sdk/client-bedrock-runtime": "^3.706.0",
|
||||||
"@google/generative-ai": "^0.18.0",
|
"@google/generative-ai": "^0.18.0",
|
||||||
"@mistralai/mistralai": "^1.3.6",
|
"@mistralai/mistralai": "^1.3.6",
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import { DeepSeekHandler } from "./providers/deepseek"
|
|||||||
import { MistralHandler } from "./providers/mistral"
|
import { MistralHandler } from "./providers/mistral"
|
||||||
import { VsCodeLmHandler } from "./providers/vscode-lm"
|
import { VsCodeLmHandler } from "./providers/vscode-lm"
|
||||||
import { ApiStream } from "./transform/stream"
|
import { ApiStream } from "./transform/stream"
|
||||||
|
import { AzureAiHandler } from "./providers/azure-ai"
|
||||||
import { UnboundHandler } from "./providers/unbound"
|
import { UnboundHandler } from "./providers/unbound"
|
||||||
|
|
||||||
export interface SingleCompletionHandler {
|
export interface SingleCompletionHandler {
|
||||||
@@ -56,7 +57,9 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
|
|||||||
return new MistralHandler(options)
|
return new MistralHandler(options)
|
||||||
case "unbound":
|
case "unbound":
|
||||||
return new UnboundHandler(options)
|
return new UnboundHandler(options)
|
||||||
default:
|
case "azure-ai":
|
||||||
|
return new AzureAiHandler(options)
|
||||||
|
default:
|
||||||
return new AnthropicHandler(options)
|
return new AnthropicHandler(options)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
171
src/api/providers/__tests__/azure-ai.test.ts
Normal file
171
src/api/providers/__tests__/azure-ai.test.ts
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
import { AzureAiHandler } from "../azure-ai"
|
||||||
|
import { ApiHandlerOptions } from "../../../shared/api"
|
||||||
|
import { Readable } from "stream"
|
||||||
|
import ModelClient from "@azure-rest/ai-inference"
|
||||||
|
|
||||||
|
// Mock the Azure AI client
|
||||||
|
jest.mock("@azure-rest/ai-inference", () => {
|
||||||
|
return {
|
||||||
|
__esModule: true,
|
||||||
|
default: jest.fn().mockImplementation(() => ({
|
||||||
|
path: jest.fn().mockReturnValue({
|
||||||
|
post: jest.fn()
|
||||||
|
})
|
||||||
|
})),
|
||||||
|
isUnexpected: jest.fn()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("AzureAiHandler", () => {
|
||||||
|
const mockOptions: ApiHandlerOptions = {
|
||||||
|
apiProvider: "azure-ai",
|
||||||
|
apiModelId: "azure-gpt-35",
|
||||||
|
azureAiEndpoint: "https://test-resource.inference.azure.com",
|
||||||
|
azureAiKey: "test-key",
|
||||||
|
azureAiDeployments: {
|
||||||
|
"azure-gpt-35": {
|
||||||
|
name: "custom-gpt35",
|
||||||
|
apiVersion: "2024-02-15-preview",
|
||||||
|
modelMeshName: "test-mesh-model"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
test("constructs with required options", () => {
|
||||||
|
const handler = new AzureAiHandler(mockOptions)
|
||||||
|
expect(handler).toBeInstanceOf(AzureAiHandler)
|
||||||
|
})
|
||||||
|
|
||||||
|
test("throws error without endpoint", () => {
|
||||||
|
const invalidOptions = { ...mockOptions }
|
||||||
|
delete invalidOptions.azureAiEndpoint
|
||||||
|
expect(() => new AzureAiHandler(invalidOptions)).toThrow("Azure AI endpoint is required")
|
||||||
|
})
|
||||||
|
|
||||||
|
test("throws error without API key", () => {
|
||||||
|
const invalidOptions = { ...mockOptions }
|
||||||
|
delete invalidOptions.azureAiKey
|
||||||
|
expect(() => new AzureAiHandler(invalidOptions)).toThrow("Azure AI key is required")
|
||||||
|
})
|
||||||
|
|
||||||
|
test("creates chat completion correctly", async () => {
|
||||||
|
const handler = new AzureAiHandler(mockOptions)
|
||||||
|
const mockResponse = {
|
||||||
|
body: {
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
message: {
|
||||||
|
content: "test response"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const mockClient = ModelClient as jest.MockedClass<typeof ModelClient>
|
||||||
|
mockClient.prototype.path.mockReturnValue({
|
||||||
|
post: jest.fn().mockResolvedValue(mockResponse)
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = await handler.completePrompt("test prompt")
|
||||||
|
expect(result).toBe("test response")
|
||||||
|
|
||||||
|
expect(mockClient.prototype.path).toHaveBeenCalledWith("/chat/completions")
|
||||||
|
expect(mockClient.prototype.path().post).toHaveBeenCalledWith({
|
||||||
|
body: {
|
||||||
|
messages: [{ role: "user", content: "test prompt" }],
|
||||||
|
temperature: 0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test("handles streaming responses correctly", async () => {
|
||||||
|
const handler = new AzureAiHandler(mockOptions)
|
||||||
|
const mockStream = Readable.from([
|
||||||
|
'data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}\n\n',
|
||||||
|
'data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2}}\n\n',
|
||||||
|
'data: [DONE]\n\n'
|
||||||
|
])
|
||||||
|
|
||||||
|
const mockClient = ModelClient as jest.MockedClass<typeof ModelClient>
|
||||||
|
mockClient.prototype.path.mockReturnValue({
|
||||||
|
post: jest.fn().mockResolvedValue({
|
||||||
|
status: 200,
|
||||||
|
body: mockStream,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
const messages = []
|
||||||
|
for await (const message of handler.createMessage("system prompt", [])) {
|
||||||
|
messages.push(message)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(messages).toEqual([
|
||||||
|
{ type: "text", text: "Hello" },
|
||||||
|
{ type: "text", text: " world" },
|
||||||
|
{ type: "usage", inputTokens: 10, outputTokens: 2 }
|
||||||
|
])
|
||||||
|
|
||||||
|
expect(mockClient.prototype.path().post).toHaveBeenCalledWith({
|
||||||
|
body: {
|
||||||
|
messages: [{ role: "system", content: "system prompt" }],
|
||||||
|
temperature: 0,
|
||||||
|
stream: true,
|
||||||
|
max_tokens: expect.any(Number)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test("handles rate limit errors", async () => {
|
||||||
|
const handler = new AzureAiHandler(mockOptions)
|
||||||
|
const mockError = new Error("Rate limit exceeded")
|
||||||
|
Object.assign(mockError, { status: 429 })
|
||||||
|
|
||||||
|
const mockClient = ModelClient as jest.MockedClass<typeof ModelClient>
|
||||||
|
mockClient.prototype.path.mockReturnValue({
|
||||||
|
post: jest.fn().mockRejectedValue(mockError)
|
||||||
|
})
|
||||||
|
|
||||||
|
await expect(handler.completePrompt("test")).rejects.toThrow(
|
||||||
|
"Azure AI rate limit exceeded. Please try again later."
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
test("handles content safety errors", async () => {
|
||||||
|
const handler = new AzureAiHandler(mockOptions)
|
||||||
|
const mockError = {
|
||||||
|
status: 400,
|
||||||
|
body: {
|
||||||
|
error: {
|
||||||
|
code: "ContentFilterError",
|
||||||
|
message: "Content was flagged by content safety filters"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const mockClient = ModelClient as jest.MockedClass<typeof ModelClient>
|
||||||
|
mockClient.prototype.path.mockReturnValue({
|
||||||
|
post: jest.fn().mockRejectedValue(mockError)
|
||||||
|
})
|
||||||
|
|
||||||
|
await expect(handler.completePrompt("test")).rejects.toThrow(
|
||||||
|
"Azure AI completion error: Content was flagged by content safety filters"
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
test("falls back to default model configuration", async () => {
|
||||||
|
const options = { ...mockOptions }
|
||||||
|
delete options.azureAiDeployments
|
||||||
|
|
||||||
|
const handler = new AzureAiHandler(options)
|
||||||
|
const model = handler.getModel()
|
||||||
|
|
||||||
|
expect(model.id).toBe("azure-gpt-35")
|
||||||
|
expect(model.info).toBeDefined()
|
||||||
|
expect(model.info.defaultDeployment.name).toBe("azure-gpt-35")
|
||||||
|
})
|
||||||
|
})
|
||||||
147
src/api/providers/azure-ai.ts
Normal file
147
src/api/providers/azure-ai.ts
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
import { Anthropic } from "@anthropic-ai/sdk"
|
||||||
|
import ModelClient from "@azure-rest/ai-inference"
|
||||||
|
import { isUnexpected } from "@azure-rest/ai-inference"
|
||||||
|
import { AzureKeyCredential } from "@azure/core-auth"
|
||||||
|
import {
|
||||||
|
ApiHandlerOptions,
|
||||||
|
ModelInfo,
|
||||||
|
azureAiDefaultModelId,
|
||||||
|
AzureAiModelId,
|
||||||
|
azureAiModels,
|
||||||
|
AzureDeploymentConfig
|
||||||
|
} from "../../shared/api"
|
||||||
|
import { ApiHandler, SingleCompletionHandler } from "../index"
|
||||||
|
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||||
|
import { ApiStream } from "../transform/stream"
|
||||||
|
|
||||||
|
export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
|
||||||
|
private options: ApiHandlerOptions
|
||||||
|
private client: ModelClient
|
||||||
|
|
||||||
|
constructor(options: ApiHandlerOptions) {
|
||||||
|
this.options = options
|
||||||
|
|
||||||
|
if (!options.azureAiEndpoint) {
|
||||||
|
throw new Error("Azure AI endpoint is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!options.azureAiKey) {
|
||||||
|
throw new Error("Azure AI key is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
this.client = new ModelClient(
|
||||||
|
options.azureAiEndpoint,
|
||||||
|
new AzureKeyCredential(options.azureAiKey)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private getDeploymentConfig(): AzureDeploymentConfig {
|
||||||
|
const model = this.getModel()
|
||||||
|
const defaultConfig = azureAiModels[model.id].defaultDeployment
|
||||||
|
|
||||||
|
return {
|
||||||
|
name: this.options.azureAiDeployments?.[model.id]?.name || defaultConfig.name,
|
||||||
|
apiVersion: this.options.azureAiDeployments?.[model.id]?.apiVersion || defaultConfig.apiVersion,
|
||||||
|
modelMeshName: this.options.azureAiDeployments?.[model.id]?.modelMeshName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
||||||
|
const modelInfo = this.getModel().info
|
||||||
|
const chatMessages = [
|
||||||
|
{ role: "system", content: systemPrompt },
|
||||||
|
...convertToOpenAiMessages(messages)
|
||||||
|
]
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await this.client.path("/chat/completions").post({
|
||||||
|
body: {
|
||||||
|
messages: chatMessages,
|
||||||
|
temperature: 0,
|
||||||
|
stream: true,
|
||||||
|
max_tokens: modelInfo.maxTokens
|
||||||
|
}
|
||||||
|
}).asNodeStream()
|
||||||
|
|
||||||
|
const stream = response.body
|
||||||
|
if (!stream) {
|
||||||
|
throw new Error(`Failed to get chat completions with status: ${response.status}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (response.status !== 200) {
|
||||||
|
throw new Error(`Failed to get chat completions: ${response.body.error}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
for await (const chunk of stream) {
|
||||||
|
if (chunk.toString() === 'data: [DONE]') {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(chunk.toString().replace('data: ', ''))
|
||||||
|
const delta = data.choices[0]?.delta
|
||||||
|
|
||||||
|
if (delta?.content) {
|
||||||
|
yield {
|
||||||
|
type: "text",
|
||||||
|
text: delta.content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data.usage) {
|
||||||
|
yield {
|
||||||
|
type: "usage",
|
||||||
|
inputTokens: data.usage.prompt_tokens || 0,
|
||||||
|
outputTokens: data.usage.completion_tokens || 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
// Ignore parse errors from incomplete chunks
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error) {
|
||||||
|
if ('status' in error && error.status === 429) {
|
||||||
|
throw new Error("Azure AI rate limit exceeded. Please try again later.")
|
||||||
|
}
|
||||||
|
throw new Error(`Azure AI error: ${error.message}`)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
getModel(): { id: AzureAiModelId; info: ModelInfo } {
|
||||||
|
const modelId = this.options.apiModelId
|
||||||
|
if (modelId && modelId in azureAiModels) {
|
||||||
|
const id = modelId as AzureAiModelId
|
||||||
|
return { id, info: azureAiModels[id] }
|
||||||
|
}
|
||||||
|
return { id: azureAiDefaultModelId, info: azureAiModels[azureAiDefaultModelId] }
|
||||||
|
}
|
||||||
|
|
||||||
|
async completePrompt(prompt: string): Promise<string> {
|
||||||
|
try {
|
||||||
|
const response = await this.client.path("/chat/completions").post({
|
||||||
|
body: {
|
||||||
|
messages: [{ role: "user", content: prompt }],
|
||||||
|
temperature: 0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if (isUnexpected(response)) {
|
||||||
|
throw response.body.error
|
||||||
|
}
|
||||||
|
|
||||||
|
return response.body.choices[0]?.message?.content || ""
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error) {
|
||||||
|
if ('status' in error && error.status === 429) {
|
||||||
|
throw new Error("Azure AI rate limit exceeded. Please try again later.")
|
||||||
|
}
|
||||||
|
throw new Error(`Azure AI completion error: ${error.message}`)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -86,6 +86,7 @@ type GlobalStateKey =
|
|||||||
| "lmStudioBaseUrl"
|
| "lmStudioBaseUrl"
|
||||||
| "anthropicBaseUrl"
|
| "anthropicBaseUrl"
|
||||||
| "azureApiVersion"
|
| "azureApiVersion"
|
||||||
|
| "azureAiDeployments"
|
||||||
| "openAiStreamingEnabled"
|
| "openAiStreamingEnabled"
|
||||||
| "openRouterModelId"
|
| "openRouterModelId"
|
||||||
| "openRouterModelInfo"
|
| "openRouterModelInfo"
|
||||||
@@ -1074,6 +1075,16 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
await this.updateGlobalState("autoApprovalEnabled", message.bool ?? false)
|
await this.updateGlobalState("autoApprovalEnabled", message.bool ?? false)
|
||||||
await this.postStateToWebview()
|
await this.postStateToWebview()
|
||||||
break
|
break
|
||||||
|
case "updateAzureAiDeployment":
|
||||||
|
if (message.azureAiDeployment) {
|
||||||
|
const deployments = await this.getGlobalState("azureAiDeployments") || {}
|
||||||
|
deployments[message.azureAiDeployment.modelId] = {
|
||||||
|
...message.azureAiDeployment,
|
||||||
|
}
|
||||||
|
await this.updateGlobalState("azureAiDeployments", deployments)
|
||||||
|
await this.postStateToWebview()
|
||||||
|
}
|
||||||
|
break
|
||||||
case "enhancePrompt":
|
case "enhancePrompt":
|
||||||
if (message.text) {
|
if (message.text) {
|
||||||
try {
|
try {
|
||||||
@@ -1506,6 +1517,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
|
await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
|
||||||
await this.storeSecret("deepSeekApiKey", deepSeekApiKey)
|
await this.storeSecret("deepSeekApiKey", deepSeekApiKey)
|
||||||
await this.updateGlobalState("azureApiVersion", azureApiVersion)
|
await this.updateGlobalState("azureApiVersion", azureApiVersion)
|
||||||
|
await this.updateGlobalState("azureAiDeployments", apiConfiguration.azureAiDeployments)
|
||||||
await this.updateGlobalState("openAiStreamingEnabled", openAiStreamingEnabled)
|
await this.updateGlobalState("openAiStreamingEnabled", openAiStreamingEnabled)
|
||||||
await this.updateGlobalState("openRouterModelId", openRouterModelId)
|
await this.updateGlobalState("openRouterModelId", openRouterModelId)
|
||||||
await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo)
|
await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo)
|
||||||
@@ -2147,6 +2159,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
openAiNativeApiKey,
|
openAiNativeApiKey,
|
||||||
deepSeekApiKey,
|
deepSeekApiKey,
|
||||||
mistralApiKey,
|
mistralApiKey,
|
||||||
|
azureAiDeployments,
|
||||||
azureApiVersion,
|
azureApiVersion,
|
||||||
openAiStreamingEnabled,
|
openAiStreamingEnabled,
|
||||||
openRouterModelId,
|
openRouterModelId,
|
||||||
@@ -2221,6 +2234,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
|
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
|
||||||
this.getSecret("deepSeekApiKey") as Promise<string | undefined>,
|
this.getSecret("deepSeekApiKey") as Promise<string | undefined>,
|
||||||
this.getSecret("mistralApiKey") as Promise<string | undefined>,
|
this.getSecret("mistralApiKey") as Promise<string | undefined>,
|
||||||
|
this.getGlobalState("azureAiDeployments") as Promise<Record<string, any> | undefined>,
|
||||||
this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
|
this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
|
||||||
this.getGlobalState("openAiStreamingEnabled") as Promise<boolean | undefined>,
|
this.getGlobalState("openAiStreamingEnabled") as Promise<boolean | undefined>,
|
||||||
this.getGlobalState("openRouterModelId") as Promise<string | undefined>,
|
this.getGlobalState("openRouterModelId") as Promise<string | undefined>,
|
||||||
@@ -2313,6 +2327,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
deepSeekApiKey,
|
deepSeekApiKey,
|
||||||
mistralApiKey,
|
mistralApiKey,
|
||||||
azureApiVersion,
|
azureApiVersion,
|
||||||
|
azureAiDeployments,
|
||||||
openAiStreamingEnabled,
|
openAiStreamingEnabled,
|
||||||
openRouterModelId,
|
openRouterModelId,
|
||||||
openRouterModelInfo,
|
openRouterModelInfo,
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
// type that represents json data that is sent from extension to webview, called ExtensionMessage and has 'type' enum which can be 'plusButtonClicked' or 'settingsButtonClicked' or 'hello'
|
import { ApiConfiguration, ApiProvider, ModelInfo, AzureDeploymentConfig } from "./api"
|
||||||
|
|
||||||
import { ApiConfiguration, ApiProvider, ModelInfo } from "./api"
|
|
||||||
import { HistoryItem } from "./HistoryItem"
|
import { HistoryItem } from "./HistoryItem"
|
||||||
import { McpServer } from "./mcp"
|
import { McpServer } from "./mcp"
|
||||||
import { GitCommit } from "../utils/git"
|
import { GitCommit } from "../utils/git"
|
||||||
@@ -15,7 +13,6 @@ export interface LanguageModelChatSelector {
|
|||||||
id?: string
|
id?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
// webview will hold state
|
|
||||||
export interface ExtensionMessage {
|
export interface ExtensionMessage {
|
||||||
type:
|
type:
|
||||||
| "action"
|
| "action"
|
||||||
@@ -99,7 +96,7 @@ export interface ExtensionState {
|
|||||||
alwaysApproveResubmit?: boolean
|
alwaysApproveResubmit?: boolean
|
||||||
alwaysAllowModeSwitch?: boolean
|
alwaysAllowModeSwitch?: boolean
|
||||||
requestDelaySeconds: number
|
requestDelaySeconds: number
|
||||||
rateLimitSeconds: number // Minimum time between successive requests (0 = disabled)
|
rateLimitSeconds: number
|
||||||
uriScheme?: string
|
uriScheme?: string
|
||||||
allowedCommands?: string[]
|
allowedCommands?: string[]
|
||||||
soundEnabled?: boolean
|
soundEnabled?: boolean
|
||||||
@@ -116,10 +113,11 @@ export interface ExtensionState {
|
|||||||
mode: Mode
|
mode: Mode
|
||||||
modeApiConfigs?: Record<Mode, string>
|
modeApiConfigs?: Record<Mode, string>
|
||||||
enhancementApiConfigId?: string
|
enhancementApiConfigId?: string
|
||||||
experiments: Record<ExperimentId, boolean> // Map of experiment IDs to their enabled state
|
experiments: Record<ExperimentId, boolean>
|
||||||
autoApprovalEnabled?: boolean
|
autoApprovalEnabled?: boolean
|
||||||
customModes: ModeConfig[]
|
customModes: ModeConfig[]
|
||||||
toolRequirements?: Record<string, boolean> // Map of tool names to their requirements (e.g. {"apply_diff": true} if diffEnabled)
|
toolRequirements?: Record<string, boolean>
|
||||||
|
azureAiDeployments?: Record<string, AzureDeploymentConfig>
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ClineMessage {
|
export interface ClineMessage {
|
||||||
@@ -190,7 +188,6 @@ export interface ClineSayTool {
|
|||||||
reason?: string
|
reason?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
// must keep in sync with system prompt
|
|
||||||
export const browserActions = ["launch", "click", "type", "scroll_down", "scroll_up", "close"] as const
|
export const browserActions = ["launch", "click", "type", "scroll_down", "scroll_up", "close"] as const
|
||||||
export type BrowserAction = (typeof browserActions)[number]
|
export type BrowserAction = (typeof browserActions)[number]
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ export interface WebviewMessage {
|
|||||||
| "alwaysAllowReadOnly"
|
| "alwaysAllowReadOnly"
|
||||||
| "alwaysAllowWrite"
|
| "alwaysAllowWrite"
|
||||||
| "alwaysAllowExecute"
|
| "alwaysAllowExecute"
|
||||||
|
| "alwaysAllowBrowser"
|
||||||
|
| "alwaysAllowMcp"
|
||||||
|
| "alwaysAllowModeSwitch"
|
||||||
| "webviewDidLaunch"
|
| "webviewDidLaunch"
|
||||||
| "newTask"
|
| "newTask"
|
||||||
| "askResponse"
|
| "askResponse"
|
||||||
@@ -83,6 +86,7 @@ export interface WebviewMessage {
|
|||||||
| "deleteCustomMode"
|
| "deleteCustomMode"
|
||||||
| "setopenAiCustomModelInfo"
|
| "setopenAiCustomModelInfo"
|
||||||
| "openCustomModesSettings"
|
| "openCustomModesSettings"
|
||||||
|
| "updateAzureAiDeployment"
|
||||||
text?: string
|
text?: string
|
||||||
disabled?: boolean
|
disabled?: boolean
|
||||||
askResponse?: ClineAskResponse
|
askResponse?: ClineAskResponse
|
||||||
@@ -104,6 +108,12 @@ export interface WebviewMessage {
|
|||||||
slug?: string
|
slug?: string
|
||||||
modeConfig?: ModeConfig
|
modeConfig?: ModeConfig
|
||||||
timeout?: number
|
timeout?: number
|
||||||
|
azureAiDeployment?: {
|
||||||
|
modelId: string
|
||||||
|
name: string
|
||||||
|
apiVersion: string
|
||||||
|
modelMeshName?: string
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ClineAskResponse = "yesButtonClicked" | "noButtonClicked" | "messageResponse"
|
export type ClineAskResponse = "yesButtonClicked" | "noButtonClicked" | "messageResponse"
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ export type ApiProvider =
|
|||||||
| "vscode-lm"
|
| "vscode-lm"
|
||||||
| "mistral"
|
| "mistral"
|
||||||
| "unbound"
|
| "unbound"
|
||||||
|
| "azure-ai"
|
||||||
|
|
||||||
export interface ApiHandlerOptions {
|
export interface ApiHandlerOptions {
|
||||||
apiModelId?: string
|
apiModelId?: string
|
||||||
@@ -60,6 +61,15 @@ export interface ApiHandlerOptions {
|
|||||||
includeMaxTokens?: boolean
|
includeMaxTokens?: boolean
|
||||||
unboundApiKey?: string
|
unboundApiKey?: string
|
||||||
unboundModelId?: string
|
unboundModelId?: string
|
||||||
|
azureAiEndpoint?: string
|
||||||
|
azureAiKey?: string
|
||||||
|
azureAiDeployments?: {
|
||||||
|
[key in AzureAiModelId]?: {
|
||||||
|
name: string
|
||||||
|
apiVersion: string
|
||||||
|
modelMeshName?: string
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ApiConfiguration = ApiHandlerOptions & {
|
export type ApiConfiguration = ApiHandlerOptions & {
|
||||||
@@ -635,3 +645,50 @@ export const unboundModels = {
|
|||||||
"deepseek/deepseek-reasoner": deepSeekModels["deepseek-reasoner"],
|
"deepseek/deepseek-reasoner": deepSeekModels["deepseek-reasoner"],
|
||||||
"mistral/codestral-latest": mistralModels["codestral-latest"],
|
"mistral/codestral-latest": mistralModels["codestral-latest"],
|
||||||
} as const satisfies Record<string, ModelInfo>
|
} as const satisfies Record<string, ModelInfo>
|
||||||
|
|
||||||
|
// Azure AI
|
||||||
|
export type AzureAiModelId = "azure-gpt-35" | "azure-gpt-4" | "azure-gpt-4-turbo"
|
||||||
|
|
||||||
|
export interface AzureDeploymentConfig {
|
||||||
|
name: string
|
||||||
|
apiVersion: string
|
||||||
|
modelMeshName?: string // For Model-Mesh deployments
|
||||||
|
}
|
||||||
|
|
||||||
|
export const azureAiModels: Record<AzureAiModelId, ModelInfo & { defaultDeployment: AzureDeploymentConfig }> = {
|
||||||
|
"azure-gpt-35": {
|
||||||
|
maxTokens: 4096,
|
||||||
|
contextWindow: 16385,
|
||||||
|
supportsPromptCache: true,
|
||||||
|
inputPrice: 0.0015,
|
||||||
|
outputPrice: 0.002,
|
||||||
|
defaultDeployment: {
|
||||||
|
name: "azure-gpt-35",
|
||||||
|
apiVersion: "2024-02-15-preview"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"azure-gpt-4": {
|
||||||
|
maxTokens: 8192,
|
||||||
|
contextWindow: 8192,
|
||||||
|
supportsPromptCache: true,
|
||||||
|
inputPrice: 0.03,
|
||||||
|
outputPrice: 0.06,
|
||||||
|
defaultDeployment: {
|
||||||
|
name: "azure-gpt-4",
|
||||||
|
apiVersion: "2024-02-15-preview"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"azure-gpt-4-turbo": {
|
||||||
|
maxTokens: 4096,
|
||||||
|
contextWindow: 128000,
|
||||||
|
supportsPromptCache: true,
|
||||||
|
inputPrice: 0.01,
|
||||||
|
outputPrice: 0.03,
|
||||||
|
defaultDeployment: {
|
||||||
|
name: "azure-gpt-4-turbo",
|
||||||
|
apiVersion: "2024-02-15-preview"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} as const satisfies Record<AzureAiModelId, ModelInfo & { defaultDeployment: AzureDeploymentConfig }>
|
||||||
|
|
||||||
|
export const azureAiDefaultModelId: AzureAiModelId = "azure-gpt-35"
|
||||||
|
|||||||
Reference in New Issue
Block a user