Add ollama provider option

This commit is contained in:
Saoud Rizwan
2024-09-03 23:03:30 -04:00
parent 4b99561294
commit 286e569e09
8 changed files with 140 additions and 8 deletions

View File

@@ -5,6 +5,7 @@ import { AwsBedrockHandler } from "./bedrock"
import { OpenRouterHandler } from "./openrouter"
import { VertexHandler } from "./vertex"
import { OpenAiHandler } from "./openai"
import { OllamaHandler } from "./ollama"
export interface ApiHandlerMessageResponse {
message: Anthropic.Messages.Message
@@ -43,6 +44,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
return new VertexHandler(options)
case "openai":
return new OpenAiHandler(options)
case "ollama":
return new OllamaHandler(options)
default:
return new AnthropicHandler(options)
}

74
src/api/ollama.ts Normal file
View File

@@ -0,0 +1,74 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import { ApiHandler, ApiHandlerMessageResponse, withoutImageData } from "."
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../shared/api"
import { convertToAnthropicMessage, convertToOpenAiMessages } from "../utils/openai-format"
export class OllamaHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: OpenAI
constructor(options: ApiHandlerOptions) {
this.options = options
this.client = new OpenAI({
baseURL: "http://localhost:11434/v1",
apiKey: "ollama",
})
}
async createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<ApiHandlerMessageResponse> {
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
...convertToOpenAiMessages(messages),
]
const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
},
}))
const createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: this.options.ollamaModelId ?? "",
messages: openAiMessages,
tools: openAiTools,
tool_choice: "auto",
}
const completion = await this.client.chat.completions.create(createParams)
const errorMessage = (completion as any).error?.message
if (errorMessage) {
throw new Error(errorMessage)
}
const anthropicMessage = convertToAnthropicMessage(completion)
return { message: anthropicMessage }
}
createUserReadableRequest(
userContent: Array<
| Anthropic.TextBlockParam
| Anthropic.ImageBlockParam
| Anthropic.ToolUseBlockParam
| Anthropic.ToolResultBlockParam
>
): any {
return {
model: this.options.ollamaModelId ?? "",
system: "(see SYSTEM_PROMPT in src/ClaudeDev.ts)",
messages: [{ conversation_history: "..." }, { role: "user", content: withoutImageData(userContent) }],
tools: "(see tools in src/ClaudeDev.ts)",
tool_choice: "auto",
}
}
getModel(): { id: string; info: ModelInfo } {
return {
id: this.options.ollamaModelId ?? "",
info: openAiModelInfoSaneDefaults,
}
}
}

View File

@@ -29,6 +29,7 @@ type GlobalStateKey =
| "taskHistory"
| "openAiBaseUrl"
| "openAiModelId"
| "ollamaModelId"
export class ClaudeDevProvider implements vscode.WebviewViewProvider {
public static readonly sideBarId = "claude-dev.SidebarProvider" // used in package.json as the view's id. This value cannot be changed due to how vscode caches views based on their id, and updating the id would break existing instances of the extension.
@@ -319,6 +320,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
openAiBaseUrl,
openAiApiKey,
openAiModelId,
ollamaModelId,
} = message.apiConfiguration
await this.updateGlobalState("apiProvider", apiProvider)
await this.updateGlobalState("apiModelId", apiModelId)
@@ -333,6 +335,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
await this.updateGlobalState("openAiBaseUrl", openAiBaseUrl)
await this.storeSecret("openAiApiKey", openAiApiKey)
await this.updateGlobalState("openAiModelId", openAiModelId)
await this.updateGlobalState("ollamaModelId", ollamaModelId)
this.claudeDev?.updateApi(message.apiConfiguration)
}
await this.postStateToWebview()
@@ -623,6 +626,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
openAiBaseUrl,
openAiApiKey,
openAiModelId,
ollamaModelId,
lastShownAnnouncementId,
customInstructions,
alwaysAllowReadOnly,
@@ -641,6 +645,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
this.getGlobalState("openAiBaseUrl") as Promise<string | undefined>,
this.getSecret("openAiApiKey") as Promise<string | undefined>,
this.getGlobalState("openAiModelId") as Promise<string | undefined>,
this.getGlobalState("ollamaModelId") as Promise<string | undefined>,
this.getGlobalState("lastShownAnnouncementId") as Promise<string | undefined>,
this.getGlobalState("customInstructions") as Promise<string | undefined>,
this.getGlobalState("alwaysAllowReadOnly") as Promise<boolean | undefined>,
@@ -676,6 +681,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
openAiBaseUrl,
openAiApiKey,
openAiModelId,
ollamaModelId,
},
lastShownAnnouncementId,
customInstructions,

View File

@@ -1,4 +1,4 @@
export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" | "openai"
export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" | "openai" | "ollama"
export interface ApiHandlerOptions {
apiModelId?: string
@@ -13,6 +13,7 @@ export interface ApiHandlerOptions {
openAiBaseUrl?: string
openAiApiKey?: string
openAiModelId?: string
ollamaModelId?: string
}
export type ApiConfiguration = ApiHandlerOptions & {