Add openai compatible provider

This commit is contained in:
Saoud Rizwan
2024-09-03 17:08:29 -04:00
parent 0badfa2706
commit c209198b23
14 changed files with 383 additions and 187 deletions

View File

@@ -1,9 +1,10 @@
import { Anthropic } from "@anthropic-ai/sdk" import { Anthropic } from "@anthropic-ai/sdk"
import { ApiConfiguration, ApiModelId, ModelInfo } from "../shared/api" import { ApiConfiguration, ModelInfo } from "../shared/api"
import { AnthropicHandler } from "./anthropic" import { AnthropicHandler } from "./anthropic"
import { AwsBedrockHandler } from "./bedrock" import { AwsBedrockHandler } from "./bedrock"
import { OpenRouterHandler } from "./openrouter" import { OpenRouterHandler } from "./openrouter"
import { VertexHandler } from "./vertex" import { VertexHandler } from "./vertex"
import { OpenAiHandler } from "./openai"
export interface ApiHandlerMessageResponse { export interface ApiHandlerMessageResponse {
message: Anthropic.Messages.Message message: Anthropic.Messages.Message
@@ -26,7 +27,7 @@ export interface ApiHandler {
> >
): any ): any
getModel(): { id: ApiModelId; info: ModelInfo } getModel(): { id: string; info: ModelInfo }
} }
export function buildApiHandler(configuration: ApiConfiguration): ApiHandler { export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
@@ -40,6 +41,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
return new AwsBedrockHandler(options) return new AwsBedrockHandler(options)
case "vertex": case "vertex":
return new VertexHandler(options) return new VertexHandler(options)
case "openai":
return new OpenAiHandler(options)
default: default:
return new AnthropicHandler(options) return new AnthropicHandler(options)
} }

74
src/api/openai.ts Normal file
View File

@@ -0,0 +1,74 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import { ApiHandler, ApiHandlerMessageResponse, withoutImageData } from "."
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../shared/api"
import { convertToAnthropicMessage, convertToOpenAiMessages } from "../utils/openai-format"
export class OpenAiHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: OpenAI
constructor(options: ApiHandlerOptions) {
this.options = options
this.client = new OpenAI({
baseURL: this.options.openAiBaseUrl,
apiKey: this.options.openAiApiKey,
})
}
async createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
tools: Anthropic.Messages.Tool[]
): Promise<ApiHandlerMessageResponse> {
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
...convertToOpenAiMessages(messages),
]
const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
},
}))
const createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: this.options.openAiModelId ?? "",
messages: openAiMessages,
tools: openAiTools,
tool_choice: "auto",
}
const completion = await this.client.chat.completions.create(createParams)
const errorMessage = (completion as any).error?.message
if (errorMessage) {
throw new Error(errorMessage)
}
const anthropicMessage = convertToAnthropicMessage(completion)
return { message: anthropicMessage }
}
createUserReadableRequest(
userContent: Array<
| Anthropic.TextBlockParam
| Anthropic.ImageBlockParam
| Anthropic.ToolUseBlockParam
| Anthropic.ToolResultBlockParam
>
): any {
return {
model: this.options.openAiModelId ?? "",
system: "(see SYSTEM_PROMPT in src/ClaudeDev.ts)",
messages: [{ conversation_history: "..." }, { role: "user", content: withoutImageData(userContent) }],
tools: "(see tools in src/ClaudeDev.ts)",
tool_choice: "auto",
}
}
getModel(): { id: string; info: ModelInfo } {
return {
id: this.options.openAiModelId ?? "",
info: openAiModelInfoSaneDefaults,
}
}
}

View File

@@ -8,7 +8,7 @@ import {
OpenRouterModelId, OpenRouterModelId,
openRouterModels, openRouterModels,
} from "../shared/api" } from "../shared/api"
import { convertToOpenAiMessages } from "../utils/openai-format" import { convertToAnthropicMessage, convertToOpenAiMessages } from "../utils/openai-format"
export class OpenRouterHandler implements ApiHandler { export class OpenRouterHandler implements ApiHandler {
private options: ApiHandlerOptions private options: ApiHandlerOptions
@@ -68,57 +68,7 @@ export class OpenRouterHandler implements ApiHandler {
throw new Error(errorMessage) throw new Error(errorMessage)
} }
// Convert OpenAI response to Anthropic format const anthropicMessage = convertToAnthropicMessage(completion)
const openAiMessage = completion.choices[0].message
const anthropicMessage: Anthropic.Messages.Message = {
id: completion.id,
type: "message",
role: openAiMessage.role, // always "assistant"
content: [
{
type: "text",
text: openAiMessage.content || "",
},
],
model: completion.model,
stop_reason: (() => {
switch (completion.choices[0].finish_reason) {
case "stop":
return "end_turn"
case "length":
return "max_tokens"
case "tool_calls":
return "tool_use"
case "content_filter": // Anthropic doesn't have an exact equivalent
default:
return null
}
})(),
stop_sequence: null, // which custom stop_sequence was generated, if any (not applicable if you don't use stop_sequence)
usage: {
input_tokens: completion.usage?.prompt_tokens || 0,
output_tokens: completion.usage?.completion_tokens || 0,
},
}
if (openAiMessage.tool_calls && openAiMessage.tool_calls.length > 0) {
anthropicMessage.content.push(
...openAiMessage.tool_calls.map((toolCall): Anthropic.ToolUseBlock => {
let parsedInput = {}
try {
parsedInput = JSON.parse(toolCall.function.arguments || "{}")
} catch (error) {
console.error("Failed to parse tool arguments:", error)
}
return {
type: "tool_use",
id: toolCall.id,
name: toolCall.function.name,
input: parsedInput,
}
})
)
}
return { message: anthropicMessage } return { message: anthropicMessage }
} }

View File

@@ -1,7 +1,7 @@
import { Anthropic } from "@anthropic-ai/sdk" import { Anthropic } from "@anthropic-ai/sdk"
import * as vscode from "vscode" import * as vscode from "vscode"
import { ClaudeDev } from "../ClaudeDev" import { ClaudeDev } from "../ClaudeDev"
import { ApiModelId, ApiProvider } from "../shared/api" import { ApiProvider } from "../shared/api"
import { ExtensionMessage } from "../shared/ExtensionMessage" import { ExtensionMessage } from "../shared/ExtensionMessage"
import { WebviewMessage } from "../shared/WebviewMessage" import { WebviewMessage } from "../shared/WebviewMessage"
import { downloadTask, findLast, getNonce, getUri, selectImages } from "../utils" import { downloadTask, findLast, getNonce, getUri, selectImages } from "../utils"
@@ -16,7 +16,7 @@ https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default
https://github.com/KumarVariable/vscode-extension-sidebar-html/blob/master/src/customSidebarViewProvider.ts https://github.com/KumarVariable/vscode-extension-sidebar-html/blob/master/src/customSidebarViewProvider.ts
*/ */
type SecretKey = "apiKey" | "openRouterApiKey" | "awsAccessKey" | "awsSecretKey" | "awsSessionToken" type SecretKey = "apiKey" | "openRouterApiKey" | "awsAccessKey" | "awsSecretKey" | "awsSessionToken" | "openAiApiKey"
type GlobalStateKey = type GlobalStateKey =
| "apiProvider" | "apiProvider"
| "apiModelId" | "apiModelId"
@@ -27,6 +27,8 @@ type GlobalStateKey =
| "customInstructions" | "customInstructions"
| "alwaysAllowReadOnly" | "alwaysAllowReadOnly"
| "taskHistory" | "taskHistory"
| "openAiBaseUrl"
| "openAiModelId"
export class ClaudeDevProvider implements vscode.WebviewViewProvider { export class ClaudeDevProvider implements vscode.WebviewViewProvider {
public static readonly sideBarId = "claude-dev.SidebarProvider" // used in package.json as the view's id. This value cannot be changed due to how vscode caches views based on their id, and updating the id would break existing instances of the extension. public static readonly sideBarId = "claude-dev.SidebarProvider" // used in package.json as the view's id. This value cannot be changed due to how vscode caches views based on their id, and updating the id would break existing instances of the extension.
@@ -314,6 +316,9 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
awsRegion, awsRegion,
vertexProjectId, vertexProjectId,
vertexRegion, vertexRegion,
openAiBaseUrl,
openAiApiKey,
openAiModelId,
} = message.apiConfiguration } = message.apiConfiguration
await this.updateGlobalState("apiProvider", apiProvider) await this.updateGlobalState("apiProvider", apiProvider)
await this.updateGlobalState("apiModelId", apiModelId) await this.updateGlobalState("apiModelId", apiModelId)
@@ -325,6 +330,9 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
await this.updateGlobalState("awsRegion", awsRegion) await this.updateGlobalState("awsRegion", awsRegion)
await this.updateGlobalState("vertexProjectId", vertexProjectId) await this.updateGlobalState("vertexProjectId", vertexProjectId)
await this.updateGlobalState("vertexRegion", vertexRegion) await this.updateGlobalState("vertexRegion", vertexRegion)
await this.updateGlobalState("openAiBaseUrl", openAiBaseUrl)
await this.storeSecret("openAiApiKey", openAiApiKey)
await this.updateGlobalState("openAiModelId", openAiModelId)
this.claudeDev?.updateApi(message.apiConfiguration) this.claudeDev?.updateApi(message.apiConfiguration)
} }
await this.postStateToWebview() await this.postStateToWebview()
@@ -615,13 +623,16 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
awsRegion, awsRegion,
vertexProjectId, vertexProjectId,
vertexRegion, vertexRegion,
openAiBaseUrl,
openAiApiKey,
openAiModelId,
lastShownAnnouncementId, lastShownAnnouncementId,
customInstructions, customInstructions,
alwaysAllowReadOnly, alwaysAllowReadOnly,
taskHistory, taskHistory,
] = await Promise.all([ ] = await Promise.all([
this.getGlobalState("apiProvider") as Promise<ApiProvider | undefined>, this.getGlobalState("apiProvider") as Promise<ApiProvider | undefined>,
this.getGlobalState("apiModelId") as Promise<ApiModelId | undefined>, this.getGlobalState("apiModelId") as Promise<string | undefined>,
this.getSecret("apiKey") as Promise<string | undefined>, this.getSecret("apiKey") as Promise<string | undefined>,
this.getSecret("openRouterApiKey") as Promise<string | undefined>, this.getSecret("openRouterApiKey") as Promise<string | undefined>,
this.getSecret("awsAccessKey") as Promise<string | undefined>, this.getSecret("awsAccessKey") as Promise<string | undefined>,
@@ -630,6 +641,9 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
this.getGlobalState("awsRegion") as Promise<string | undefined>, this.getGlobalState("awsRegion") as Promise<string | undefined>,
this.getGlobalState("vertexProjectId") as Promise<string | undefined>, this.getGlobalState("vertexProjectId") as Promise<string | undefined>,
this.getGlobalState("vertexRegion") as Promise<string | undefined>, this.getGlobalState("vertexRegion") as Promise<string | undefined>,
this.getGlobalState("openAiBaseUrl") as Promise<string | undefined>,
this.getSecret("openAiApiKey") as Promise<string | undefined>,
this.getGlobalState("openAiModelId") as Promise<string | undefined>,
this.getGlobalState("lastShownAnnouncementId") as Promise<string | undefined>, this.getGlobalState("lastShownAnnouncementId") as Promise<string | undefined>,
this.getGlobalState("customInstructions") as Promise<string | undefined>, this.getGlobalState("customInstructions") as Promise<string | undefined>,
this.getGlobalState("alwaysAllowReadOnly") as Promise<boolean | undefined>, this.getGlobalState("alwaysAllowReadOnly") as Promise<boolean | undefined>,
@@ -662,6 +676,9 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
awsRegion, awsRegion,
vertexProjectId, vertexProjectId,
vertexRegion, vertexRegion,
openAiBaseUrl,
openAiApiKey,
openAiModelId,
}, },
lastShownAnnouncementId, lastShownAnnouncementId,
customInstructions, customInstructions,
@@ -739,6 +756,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
"awsAccessKey", "awsAccessKey",
"awsSecretKey", "awsSecretKey",
"awsSessionToken", "awsSessionToken",
"openAiApiKey",
] ]
for (const key of secretKeys) { for (const key of secretKeys) {
await this.storeSecret(key, undefined) await this.storeSecret(key, undefined)

View File

@@ -1,7 +1,7 @@
export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" | "openai"
export interface ApiHandlerOptions { export interface ApiHandlerOptions {
apiModelId?: ApiModelId apiModelId?: string
apiKey?: string // anthropic apiKey?: string // anthropic
openRouterApiKey?: string openRouterApiKey?: string
awsAccessKey?: string awsAccessKey?: string
@@ -10,6 +10,9 @@ export interface ApiHandlerOptions {
awsRegion?: string awsRegion?: string
vertexProjectId?: string vertexProjectId?: string
vertexRegion?: string vertexRegion?: string
openAiBaseUrl?: string
openAiApiKey?: string
openAiModelId?: string
} }
export type ApiConfiguration = ApiHandlerOptions & { export type ApiConfiguration = ApiHandlerOptions & {
@@ -29,8 +32,6 @@ export interface ModelInfo {
cacheReadsPrice?: number cacheReadsPrice?: number
} }
export type ApiModelId = AnthropicModelId | OpenRouterModelId | BedrockModelId | VertexModelId
// Anthropic // Anthropic
// https://docs.anthropic.com/en/docs/about-claude/models // https://docs.anthropic.com/en/docs/about-claude/models
export type AnthropicModelId = keyof typeof anthropicModels export type AnthropicModelId = keyof typeof anthropicModels
@@ -292,3 +293,12 @@ export const vertexModels = {
outputPrice: 1.25, outputPrice: 1.25,
}, },
} as const satisfies Record<string, ModelInfo> } as const satisfies Record<string, ModelInfo>
export const openAiModelInfoSaneDefaults: ModelInfo = {
maxTokens: -1,
contextWindow: 128_000,
supportsImages: true,
supportsPromptCache: false,
inputPrice: 0,
outputPrice: 0,
}

View File

@@ -142,3 +142,60 @@ export function convertToOpenAiMessages(
return openAiMessages return openAiMessages
} }
// Convert OpenAI response to Anthropic format
export function convertToAnthropicMessage(
completion: OpenAI.Chat.Completions.ChatCompletion
): Anthropic.Messages.Message {
const openAiMessage = completion.choices[0].message
const anthropicMessage: Anthropic.Messages.Message = {
id: completion.id,
type: "message",
role: openAiMessage.role, // always "assistant"
content: [
{
type: "text",
text: openAiMessage.content || "",
},
],
model: completion.model,
stop_reason: (() => {
switch (completion.choices[0].finish_reason) {
case "stop":
return "end_turn"
case "length":
return "max_tokens"
case "tool_calls":
return "tool_use"
case "content_filter": // Anthropic doesn't have an exact equivalent
default:
return null
}
})(),
stop_sequence: null, // which custom stop_sequence was generated, if any (not applicable if you don't use stop_sequence)
usage: {
input_tokens: completion.usage?.prompt_tokens || 0,
output_tokens: completion.usage?.completion_tokens || 0,
},
}
if (openAiMessage.tool_calls && openAiMessage.tool_calls.length > 0) {
anthropicMessage.content.push(
...openAiMessage.tool_calls.map((toolCall): Anthropic.ToolUseBlock => {
let parsedInput = {}
try {
parsedInput = JSON.parse(toolCall.function.arguments || "{}")
} catch (error) {
console.error("Failed to parse tool arguments:", error)
}
return {
type: "tool_use",
id: toolCall.id,
name: toolCall.function.name,
input: parsedInput,
}
})
)
}
return anthropicMessage
}

View File

@@ -2,12 +2,12 @@ import { VSCodeDropdown, VSCodeLink, VSCodeOption, VSCodeTextField } from "@vsco
import React, { useMemo } from "react" import React, { useMemo } from "react"
import { import {
ApiConfiguration, ApiConfiguration,
ApiModelId,
ModelInfo, ModelInfo,
anthropicDefaultModelId, anthropicDefaultModelId,
anthropicModels, anthropicModels,
bedrockDefaultModelId, bedrockDefaultModelId,
bedrockModels, bedrockModels,
openAiModelInfoSaneDefaults,
openRouterDefaultModelId, openRouterDefaultModelId,
openRouterModels, openRouterModels,
vertexDefaultModelId, vertexDefaultModelId,
@@ -69,11 +69,16 @@ const ApiOptions: React.FC<ApiOptionsProps> = ({ showModelOptions, apiErrorMessa
<label htmlFor="api-provider"> <label htmlFor="api-provider">
<span style={{ fontWeight: 500 }}>API Provider</span> <span style={{ fontWeight: 500 }}>API Provider</span>
</label> </label>
<VSCodeDropdown id="api-provider" value={selectedProvider} onChange={handleInputChange("apiProvider")}> <VSCodeDropdown
id="api-provider"
value={selectedProvider}
onChange={handleInputChange("apiProvider")}
style={{ minWidth: 125 }}>
<VSCodeOption value="anthropic">Anthropic</VSCodeOption> <VSCodeOption value="anthropic">Anthropic</VSCodeOption>
<VSCodeOption value="openrouter">OpenRouter</VSCodeOption> <VSCodeOption value="openrouter">OpenRouter</VSCodeOption>
<VSCodeOption value="bedrock">AWS Bedrock</VSCodeOption> <VSCodeOption value="bedrock">AWS Bedrock</VSCodeOption>
<VSCodeOption value="vertex">GCP Vertex AI</VSCodeOption> <VSCodeOption value="vertex">GCP Vertex AI</VSCodeOption>
<VSCodeOption value="openai">OpenAI Compatible</VSCodeOption>
</VSCodeDropdown> </VSCodeDropdown>
</div> </div>
@@ -256,6 +261,47 @@ const ApiOptions: React.FC<ApiOptionsProps> = ({ showModelOptions, apiErrorMessa
</div> </div>
)} )}
{selectedProvider === "openai" && (
<div>
<VSCodeTextField
value={apiConfiguration?.openAiBaseUrl || ""}
style={{ width: "100%" }}
type="url"
onInput={handleInputChange("openAiBaseUrl")}
placeholder={"e.g. http://localhost:11434"}>
<span style={{ fontWeight: 500 }}>Base URL</span>
</VSCodeTextField>
<VSCodeTextField
value={apiConfiguration?.openAiApiKey || ""}
style={{ width: "100%" }}
type="password"
onInput={handleInputChange("openAiApiKey")}
placeholder="e.g. ollama">
<span style={{ fontWeight: 500 }}>API Key</span>
</VSCodeTextField>
<VSCodeTextField
value={apiConfiguration?.openAiModelId || ""}
style={{ width: "100%" }}
onInput={handleInputChange("openAiModelId")}
placeholder={"e.g. llama3.1"}>
<span style={{ fontWeight: 500 }}>Model ID</span>
</VSCodeTextField>
<p
style={{
fontSize: "12px",
marginTop: "5px",
color: "var(--vscode-descriptionForeground)",
}}>
You can use any OpenAI compatible API with models that support tool use.{" "}
<span style={{ color: "var(--vscode-errorForeground)" }}>
(<span style={{ fontWeight: 500 }}>Note:</span> Claude Dev uses complex prompts, so results
may vary depending on the quality of the model you choose. Less capable models may not work
as expected.)
</span>
</p>
</div>
)}
{apiErrorMessage && ( {apiErrorMessage && (
<p <p
style={{ style={{
@@ -267,7 +313,7 @@ const ApiOptions: React.FC<ApiOptionsProps> = ({ showModelOptions, apiErrorMessa
</p> </p>
)} )}
{showModelOptions && ( {selectedProvider !== "openai" && showModelOptions && (
<> <>
<div className="dropdown-container"> <div className="dropdown-container">
<label htmlFor="model-id"> <label htmlFor="model-id">
@@ -365,8 +411,8 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
const provider = apiConfiguration?.apiProvider || "anthropic" const provider = apiConfiguration?.apiProvider || "anthropic"
const modelId = apiConfiguration?.apiModelId const modelId = apiConfiguration?.apiModelId
const getProviderData = (models: Record<string, ModelInfo>, defaultId: ApiModelId) => { const getProviderData = (models: Record<string, ModelInfo>, defaultId: string) => {
let selectedModelId: ApiModelId let selectedModelId: string
let selectedModelInfo: ModelInfo let selectedModelInfo: ModelInfo
if (modelId && modelId in models) { if (modelId && modelId in models) {
selectedModelId = modelId selectedModelId = modelId
@@ -386,6 +432,12 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
return getProviderData(bedrockModels, bedrockDefaultModelId) return getProviderData(bedrockModels, bedrockDefaultModelId)
case "vertex": case "vertex":
return getProviderData(vertexModels, vertexDefaultModelId) return getProviderData(vertexModels, vertexDefaultModelId)
case "openai":
return {
selectedProvider: provider,
selectedModelId: apiConfiguration?.openAiModelId ?? "",
selectedModelInfo: openAiModelInfoSaneDefaults,
}
default: default:
return getProviderData(anthropicModels, anthropicDefaultModelId) return getProviderData(anthropicModels, anthropicDefaultModelId)
} }

View File

@@ -497,9 +497,6 @@ const ChatView = ({
cacheReads={apiMetrics.totalCacheReads} cacheReads={apiMetrics.totalCacheReads}
totalCost={apiMetrics.totalCost} totalCost={apiMetrics.totalCost}
onClose={handleTaskCloseButtonClick} onClose={handleTaskCloseButtonClick}
isHidden={isHidden}
vscodeUriScheme={uriScheme}
apiProvider={apiConfiguration?.apiProvider}
/> />
) : ( ) : (
<> <>

View File

@@ -108,17 +108,21 @@ const HistoryPreview = ({ showHistoryView }: HistoryPreviewProps) => {
<span> <span>
Tokens: {item.tokensIn?.toLocaleString()} {item.tokensOut?.toLocaleString()} Tokens: {item.tokensIn?.toLocaleString()} {item.tokensOut?.toLocaleString()}
</span> </span>
{" • "}
{item.cacheWrites && item.cacheReads && ( {item.cacheWrites && item.cacheReads && (
<> <>
{" • "}
<span> <span>
Cache: +{item.cacheWrites?.toLocaleString()} {" "} Cache: +{item.cacheWrites?.toLocaleString()} {" "}
{item.cacheReads?.toLocaleString()} {item.cacheReads?.toLocaleString()}
</span> </span>
{" • "}
</> </>
)} )}
{!!item.totalCost && (
<>
{" • "}
<span>API Cost: ${item.totalCost?.toFixed(4)}</span> <span>API Cost: ${item.totalCost?.toFixed(4)}</span>
</>
)}
</div> </div>
</div> </div>
</div> </div>

View File

@@ -63,6 +63,17 @@ const HistoryView = ({ onDone }: HistoryViewProps) => {
) )
} }
const ExportButton = ({ itemId }: { itemId: string }) => (
<VSCodeButton
appearance="icon"
onClick={(e) => {
e.stopPropagation()
handleExportMd(itemId)
}}>
<div style={{ fontSize: "11px", fontWeight: 500, opacity: 1 }}>EXPORT .MD</div>
</VSCodeButton>
)
return ( return (
<> <>
<style> <style>
@@ -213,6 +224,12 @@ const HistoryView = ({ onDone }: HistoryViewProps) => {
{highlightText(item.task, searchQuery)} {highlightText(item.task, searchQuery)}
</div> </div>
<div style={{ display: "flex", flexDirection: "column", gap: "4px" }}> <div style={{ display: "flex", flexDirection: "column", gap: "4px" }}>
<div
style={{
display: "flex",
justifyContent: "space-between",
alignItems: "center",
}}>
<div <div
style={{ style={{
display: "flex", display: "flex",
@@ -262,6 +279,9 @@ const HistoryView = ({ onDone }: HistoryViewProps) => {
{item.tokensOut?.toLocaleString()} {item.tokensOut?.toLocaleString()}
</span> </span>
</div> </div>
{!item.totalCost && <ExportButton itemId={item.id} />}
</div>
{item.cacheWrites && item.cacheReads && ( {item.cacheWrites && item.cacheReads && (
<div <div
style={{ style={{
@@ -313,6 +333,7 @@ const HistoryView = ({ onDone }: HistoryViewProps) => {
</span> </span>
</div> </div>
)} )}
{!!item.totalCost && (
<div <div
style={{ style={{
display: "flex", display: "flex",
@@ -332,17 +353,9 @@ const HistoryView = ({ onDone }: HistoryViewProps) => {
${item.totalCost?.toFixed(4)} ${item.totalCost?.toFixed(4)}
</span> </span>
</div> </div>
<VSCodeButton <ExportButton itemId={item.id} />
appearance="icon"
onClick={(e) => {
e.stopPropagation()
handleExportMd(item.id)
}}>
<div style={{ fontSize: "11px", fontWeight: 500, opacity: 1 }}>
EXPORT .MD
</div>
</VSCodeButton>
</div> </div>
)}
</div> </div>
</div> </div>
</div> </div>

View File

@@ -1,9 +1,4 @@
import { import { VSCodeButton, VSCodeCheckbox, VSCodeLink, VSCodeTextArea } from "@vscode/webview-ui-toolkit/react"
VSCodeButton,
VSCodeCheckbox,
VSCodeLink,
VSCodeTextArea
} from "@vscode/webview-ui-toolkit/react"
import { useEffect, useState } from "react" import { useEffect, useState } from "react"
import { useExtensionState } from "../context/ExtensionStateContext" import { useExtensionState } from "../context/ExtensionStateContext"
import { validateApiConfiguration } from "../utils/validate" import { validateApiConfiguration } from "../utils/validate"

View File

@@ -1,8 +1,8 @@
import { VSCodeButton } from "@vscode/webview-ui-toolkit/react" import { VSCodeButton } from "@vscode/webview-ui-toolkit/react"
import React, { useEffect, useRef, useState } from "react" import React, { useEffect, useRef, useState } from "react"
import { useWindowSize } from "react-use" import { useWindowSize } from "react-use"
import { ApiProvider } from "../../../src/shared/api"
import { ClaudeMessage } from "../../../src/shared/ExtensionMessage" import { ClaudeMessage } from "../../../src/shared/ExtensionMessage"
import { useExtensionState } from "../context/ExtensionStateContext"
import { vscode } from "../utils/vscode" import { vscode } from "../utils/vscode"
import Thumbnails from "./Thumbnails" import Thumbnails from "./Thumbnails"
@@ -15,9 +15,6 @@ interface TaskHeaderProps {
cacheReads?: number cacheReads?: number
totalCost: number totalCost: number
onClose: () => void onClose: () => void
isHidden: boolean
vscodeUriScheme?: string
apiProvider?: ApiProvider
} }
const TaskHeader: React.FC<TaskHeaderProps> = ({ const TaskHeader: React.FC<TaskHeaderProps> = ({
@@ -29,10 +26,8 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
cacheReads, cacheReads,
totalCost, totalCost,
onClose, onClose,
isHidden,
vscodeUriScheme,
apiProvider,
}) => { }) => {
const { apiConfiguration } = useExtensionState()
const [isExpanded, setIsExpanded] = useState(false) const [isExpanded, setIsExpanded] = useState(false)
const [showSeeMore, setShowSeeMore] = useState(false) const [showSeeMore, setShowSeeMore] = useState(false)
const textContainerRef = useRef<HTMLDivElement>(null) const textContainerRef = useRef<HTMLDivElement>(null)
@@ -100,6 +95,18 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
vscode.postMessage({ type: "exportCurrentTask" }) vscode.postMessage({ type: "exportCurrentTask" })
} }
const ExportButton = () => (
<VSCodeButton
appearance="icon"
onClick={handleDownload}
style={{
marginBottom: "-2px",
marginRight: "-2.5px",
}}>
<div style={{ fontSize: "10.5px", fontWeight: "bold", opacity: 0.6 }}>EXPORT .MD</div>
</VSCodeButton>
)
return ( return (
<div style={{ padding: "10px 13px 10px 13px" }}> <div style={{ padding: "10px 13px 10px 13px" }}>
<div <div
@@ -196,6 +203,12 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
)} )}
{task.images && task.images.length > 0 && <Thumbnails images={task.images} />} {task.images && task.images.length > 0 && <Thumbnails images={task.images} />}
<div style={{ display: "flex", flexDirection: "column", gap: "4px" }}> <div style={{ display: "flex", flexDirection: "column", gap: "4px" }}>
<div
style={{
display: "flex",
justifyContent: "space-between",
alignItems: "center",
}}>
<div style={{ display: "flex", alignItems: "center", gap: "4px", flexWrap: "wrap" }}> <div style={{ display: "flex", alignItems: "center", gap: "4px", flexWrap: "wrap" }}>
<span style={{ fontWeight: "bold" }}>Tokens:</span> <span style={{ fontWeight: "bold" }}>Tokens:</span>
<span style={{ display: "flex", alignItems: "center", gap: "3px" }}> <span style={{ display: "flex", alignItems: "center", gap: "3px" }}>
@@ -213,6 +226,9 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
{tokensOut?.toLocaleString()} {tokensOut?.toLocaleString()}
</span> </span>
</div> </div>
{apiConfiguration?.apiProvider === "openai" && <ExportButton />}
</div>
{(doesModelSupportPromptCache || cacheReads !== undefined || cacheWrites !== undefined) && ( {(doesModelSupportPromptCache || cacheReads !== undefined || cacheWrites !== undefined) && (
<div style={{ display: "flex", alignItems: "center", gap: "4px", flexWrap: "wrap" }}> <div style={{ display: "flex", alignItems: "center", gap: "4px", flexWrap: "wrap" }}>
<span style={{ fontWeight: "bold" }}>Cache:</span> <span style={{ fontWeight: "bold" }}>Cache:</span>
@@ -232,6 +248,7 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
</span> </span>
</div> </div>
)} )}
{apiConfiguration?.apiProvider !== "openai" && (
<div <div
style={{ style={{
display: "flex", display: "flex",
@@ -242,16 +259,9 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
<span style={{ fontWeight: "bold" }}>API Cost:</span> <span style={{ fontWeight: "bold" }}>API Cost:</span>
<span>${totalCost?.toFixed(4)}</span> <span>${totalCost?.toFixed(4)}</span>
</div> </div>
<VSCodeButton <ExportButton />
appearance="icon"
onClick={handleDownload}
style={{
marginBottom: "-2px",
marginRight: "-2.5px",
}}>
<div style={{ fontSize: "10.5px", fontWeight: "bold", opacity: 0.6 }}>EXPORT .MD</div>
</VSCodeButton>
</div> </div>
)}
</div> </div>
</div> </div>
{/* {apiProvider === "kodu" && ( {/* {apiProvider === "kodu" && (

View File

@@ -31,9 +31,13 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
setState(message.state) setState(message.state)
const config = message.state?.apiConfiguration const config = message.state?.apiConfiguration
const hasKey = config const hasKey = config
? [config.apiKey, config.openRouterApiKey, config.awsRegion, config.vertexProjectId].some( ? [
(key) => key !== undefined config.apiKey,
) config.openRouterApiKey,
config.awsRegion,
config.vertexProjectId,
config.openAiApiKey,
].some((key) => key !== undefined)
: false : false
setShowWelcome(!hasKey) setShowWelcome(!hasKey)
setDidHydrateState(true) setDidHydrateState(true)

View File

@@ -23,6 +23,15 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s
return "You must provide a valid Google Cloud Project ID and Region." return "You must provide a valid Google Cloud Project ID and Region."
} }
break break
case "openai":
if (
!apiConfiguration.openAiBaseUrl ||
!apiConfiguration.openAiApiKey ||
!apiConfiguration.openAiModelId
) {
return "You must provide a valid base URL, API key, and model ID."
}
break
} }
} }
return undefined return undefined