mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-20 04:11:10 -05:00
fix: update Azure AI deployment handling to support dynamic model IDs and custom deployment names
This commit is contained in:
@@ -5,167 +5,180 @@ import ModelClient from "@azure-rest/ai-inference"
|
|||||||
|
|
||||||
// Mock the Azure AI client
|
// Mock the Azure AI client
|
||||||
jest.mock("@azure-rest/ai-inference", () => {
|
jest.mock("@azure-rest/ai-inference", () => {
|
||||||
return {
|
const mockClient = jest.fn().mockImplementation(() => ({
|
||||||
__esModule: true,
|
path: jest.fn().mockReturnValue({
|
||||||
default: jest.fn().mockImplementation(() => ({
|
post: jest.fn(),
|
||||||
path: jest.fn().mockReturnValue({
|
}),
|
||||||
post: jest.fn()
|
}))
|
||||||
})
|
|
||||||
})),
|
return {
|
||||||
isUnexpected: jest.fn()
|
__esModule: true,
|
||||||
}
|
default: mockClient,
|
||||||
|
isUnexpected: jest.fn(),
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
describe("AzureAiHandler", () => {
|
describe("AzureAiHandler", () => {
|
||||||
const mockOptions: ApiHandlerOptions = {
|
const mockOptions: ApiHandlerOptions = {
|
||||||
apiProvider: "azure-ai",
|
apiModelId: "azure-gpt-35",
|
||||||
apiModelId: "azure-gpt-35",
|
azureAiEndpoint: "https://test-resource.inference.azure.com",
|
||||||
azureAiEndpoint: "https://test-resource.inference.azure.com",
|
azureAiKey: "test-key",
|
||||||
azureAiKey: "test-key",
|
}
|
||||||
azureAiDeployments: {
|
|
||||||
"azure-gpt-35": {
|
|
||||||
name: "custom-gpt35",
|
|
||||||
apiVersion: "2024-02-15-preview",
|
|
||||||
modelMeshName: "test-mesh-model"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
jest.clearAllMocks()
|
||||||
})
|
})
|
||||||
|
|
||||||
test("constructs with required options", () => {
|
test("constructs with required options", () => {
|
||||||
const handler = new AzureAiHandler(mockOptions)
|
const handler = new AzureAiHandler(mockOptions)
|
||||||
expect(handler).toBeInstanceOf(AzureAiHandler)
|
expect(handler).toBeInstanceOf(AzureAiHandler)
|
||||||
})
|
})
|
||||||
|
|
||||||
test("throws error without endpoint", () => {
|
test("throws error without endpoint", () => {
|
||||||
const invalidOptions = { ...mockOptions }
|
const invalidOptions = { ...mockOptions }
|
||||||
delete invalidOptions.azureAiEndpoint
|
delete invalidOptions.azureAiEndpoint
|
||||||
expect(() => new AzureAiHandler(invalidOptions)).toThrow("Azure AI endpoint is required")
|
expect(() => new AzureAiHandler(invalidOptions)).toThrow("Azure AI endpoint is required")
|
||||||
})
|
})
|
||||||
|
|
||||||
test("throws error without API key", () => {
|
test("throws error without API key", () => {
|
||||||
const invalidOptions = { ...mockOptions }
|
const invalidOptions = { ...mockOptions }
|
||||||
delete invalidOptions.azureAiKey
|
delete invalidOptions.azureAiKey
|
||||||
expect(() => new AzureAiHandler(invalidOptions)).toThrow("Azure AI key is required")
|
expect(() => new AzureAiHandler(invalidOptions)).toThrow("Azure AI key is required")
|
||||||
})
|
})
|
||||||
|
|
||||||
test("creates chat completion correctly", async () => {
|
test("creates chat completion correctly", async () => {
|
||||||
const handler = new AzureAiHandler(mockOptions)
|
const handler = new AzureAiHandler(mockOptions)
|
||||||
const mockResponse = {
|
const mockResponse = {
|
||||||
body: {
|
body: {
|
||||||
choices: [
|
choices: [
|
||||||
{
|
{
|
||||||
message: {
|
message: {
|
||||||
content: "test response"
|
content: "test response",
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
const mockClient = ModelClient as jest.MockedClass<typeof ModelClient>
|
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
|
||||||
mockClient.prototype.path.mockReturnValue({
|
mockClient.mockReturnValue({
|
||||||
post: jest.fn().mockResolvedValue(mockResponse)
|
path: jest.fn().mockReturnValue({
|
||||||
})
|
post: jest.fn().mockResolvedValue(mockResponse),
|
||||||
|
}),
|
||||||
|
} as any)
|
||||||
|
|
||||||
const result = await handler.completePrompt("test prompt")
|
const result = await handler.completePrompt("test prompt")
|
||||||
expect(result).toBe("test response")
|
expect(result).toBe("test response")
|
||||||
|
})
|
||||||
|
|
||||||
expect(mockClient.prototype.path).toHaveBeenCalledWith("/chat/completions")
|
test("handles streaming responses correctly", async () => {
|
||||||
expect(mockClient.prototype.path().post).toHaveBeenCalledWith({
|
const handler = new AzureAiHandler(mockOptions)
|
||||||
body: {
|
const mockStream = new Readable({
|
||||||
messages: [{ role: "user", content: "test prompt" }],
|
read() {
|
||||||
temperature: 0
|
this.push('data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}\n\n')
|
||||||
}
|
this.push(
|
||||||
})
|
'data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2}}\n\n',
|
||||||
})
|
)
|
||||||
|
this.push("data: [DONE]\n\n")
|
||||||
|
this.push(null)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
test("handles streaming responses correctly", async () => {
|
const mockResponse = {
|
||||||
const handler = new AzureAiHandler(mockOptions)
|
status: 200,
|
||||||
const mockStream = Readable.from([
|
body: mockStream,
|
||||||
'data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}\n\n',
|
}
|
||||||
'data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2}}\n\n',
|
|
||||||
'data: [DONE]\n\n'
|
|
||||||
])
|
|
||||||
|
|
||||||
const mockClient = ModelClient as jest.MockedClass<typeof ModelClient>
|
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
|
||||||
mockClient.prototype.path.mockReturnValue({
|
mockClient.mockReturnValue({
|
||||||
post: jest.fn().mockResolvedValue({
|
path: jest.fn().mockReturnValue({
|
||||||
status: 200,
|
post: jest.fn().mockReturnValue({
|
||||||
body: mockStream,
|
asNodeStream: () => Promise.resolve(mockResponse),
|
||||||
})
|
}),
|
||||||
})
|
}),
|
||||||
|
} as any)
|
||||||
|
|
||||||
const messages = []
|
const messages = []
|
||||||
for await (const message of handler.createMessage("system prompt", [])) {
|
for await (const message of handler.createMessage("system prompt", [])) {
|
||||||
messages.push(message)
|
messages.push(message)
|
||||||
}
|
}
|
||||||
|
|
||||||
expect(messages).toEqual([
|
expect(messages).toEqual([
|
||||||
{ type: "text", text: "Hello" },
|
{ type: "text", text: "Hello" },
|
||||||
{ type: "text", text: " world" },
|
{ type: "text", text: " world" },
|
||||||
{ type: "usage", inputTokens: 10, outputTokens: 2 }
|
{ type: "usage", inputTokens: 10, outputTokens: 2 },
|
||||||
])
|
])
|
||||||
|
})
|
||||||
|
|
||||||
expect(mockClient.prototype.path().post).toHaveBeenCalledWith({
|
test("handles rate limit errors", async () => {
|
||||||
body: {
|
const handler = new AzureAiHandler(mockOptions)
|
||||||
messages: [{ role: "system", content: "system prompt" }],
|
const mockError = new Error("Rate limit exceeded")
|
||||||
temperature: 0,
|
Object.assign(mockError, { status: 429 })
|
||||||
stream: true,
|
|
||||||
max_tokens: expect.any(Number)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
test("handles rate limit errors", async () => {
|
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
|
||||||
const handler = new AzureAiHandler(mockOptions)
|
mockClient.mockReturnValue({
|
||||||
const mockError = new Error("Rate limit exceeded")
|
path: jest.fn().mockReturnValue({
|
||||||
Object.assign(mockError, { status: 429 })
|
post: jest.fn().mockRejectedValue(mockError),
|
||||||
|
}),
|
||||||
|
} as any)
|
||||||
|
|
||||||
const mockClient = ModelClient as jest.MockedClass<typeof ModelClient>
|
await expect(handler.completePrompt("test")).rejects.toThrow(
|
||||||
mockClient.prototype.path.mockReturnValue({
|
"Azure AI rate limit exceeded. Please try again later.",
|
||||||
post: jest.fn().mockRejectedValue(mockError)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
await expect(handler.completePrompt("test")).rejects.toThrow(
|
test("handles content safety errors", async () => {
|
||||||
"Azure AI rate limit exceeded. Please try again later."
|
const handler = new AzureAiHandler(mockOptions)
|
||||||
)
|
const mockError = {
|
||||||
})
|
status: 400,
|
||||||
|
body: {
|
||||||
|
error: {
|
||||||
|
code: "ContentFilterError",
|
||||||
|
message: "Content was flagged by content safety filters",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
test("handles content safety errors", async () => {
|
const mockClient = ModelClient as jest.MockedFunction<typeof ModelClient>
|
||||||
const handler = new AzureAiHandler(mockOptions)
|
mockClient.mockReturnValue({
|
||||||
const mockError = {
|
path: jest.fn().mockReturnValue({
|
||||||
status: 400,
|
post: jest.fn().mockRejectedValue(mockError),
|
||||||
body: {
|
}),
|
||||||
error: {
|
} as any)
|
||||||
code: "ContentFilterError",
|
|
||||||
message: "Content was flagged by content safety filters"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const mockClient = ModelClient as jest.MockedClass<typeof ModelClient>
|
await expect(handler.completePrompt("test")).rejects.toThrow(
|
||||||
mockClient.prototype.path.mockReturnValue({
|
"Content was flagged by Azure AI content safety filters",
|
||||||
post: jest.fn().mockRejectedValue(mockError)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
await expect(handler.completePrompt("test")).rejects.toThrow(
|
test("falls back to default model configuration", () => {
|
||||||
"Azure AI completion error: Content was flagged by content safety filters"
|
const handler = new AzureAiHandler({
|
||||||
)
|
azureAiEndpoint: "https://test.azure.com",
|
||||||
})
|
azureAiKey: "test-key",
|
||||||
|
})
|
||||||
|
const model = handler.getModel()
|
||||||
|
|
||||||
test("falls back to default model configuration", async () => {
|
expect(model.id).toBe("azure-gpt-35")
|
||||||
const options = { ...mockOptions }
|
expect(model.info).toBeDefined()
|
||||||
delete options.azureAiDeployments
|
})
|
||||||
|
|
||||||
const handler = new AzureAiHandler(options)
|
test("supports custom deployment names", async () => {
|
||||||
const model = handler.getModel()
|
const customOptions = {
|
||||||
|
...mockOptions,
|
||||||
|
apiModelId: "custom-model",
|
||||||
|
azureAiDeployments: {
|
||||||
|
"custom-model": {
|
||||||
|
name: "my-custom-deployment",
|
||||||
|
apiVersion: "2024-02-15-preview",
|
||||||
|
modelMeshName: "my-custom-model",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
expect(model.id).toBe("azure-gpt-35")
|
const handler = new AzureAiHandler(customOptions)
|
||||||
expect(model.info).toBeDefined()
|
const model = handler.getModel()
|
||||||
expect(model.info.defaultDeployment.name).toBe("azure-gpt-35")
|
|
||||||
})
|
expect(model.id).toBe("custom-model")
|
||||||
|
expect(model.info).toBeDefined()
|
||||||
|
})
|
||||||
})
|
})
|
||||||
@@ -2,22 +2,17 @@ import { Anthropic } from "@anthropic-ai/sdk"
|
|||||||
import ModelClient from "@azure-rest/ai-inference"
|
import ModelClient from "@azure-rest/ai-inference"
|
||||||
import { isUnexpected } from "@azure-rest/ai-inference"
|
import { isUnexpected } from "@azure-rest/ai-inference"
|
||||||
import { AzureKeyCredential } from "@azure/core-auth"
|
import { AzureKeyCredential } from "@azure/core-auth"
|
||||||
import {
|
import { ApiHandlerOptions, ModelInfo, AzureDeploymentConfig } from "../../shared/api"
|
||||||
ApiHandlerOptions,
|
|
||||||
ModelInfo,
|
|
||||||
azureAiDefaultModelId,
|
|
||||||
AzureAiModelId,
|
|
||||||
azureAiModels,
|
|
||||||
AzureDeploymentConfig,
|
|
||||||
} from "../../shared/api"
|
|
||||||
import { ApiHandler, SingleCompletionHandler } from "../index"
|
import { ApiHandler, SingleCompletionHandler } from "../index"
|
||||||
import { convertToOpenAiMessages } from "../transform/openai-format"
|
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||||
import { ApiStream } from "../transform/stream"
|
import { ApiStream } from "../transform/stream"
|
||||||
import { createSseStream } from "@azure/core-rest-pipeline"
|
|
||||||
|
const DEFAULT_API_VERSION = "2024-02-15-preview"
|
||||||
|
const DEFAULT_MAX_TOKENS = 4096
|
||||||
|
|
||||||
export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
|
export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
|
||||||
private options: ApiHandlerOptions
|
private options: ApiHandlerOptions
|
||||||
private client: ModelClient
|
private client: ReturnType<typeof ModelClient>
|
||||||
|
|
||||||
constructor(options: ApiHandlerOptions) {
|
constructor(options: ApiHandlerOptions) {
|
||||||
this.options = options
|
this.options = options
|
||||||
@@ -30,22 +25,36 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
throw new Error("Azure AI key is required")
|
throw new Error("Azure AI key is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
this.client = new ModelClient(options.azureAiEndpoint, new AzureKeyCredential(options.azureAiKey))
|
this.client = ModelClient(options.azureAiEndpoint, new AzureKeyCredential(options.azureAiKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
private getDeploymentConfig(): AzureDeploymentConfig {
|
private getDeploymentConfig(): AzureDeploymentConfig {
|
||||||
const model = this.getModel()
|
const modelId = this.options.apiModelId
|
||||||
const defaultConfig = azureAiModels[model.id].defaultDeployment
|
if (!modelId) {
|
||||||
|
return {
|
||||||
|
name: "gpt-35-turbo", // Default deployment name if none specified
|
||||||
|
apiVersion: DEFAULT_API_VERSION,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const customConfig = this.options.azureAiDeployments?.[modelId]
|
||||||
|
if (customConfig) {
|
||||||
|
return {
|
||||||
|
name: customConfig.name,
|
||||||
|
apiVersion: customConfig.apiVersion || DEFAULT_API_VERSION,
|
||||||
|
modelMeshName: customConfig.modelMeshName,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no custom config, use model ID as deployment name
|
||||||
return {
|
return {
|
||||||
name: this.options.azureAiDeployments?.[model.id]?.name || defaultConfig.name,
|
name: modelId,
|
||||||
apiVersion: this.options.azureAiDeployments?.[model.id]?.apiVersion || defaultConfig.apiVersion,
|
apiVersion: DEFAULT_API_VERSION,
|
||||||
modelMeshName: this.options.azureAiDeployments?.[model.id]?.modelMeshName,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
||||||
const modelInfo = this.getModel().info
|
const deployment = this.getDeploymentConfig()
|
||||||
const chatMessages = [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)]
|
const chatMessages = [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)]
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@@ -56,12 +65,12 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
messages: chatMessages,
|
messages: chatMessages,
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
stream: true,
|
stream: true,
|
||||||
max_tokens: modelInfo.maxTokens,
|
max_tokens: DEFAULT_MAX_TOKENS,
|
||||||
response_format: { type: "text" }, // Ensure text format for chat
|
response_format: { type: "text" },
|
||||||
},
|
},
|
||||||
headers: this.getDeploymentConfig().modelMeshName
|
headers: deployment.modelMeshName
|
||||||
? {
|
? {
|
||||||
"x-ms-model-mesh-model-name": this.getDeploymentConfig().modelMeshName,
|
"x-ms-model-mesh-model-name": deployment.modelMeshName,
|
||||||
}
|
}
|
||||||
: undefined,
|
: undefined,
|
||||||
})
|
})
|
||||||
@@ -69,22 +78,22 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
|
|
||||||
const stream = response.body
|
const stream = response.body
|
||||||
if (!stream) {
|
if (!stream) {
|
||||||
throw new Error(`Failed to get chat completions with status: ${response.status}`)
|
throw new Error("Failed to get chat completions stream")
|
||||||
}
|
}
|
||||||
|
|
||||||
if (response.status !== 200) {
|
const statusCode = Number(response.status)
|
||||||
throw new Error(`Failed to get chat completions: ${response.body.error}`)
|
if (statusCode !== 200) {
|
||||||
|
throw new Error(`Failed to get chat completions: HTTP ${statusCode}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
const sseStream = createSseStream(stream)
|
for await (const chunk of stream) {
|
||||||
|
const chunkStr = chunk.toString()
|
||||||
for await (const event of sseStream) {
|
if (chunkStr === "data: [DONE]\n\n") {
|
||||||
if (event.data === "[DONE]") {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const data = JSON.parse(event.data)
|
const data = JSON.parse(chunkStr.replace("data: ", ""))
|
||||||
const delta = data.choices[0]?.delta
|
const delta = data.choices[0]?.delta
|
||||||
|
|
||||||
if (delta?.content) {
|
if (delta?.content) {
|
||||||
@@ -124,26 +133,29 @@ export class AzureAiHandler implements ApiHandler, SingleCompletionHandler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
getModel(): { id: AzureAiModelId; info: ModelInfo } {
|
getModel(): { id: string; info: ModelInfo } {
|
||||||
const modelId = this.options.apiModelId
|
return {
|
||||||
if (modelId && modelId in azureAiModels) {
|
id: this.options.apiModelId || "gpt-35-turbo",
|
||||||
const id = modelId as AzureAiModelId
|
info: {
|
||||||
return { id, info: azureAiModels[id] }
|
maxTokens: DEFAULT_MAX_TOKENS,
|
||||||
|
contextWindow: 16385, // Conservative default
|
||||||
|
supportsPromptCache: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
return { id: azureAiDefaultModelId, info: azureAiModels[azureAiDefaultModelId] }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async completePrompt(prompt: string): Promise<string> {
|
async completePrompt(prompt: string): Promise<string> {
|
||||||
try {
|
try {
|
||||||
|
const deployment = this.getDeploymentConfig()
|
||||||
const response = await this.client.path("/chat/completions").post({
|
const response = await this.client.path("/chat/completions").post({
|
||||||
body: {
|
body: {
|
||||||
messages: [{ role: "user", content: prompt }],
|
messages: [{ role: "user", content: prompt }],
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
response_format: { type: "text" },
|
response_format: { type: "text" },
|
||||||
},
|
},
|
||||||
headers: this.getDeploymentConfig().modelMeshName
|
headers: deployment.modelMeshName
|
||||||
? {
|
? {
|
||||||
"x-ms-model-mesh-model-name": this.getDeploymentConfig().modelMeshName,
|
"x-ms-model-mesh-model-name": deployment.modelMeshName,
|
||||||
}
|
}
|
||||||
: undefined,
|
: undefined,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ type GlobalStateKey =
|
|||||||
| "lmStudioBaseUrl"
|
| "lmStudioBaseUrl"
|
||||||
| "anthropicBaseUrl"
|
| "anthropicBaseUrl"
|
||||||
| "azureApiVersion"
|
| "azureApiVersion"
|
||||||
| "azureAiDeployments"
|
| "azureAiDeployments"
|
||||||
| "openAiStreamingEnabled"
|
| "openAiStreamingEnabled"
|
||||||
| "openRouterModelId"
|
| "openRouterModelId"
|
||||||
| "openRouterModelInfo"
|
| "openRouterModelInfo"
|
||||||
@@ -1075,16 +1075,25 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
await this.updateGlobalState("autoApprovalEnabled", message.bool ?? false)
|
await this.updateGlobalState("autoApprovalEnabled", message.bool ?? false)
|
||||||
await this.postStateToWebview()
|
await this.postStateToWebview()
|
||||||
break
|
break
|
||||||
case "updateAzureAiDeployment":
|
case "updateAzureAiDeployment":
|
||||||
if (message.azureAiDeployment) {
|
if (message.azureAiDeployment) {
|
||||||
const deployments = await this.getGlobalState("azureAiDeployments") || {}
|
const deployments = ((await this.getGlobalState("azureAiDeployments")) || {}) as Record<
|
||||||
deployments[message.azureAiDeployment.modelId] = {
|
string,
|
||||||
...message.azureAiDeployment,
|
{
|
||||||
}
|
name: string
|
||||||
await this.updateGlobalState("azureAiDeployments", deployments)
|
apiVersion: string
|
||||||
await this.postStateToWebview()
|
modelMeshName?: string
|
||||||
}
|
}
|
||||||
break
|
>
|
||||||
|
deployments[message.azureAiDeployment.modelId] = {
|
||||||
|
name: message.azureAiDeployment.name,
|
||||||
|
apiVersion: message.azureAiDeployment.apiVersion,
|
||||||
|
modelMeshName: message.azureAiDeployment.modelMeshName,
|
||||||
|
}
|
||||||
|
await this.updateGlobalState("azureAiDeployments", deployments)
|
||||||
|
await this.postStateToWebview()
|
||||||
|
}
|
||||||
|
break
|
||||||
case "enhancePrompt":
|
case "enhancePrompt":
|
||||||
if (message.text) {
|
if (message.text) {
|
||||||
try {
|
try {
|
||||||
@@ -1517,7 +1526,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
|
await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
|
||||||
await this.storeSecret("deepSeekApiKey", deepSeekApiKey)
|
await this.storeSecret("deepSeekApiKey", deepSeekApiKey)
|
||||||
await this.updateGlobalState("azureApiVersion", azureApiVersion)
|
await this.updateGlobalState("azureApiVersion", azureApiVersion)
|
||||||
await this.updateGlobalState("azureAiDeployments", apiConfiguration.azureAiDeployments)
|
await this.updateGlobalState("azureAiDeployments", apiConfiguration.azureAiDeployments)
|
||||||
await this.updateGlobalState("openAiStreamingEnabled", openAiStreamingEnabled)
|
await this.updateGlobalState("openAiStreamingEnabled", openAiStreamingEnabled)
|
||||||
await this.updateGlobalState("openRouterModelId", openRouterModelId)
|
await this.updateGlobalState("openRouterModelId", openRouterModelId)
|
||||||
await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo)
|
await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo)
|
||||||
@@ -2159,7 +2168,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
openAiNativeApiKey,
|
openAiNativeApiKey,
|
||||||
deepSeekApiKey,
|
deepSeekApiKey,
|
||||||
mistralApiKey,
|
mistralApiKey,
|
||||||
azureAiDeployments,
|
azureAiDeployments,
|
||||||
azureApiVersion,
|
azureApiVersion,
|
||||||
openAiStreamingEnabled,
|
openAiStreamingEnabled,
|
||||||
openRouterModelId,
|
openRouterModelId,
|
||||||
@@ -2234,7 +2243,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
|
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
|
||||||
this.getSecret("deepSeekApiKey") as Promise<string | undefined>,
|
this.getSecret("deepSeekApiKey") as Promise<string | undefined>,
|
||||||
this.getSecret("mistralApiKey") as Promise<string | undefined>,
|
this.getSecret("mistralApiKey") as Promise<string | undefined>,
|
||||||
this.getGlobalState("azureAiDeployments") as Promise<Record<string, any> | undefined>,
|
this.getGlobalState("azureAiDeployments") as Promise<Record<string, any> | undefined>,
|
||||||
this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
|
this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
|
||||||
this.getGlobalState("openAiStreamingEnabled") as Promise<boolean | undefined>,
|
this.getGlobalState("openAiStreamingEnabled") as Promise<boolean | undefined>,
|
||||||
this.getGlobalState("openRouterModelId") as Promise<string | undefined>,
|
this.getGlobalState("openRouterModelId") as Promise<string | undefined>,
|
||||||
@@ -2327,7 +2336,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
deepSeekApiKey,
|
deepSeekApiKey,
|
||||||
mistralApiKey,
|
mistralApiKey,
|
||||||
azureApiVersion,
|
azureApiVersion,
|
||||||
azureAiDeployments,
|
azureAiDeployments,
|
||||||
openAiStreamingEnabled,
|
openAiStreamingEnabled,
|
||||||
openRouterModelId,
|
openRouterModelId,
|
||||||
openRouterModelInfo,
|
openRouterModelInfo,
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ export type ApiProvider =
|
|||||||
| "vscode-lm"
|
| "vscode-lm"
|
||||||
| "mistral"
|
| "mistral"
|
||||||
| "unbound"
|
| "unbound"
|
||||||
| "azure-ai"
|
| "azure-ai"
|
||||||
|
|
||||||
export interface ApiHandlerOptions {
|
export interface ApiHandlerOptions {
|
||||||
apiModelId?: string
|
apiModelId?: string
|
||||||
@@ -61,15 +61,17 @@ export interface ApiHandlerOptions {
|
|||||||
includeMaxTokens?: boolean
|
includeMaxTokens?: boolean
|
||||||
unboundApiKey?: string
|
unboundApiKey?: string
|
||||||
unboundModelId?: string
|
unboundModelId?: string
|
||||||
azureAiEndpoint?: string
|
azureAiEndpoint?: string
|
||||||
azureAiKey?: string
|
azureAiKey?: string
|
||||||
azureAiDeployments?: {
|
azureAiDeployments?:
|
||||||
[key in AzureAiModelId]?: {
|
| {
|
||||||
name: string
|
[key: string]: {
|
||||||
apiVersion: string
|
name: string
|
||||||
modelMeshName?: string
|
apiVersion: string
|
||||||
}
|
modelMeshName?: string
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
| undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ApiConfiguration = ApiHandlerOptions & {
|
export type ApiConfiguration = ApiHandlerOptions & {
|
||||||
@@ -650,45 +652,45 @@ export const unboundModels = {
|
|||||||
export type AzureAiModelId = "azure-gpt-35" | "azure-gpt-4" | "azure-gpt-4-turbo"
|
export type AzureAiModelId = "azure-gpt-35" | "azure-gpt-4" | "azure-gpt-4-turbo"
|
||||||
|
|
||||||
export interface AzureDeploymentConfig {
|
export interface AzureDeploymentConfig {
|
||||||
name: string
|
name: string
|
||||||
apiVersion: string
|
apiVersion: string
|
||||||
modelMeshName?: string // For Model-Mesh deployments
|
modelMeshName?: string // For Model-Mesh deployments
|
||||||
}
|
}
|
||||||
|
|
||||||
export const azureAiModels: Record<AzureAiModelId, ModelInfo & { defaultDeployment: AzureDeploymentConfig }> = {
|
export const azureAiModels: Record<AzureAiModelId, ModelInfo & { defaultDeployment: AzureDeploymentConfig }> = {
|
||||||
"azure-gpt-35": {
|
"azure-gpt-35": {
|
||||||
maxTokens: 4096,
|
maxTokens: 4096,
|
||||||
contextWindow: 16385,
|
contextWindow: 16385,
|
||||||
supportsPromptCache: true,
|
supportsPromptCache: true,
|
||||||
inputPrice: 0.0015,
|
inputPrice: 0.0015,
|
||||||
outputPrice: 0.002,
|
outputPrice: 0.002,
|
||||||
defaultDeployment: {
|
defaultDeployment: {
|
||||||
name: "azure-gpt-35",
|
name: "azure-gpt-35",
|
||||||
apiVersion: "2024-02-15-preview"
|
apiVersion: "2024-02-15-preview",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"azure-gpt-4": {
|
"azure-gpt-4": {
|
||||||
maxTokens: 8192,
|
maxTokens: 8192,
|
||||||
contextWindow: 8192,
|
contextWindow: 8192,
|
||||||
supportsPromptCache: true,
|
supportsPromptCache: true,
|
||||||
inputPrice: 0.03,
|
inputPrice: 0.03,
|
||||||
outputPrice: 0.06,
|
outputPrice: 0.06,
|
||||||
defaultDeployment: {
|
defaultDeployment: {
|
||||||
name: "azure-gpt-4",
|
name: "azure-gpt-4",
|
||||||
apiVersion: "2024-02-15-preview"
|
apiVersion: "2024-02-15-preview",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"azure-gpt-4-turbo": {
|
"azure-gpt-4-turbo": {
|
||||||
maxTokens: 4096,
|
maxTokens: 4096,
|
||||||
contextWindow: 128000,
|
contextWindow: 128000,
|
||||||
supportsPromptCache: true,
|
supportsPromptCache: true,
|
||||||
inputPrice: 0.01,
|
inputPrice: 0.01,
|
||||||
outputPrice: 0.03,
|
outputPrice: 0.03,
|
||||||
defaultDeployment: {
|
defaultDeployment: {
|
||||||
name: "azure-gpt-4-turbo",
|
name: "azure-gpt-4-turbo",
|
||||||
apiVersion: "2024-02-15-preview"
|
apiVersion: "2024-02-15-preview",
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
} as const satisfies Record<AzureAiModelId, ModelInfo & { defaultDeployment: AzureDeploymentConfig }>
|
} as const satisfies Record<AzureAiModelId, ModelInfo & { defaultDeployment: AzureDeploymentConfig }>
|
||||||
|
|
||||||
export const azureAiDefaultModelId: AzureAiModelId = "azure-gpt-35"
|
export const azureAiDefaultModelId: AzureAiModelId = "azure-gpt-35"
|
||||||
|
|||||||
Reference in New Issue
Block a user