From 077fa843746c33e61ee702e479862352227539ab Mon Sep 17 00:00:00 2001 From: Saoud Rizwan <7799382+saoudrizwan@users.noreply.github.com> Date: Thu, 16 Jan 2025 19:40:27 -0800 Subject: [PATCH] Add Mistral API provider --- package-lock.json | 9 ++ package.json | 1 + src/api/index.ts | 3 + src/api/providers/mistral.ts | 74 +++++++++++++++ src/api/transform/mistral-format.ts | 92 +++++++++++++++++++ src/core/webview/ClineProvider.ts | 7 ++ .../__tests__/checkExistApiConfig.test.ts | 1 + src/shared/api.ts | 17 ++++ src/shared/checkExistApiConfig.ts | 1 + .../src/components/settings/ApiOptions.tsx | 37 ++++++++ webview-ui/src/utils/validate.ts | 5 + 11 files changed, 247 insertions(+) create mode 100644 src/api/providers/mistral.ts create mode 100644 src/api/transform/mistral-format.ts diff --git a/package-lock.json b/package-lock.json index 133fb44..f235d2a 100644 --- a/package-lock.json +++ b/package-lock.json @@ -13,6 +13,7 @@ "@anthropic-ai/vertex-sdk": "^0.4.1", "@aws-sdk/client-bedrock-runtime": "^3.706.0", "@google/generative-ai": "^0.18.0", + "@mistralai/mistralai": "^1.3.6", "@modelcontextprotocol/sdk": "^1.0.1", "@types/clone-deep": "^4.0.4", "@types/pdf-parse": "^1.1.4", @@ -4254,6 +4255,14 @@ "node": ">=8" } }, + "node_modules/@mistralai/mistralai": { + "version": "1.3.6", + "resolved": "https://registry.npmjs.org/@mistralai/mistralai/-/mistralai-1.3.6.tgz", + "integrity": "sha512-2y7U5riZq+cIjKpxGO9y417XuZv9CpBXEAvbjRMzWPGhXY7U1ZXj4VO4H9riS2kFZqTR2yLEKSE6/pGWVVIqgQ==", + "peerDependencies": { + "zod": ">= 3" + } + }, "node_modules/@mixmark-io/domino": { "version": "2.2.0", "resolved": "https://registry.npmjs.org/@mixmark-io/domino/-/domino-2.2.0.tgz", diff --git a/package.json b/package.json index 5632d3c..b02f834 100644 --- a/package.json +++ b/package.json @@ -226,6 +226,7 @@ "@anthropic-ai/vertex-sdk": "^0.4.1", "@aws-sdk/client-bedrock-runtime": "^3.706.0", "@google/generative-ai": "^0.18.0", + "@mistralai/mistralai": "^1.3.6", "@modelcontextprotocol/sdk": "^1.0.1", "@types/clone-deep": "^4.0.4", "@types/pdf-parse": "^1.1.4", diff --git a/src/api/index.ts b/src/api/index.ts index 647c538..641c50d 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -11,6 +11,7 @@ import { LmStudioHandler } from "./providers/lmstudio" import { GeminiHandler } from "./providers/gemini" import { OpenAiNativeHandler } from "./providers/openai-native" import { DeepSeekHandler } from "./providers/deepseek" +import { MistralHandler } from "./providers/mistral" import { VsCodeLmHandler } from "./providers/vscode-lm" import { ApiStream } from "./transform/stream" @@ -50,6 +51,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler { return new DeepSeekHandler(options) case "vscode-lm": return new VsCodeLmHandler(options) + case "mistral": + return new MistralHandler(options) default: return new AnthropicHandler(options) } diff --git a/src/api/providers/mistral.ts b/src/api/providers/mistral.ts new file mode 100644 index 0000000..c4377f0 --- /dev/null +++ b/src/api/providers/mistral.ts @@ -0,0 +1,74 @@ +import { Anthropic } from "@anthropic-ai/sdk" +import { Mistral } from "@mistralai/mistralai" +import { ApiHandler } from "../" +import { + ApiHandlerOptions, + mistralDefaultModelId, + MistralModelId, + mistralModels, + ModelInfo, + openAiNativeDefaultModelId, + OpenAiNativeModelId, + openAiNativeModels, +} from "../../shared/api" +import { convertToMistralMessages } from "../transform/mistral-format" +import { ApiStream } from "../transform/stream" + +export class MistralHandler implements ApiHandler { + private options: ApiHandlerOptions + private client: Mistral + + constructor(options: ApiHandlerOptions) { + this.options = options + this.client = new Mistral({ + serverURL: "https://codestral.mistral.ai", + apiKey: this.options.mistralApiKey, + }) + } + + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { + const stream = await this.client.chat.stream({ + model: this.getModel().id, + // max_completion_tokens: this.getModel().info.maxTokens, + temperature: 0, + messages: [{ role: "system", content: systemPrompt }, ...convertToMistralMessages(messages)], + stream: true, + }) + + for await (const chunk of stream) { + const delta = chunk.data.choices[0]?.delta + if (delta?.content) { + let content: string = "" + if (typeof delta.content === "string") { + content = delta.content + } else if (Array.isArray(delta.content)) { + content = delta.content.map((c) => (c.type === "text" ? c.text : "")).join("") + } + yield { + type: "text", + text: content, + } + } + + if (chunk.data.usage) { + yield { + type: "usage", + inputTokens: chunk.data.usage.promptTokens || 0, + outputTokens: chunk.data.usage.completionTokens || 0, + } + } + } + } + + getModel(): { id: MistralModelId; info: ModelInfo } { + const modelId = this.options.apiModelId + if (modelId && modelId in mistralModels) { + const id = modelId as MistralModelId + return { id, info: mistralModels[id] } + } + return { + id: mistralDefaultModelId, + info: mistralModels[mistralDefaultModelId], + } + } +} diff --git a/src/api/transform/mistral-format.ts b/src/api/transform/mistral-format.ts new file mode 100644 index 0000000..16c6aaf --- /dev/null +++ b/src/api/transform/mistral-format.ts @@ -0,0 +1,92 @@ +import { Anthropic } from "@anthropic-ai/sdk" +import { Mistral } from "@mistralai/mistralai" +import { AssistantMessage } from "@mistralai/mistralai/models/components/assistantmessage" +import { SystemMessage } from "@mistralai/mistralai/models/components/systemmessage" +import { ToolMessage } from "@mistralai/mistralai/models/components/toolmessage" +import { UserMessage } from "@mistralai/mistralai/models/components/usermessage" + +export type MistralMessage = + | (SystemMessage & { role: "system" }) + | (UserMessage & { role: "user" }) + | (AssistantMessage & { role: "assistant" }) + | (ToolMessage & { role: "tool" }) + +export function convertToMistralMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): MistralMessage[] { + const mistralMessages: MistralMessage[] = [] + for (const anthropicMessage of anthropicMessages) { + if (typeof anthropicMessage.content === "string") { + mistralMessages.push({ + role: anthropicMessage.role, + content: anthropicMessage.content, + }) + } else { + if (anthropicMessage.role === "user") { + const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{ + nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[] + toolMessages: Anthropic.ToolResultBlockParam[] + }>( + (acc, part) => { + if (part.type === "tool_result") { + acc.toolMessages.push(part) + } else if (part.type === "text" || part.type === "image") { + acc.nonToolMessages.push(part) + } // user cannot send tool_use messages + return acc + }, + { nonToolMessages: [], toolMessages: [] }, + ) + + if (nonToolMessages.length > 0) { + mistralMessages.push({ + role: "user", + content: nonToolMessages.map((part) => { + if (part.type === "image") { + return { + type: "image_url", + imageUrl: { + url: `data:${part.source.media_type};base64,${part.source.data}`, + }, + } + } + return { type: "text", text: part.text } + }), + }) + } + } else if (anthropicMessage.role === "assistant") { + const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{ + nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[] + toolMessages: Anthropic.ToolUseBlockParam[] + }>( + (acc, part) => { + if (part.type === "tool_use") { + acc.toolMessages.push(part) + } else if (part.type === "text" || part.type === "image") { + acc.nonToolMessages.push(part) + } // assistant cannot send tool_result messages + return acc + }, + { nonToolMessages: [], toolMessages: [] }, + ) + + let content: string | undefined + if (nonToolMessages.length > 0) { + content = nonToolMessages + .map((part) => { + if (part.type === "image") { + return "" // impossible as the assistant cannot send images + } + return part.text + }) + .join("\n") + } + + mistralMessages.push({ + role: "assistant", + content, + }) + } + } + } + + return mistralMessages +} diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 9ff3613..b4e2f76 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -49,6 +49,7 @@ type SecretKey = | "geminiApiKey" | "openAiNativeApiKey" | "deepSeekApiKey" + | "mistralApiKey" type GlobalStateKey = | "apiProvider" | "apiModelId" @@ -1120,6 +1121,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { openRouterModelInfo, openRouterUseMiddleOutTransform, vsCodeLmModelSelector, + mistralApiKey, } = apiConfiguration await this.updateGlobalState("apiProvider", apiProvider) await this.updateGlobalState("apiModelId", apiModelId) @@ -1152,6 +1154,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo) await this.updateGlobalState("openRouterUseMiddleOutTransform", openRouterUseMiddleOutTransform) await this.updateGlobalState("vsCodeLmModelSelector", vsCodeLmModelSelector) + await this.storeSecret("mistralApiKey", mistralApiKey) if (this.cline) { this.cline.api = buildApiHandler(apiConfiguration) } @@ -1766,6 +1769,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { geminiApiKey, openAiNativeApiKey, deepSeekApiKey, + mistralApiKey, azureApiVersion, openAiStreamingEnabled, openRouterModelId, @@ -1826,6 +1830,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { this.getSecret("geminiApiKey") as Promise, this.getSecret("openAiNativeApiKey") as Promise, this.getSecret("deepSeekApiKey") as Promise, + this.getSecret("mistralApiKey") as Promise, this.getGlobalState("azureApiVersion") as Promise, this.getGlobalState("openAiStreamingEnabled") as Promise, this.getGlobalState("openRouterModelId") as Promise, @@ -1903,6 +1908,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { geminiApiKey, openAiNativeApiKey, deepSeekApiKey, + mistralApiKey, azureApiVersion, openAiStreamingEnabled, openRouterModelId, @@ -2041,6 +2047,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { "geminiApiKey", "openAiNativeApiKey", "deepSeekApiKey", + "mistralApiKey", ] for (const key of secretKeys) { await this.storeSecret(key, undefined) diff --git a/src/shared/__tests__/checkExistApiConfig.test.ts b/src/shared/__tests__/checkExistApiConfig.test.ts index 13b64f5..c0fcb64 100644 --- a/src/shared/__tests__/checkExistApiConfig.test.ts +++ b/src/shared/__tests__/checkExistApiConfig.test.ts @@ -49,6 +49,7 @@ describe('checkExistKey', () => { geminiApiKey: undefined, openAiNativeApiKey: undefined, deepSeekApiKey: undefined, + mistralApiKey: undefined, vsCodeLmModelSelector: undefined }; expect(checkExistKey(config)).toBe(false); diff --git a/src/shared/api.ts b/src/shared/api.ts index 908ffec..9721f65 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -13,6 +13,7 @@ export type ApiProvider = | "openai-native" | "deepseek" | "vscode-lm" + | "mistral" export interface ApiHandlerOptions { apiModelId?: string @@ -43,6 +44,7 @@ export interface ApiHandlerOptions { lmStudioBaseUrl?: string geminiApiKey?: string openAiNativeApiKey?: string + mistralApiKey?: string azureApiVersion?: string openRouterUseMiddleOutTransform?: boolean openAiStreamingEnabled?: boolean @@ -549,3 +551,18 @@ export const deepSeekModels = { // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#api-specs export const azureOpenAiDefaultApiVersion = "2024-08-01-preview" + +// Mistral +// https://docs.mistral.ai/getting-started/models/models_overview/ +export type MistralModelId = keyof typeof mistralModels +export const mistralDefaultModelId: MistralModelId = "codestral-latest" +export const mistralModels = { + "codestral-latest": { + maxTokens: 32_768, + contextWindow: 256_000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.3, + outputPrice: 0.9, + }, +} as const satisfies Record diff --git a/src/shared/checkExistApiConfig.ts b/src/shared/checkExistApiConfig.ts index c876e81..6dec4ff 100644 --- a/src/shared/checkExistApiConfig.ts +++ b/src/shared/checkExistApiConfig.ts @@ -14,6 +14,7 @@ export function checkExistKey(config: ApiConfiguration | undefined) { config.geminiApiKey, config.openAiNativeApiKey, config.deepSeekApiKey, + config.mistralApiKey, config.vsCodeLmModelSelector, ].some((key) => key !== undefined) : false; diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index eff166b..13edd85 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -22,6 +22,8 @@ import { geminiModels, glamaDefaultModelId, glamaDefaultModelInfo, + mistralDefaultModelId, + mistralModels, openAiModelInfoSaneDefaults, openAiNativeDefaultModelId, openAiNativeModels, @@ -145,6 +147,7 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = { value: "bedrock", label: "AWS Bedrock" }, { value: "glama", label: "Glama" }, { value: "vscode-lm", label: "VS Code LM API" }, + { value: "mistral", label: "Mistral" }, { value: "lmstudio", label: "LM Studio" }, { value: "ollama", label: "Ollama" } ]} @@ -258,6 +261,37 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) = )} + {selectedProvider === "mistral" && ( +
+ + Mistral API Key + +

+ This key is stored locally and only used to make API requests from this extension. + {!apiConfiguration?.mistralApiKey && ( + + You can get a Mistral API key by signing up here. + + )} +

+
+ )} + {selectedProvider === "openrouter" && (