Add LM Studio provider

This commit is contained in:
Saoud Rizwan
2024-11-12 22:02:42 -05:00
parent bac0b1a0cb
commit 39bc35eec1
12 changed files with 199 additions and 9 deletions

View File

@@ -6,6 +6,7 @@ import { OpenRouterHandler } from "./providers/openrouter"
import { VertexHandler } from "./providers/vertex"
import { OpenAiHandler } from "./providers/openai"
import { OllamaHandler } from "./providers/ollama"
import { LmStudioHandler } from "./providers/lmstudio"
import { GeminiHandler } from "./providers/gemini"
import { OpenAiNativeHandler } from "./providers/openai-native"
import { ApiStream } from "./transform/stream"
@@ -30,6 +31,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
return new OpenAiHandler(options)
case "ollama":
return new OllamaHandler(options)
case "lmstudio":
return new LmStudioHandler(options)
case "gemini":
return new GeminiHandler(options)
case "openai-native":

View File

@@ -0,0 +1,56 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import { ApiHandler } from "../"
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"
export class LmStudioHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: OpenAI
constructor(options: ApiHandlerOptions) {
this.options = options
this.client = new OpenAI({
baseURL: (this.options.lmStudioBaseUrl || "http://localhost:1234") + "/v1",
apiKey: "noop",
})
}
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
...convertToOpenAiMessages(messages),
]
try {
const stream = await this.client.chat.completions.create({
model: this.getModel().id,
messages: openAiMessages,
temperature: 0,
stream: true,
})
for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta
if (delta?.content) {
yield {
type: "text",
text: delta.content,
}
}
}
} catch (error) {
// LM Studio doesn't return an error code/body for now
throw new Error(
"Please check the LM Studio developer logs to debug what went wrong. You may need to load the model with a larger context length to work with Cline's prompts."
)
}
}
getModel(): { id: string; info: ModelInfo } {
return {
id: this.options.lmStudioModelId || "",
info: openAiModelInfoSaneDefaults,
}
}
}

View File

@@ -51,6 +51,8 @@ type GlobalStateKey =
| "openAiModelId"
| "ollamaModelId"
| "ollamaBaseUrl"
| "lmStudioModelId"
| "lmStudioBaseUrl"
| "anthropicBaseUrl"
| "azureApiVersion"
| "openRouterModelId"
@@ -359,6 +361,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
openAiModelId,
ollamaModelId,
ollamaBaseUrl,
lmStudioModelId,
lmStudioBaseUrl,
anthropicBaseUrl,
geminiApiKey,
openAiNativeApiKey,
@@ -382,6 +386,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
await this.updateGlobalState("openAiModelId", openAiModelId)
await this.updateGlobalState("ollamaModelId", ollamaModelId)
await this.updateGlobalState("ollamaBaseUrl", ollamaBaseUrl)
await this.updateGlobalState("lmStudioModelId", lmStudioModelId)
await this.updateGlobalState("lmStudioBaseUrl", lmStudioBaseUrl)
await this.updateGlobalState("anthropicBaseUrl", anthropicBaseUrl)
await this.storeSecret("geminiApiKey", geminiApiKey)
await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
@@ -442,6 +448,10 @@ export class ClineProvider implements vscode.WebviewViewProvider {
const ollamaModels = await this.getOllamaModels(message.text)
this.postMessageToWebview({ type: "ollamaModels", ollamaModels })
break
case "requestLmStudioModels":
const lmStudioModels = await this.getLmStudioModels(message.text)
this.postMessageToWebview({ type: "lmStudioModels", lmStudioModels })
break
case "refreshOpenRouterModels":
await this.refreshOpenRouterModels()
break
@@ -509,6 +519,25 @@ export class ClineProvider implements vscode.WebviewViewProvider {
}
}
// LM Studio
async getLmStudioModels(baseUrl?: string) {
try {
if (!baseUrl) {
baseUrl = "http://localhost:1234"
}
if (!URL.canParse(baseUrl)) {
return []
}
const response = await axios.get(`${baseUrl}/v1/models`)
const modelsArray = response.data?.data?.map((model: any) => model.id) || []
const models = [...new Set<string>(modelsArray)]
return models
} catch (error) {
return []
}
}
// OpenRouter
async handleOpenRouterCallback(code: string) {
@@ -835,6 +864,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
openAiModelId,
ollamaModelId,
ollamaBaseUrl,
lmStudioModelId,
lmStudioBaseUrl,
anthropicBaseUrl,
geminiApiKey,
openAiNativeApiKey,
@@ -862,6 +893,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
this.getGlobalState("openAiModelId") as Promise<string | undefined>,
this.getGlobalState("ollamaModelId") as Promise<string | undefined>,
this.getGlobalState("ollamaBaseUrl") as Promise<string | undefined>,
this.getGlobalState("lmStudioModelId") as Promise<string | undefined>,
this.getGlobalState("lmStudioBaseUrl") as Promise<string | undefined>,
this.getGlobalState("anthropicBaseUrl") as Promise<string | undefined>,
this.getSecret("geminiApiKey") as Promise<string | undefined>,
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
@@ -906,6 +939,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
openAiModelId,
ollamaModelId,
ollamaBaseUrl,
lmStudioModelId,
lmStudioBaseUrl,
anthropicBaseUrl,
geminiApiKey,
openAiNativeApiKey,

View File

@@ -10,6 +10,7 @@ export interface ExtensionMessage {
| "state"
| "selectedImages"
| "ollamaModels"
| "lmStudioModels"
| "theme"
| "workspaceUpdated"
| "invoke"
@@ -21,6 +22,7 @@ export interface ExtensionMessage {
state?: ExtensionState
images?: string[]
ollamaModels?: string[]
lmStudioModels?: string[]
filePaths?: string[]
partialMessage?: ClineMessage
openRouterModels?: Record<string, ModelInfo>

View File

@@ -17,6 +17,7 @@ export interface WebviewMessage {
| "exportTaskWithId"
| "resetState"
| "requestOllamaModels"
| "requestLmStudioModels"
| "openImage"
| "openFile"
| "openMention"

View File

@@ -5,6 +5,7 @@ export type ApiProvider =
| "vertex"
| "openai"
| "ollama"
| "lmstudio"
| "gemini"
| "openai-native"
@@ -27,6 +28,8 @@ export interface ApiHandlerOptions {
openAiModelId?: string
ollamaModelId?: string
ollamaBaseUrl?: string
lmStudioModelId?: string
lmStudioBaseUrl?: string
geminiApiKey?: string
openAiNativeApiKey?: string
azureApiVersion?: string