Add a screen for custom prompts

This commit is contained in:
Matt Rubens
2025-01-13 03:16:10 -05:00
parent 4027e1c10c
commit 75e308b033
21 changed files with 1044 additions and 238 deletions

View File

@@ -780,14 +780,15 @@ export class Cline {
})
}
const { browserViewportSize, preferredLanguage, mode } = await this.providerRef.deref()?.getState() ?? {}
const { browserViewportSize, preferredLanguage, mode, customPrompts } = await this.providerRef.deref()?.getState() ?? {}
const systemPrompt = await SYSTEM_PROMPT(
cwd,
this.api.getModel().info.supportsComputerUse ?? false,
mcpHub,
this.diffStrategy,
browserViewportSize,
mode
mode,
customPrompts
) + await addCustomInstructions(this.customInstructions ?? '', cwd, preferredLanguage)
// If the previous API request's total token usage is close to the context window, truncate the conversation history to free up space for the new request

View File

@@ -1,4 +1,4 @@
import { architectMode } from "./modes"
import { architectMode, defaultPrompts } from "../../shared/modes"
import { getToolDescriptionsForMode } from "./tools"
import {
getRulesSection,
@@ -20,7 +20,8 @@ export const ARCHITECT_PROMPT = async (
mcpHub?: McpHub,
diffStrategy?: DiffStrategy,
browserViewportSize?: string,
) => `You are Cline, a software architecture expert specializing in analyzing codebases, identifying patterns, and providing high-level technical guidance. You excel at understanding complex systems, evaluating architectural decisions, and suggesting improvements while maintaining a read-only approach to the codebase. Make sure to help the user come up with a solid implementation plan for their project and don't rush to switch to implementing code.
customPrompt?: string,
) => `${customPrompt || defaultPrompts[architectMode]}
${getSharedToolUseSection()}

View File

@@ -1,4 +1,4 @@
import { Mode, askMode } from "./modes"
import { Mode, askMode, defaultPrompts } from "../../shared/modes"
import { getToolDescriptionsForMode } from "./tools"
import {
getRulesSection,
@@ -21,7 +21,8 @@ export const ASK_PROMPT = async (
mcpHub?: McpHub,
diffStrategy?: DiffStrategy,
browserViewportSize?: string,
) => `You are Cline, a knowledgeable technical assistant focused on answering questions and providing information about software development, technology, and related topics. You can analyze code, explain concepts, and access external resources while maintaining a read-only approach to the codebase. Make sure to answer the user's questions and don't rush to switch to implementing code.
customPrompt?: string,
) => `${customPrompt || defaultPrompts[askMode]}
${getSharedToolUseSection()}

View File

@@ -1,4 +1,4 @@
import { Mode, codeMode } from "./modes"
import { Mode, codeMode, defaultPrompts } from "../../shared/modes"
import { getToolDescriptionsForMode } from "./tools"
import {
getRulesSection,
@@ -21,7 +21,8 @@ export const CODE_PROMPT = async (
mcpHub?: McpHub,
diffStrategy?: DiffStrategy,
browserViewportSize?: string,
) => `You are Cline, a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
customPrompt?: string,
) => `${customPrompt || defaultPrompts[codeMode]}
${getSharedToolUseSection()}

View File

@@ -63,15 +63,16 @@ export const SYSTEM_PROMPT = async (
mcpHub?: McpHub,
diffStrategy?: DiffStrategy,
browserViewportSize?: string,
mode: Mode = codeMode,
mode: Mode = codeMode,
customPrompts?: { ask?: string; code?: string; architect?: string; enhance?: string },
) => {
switch (mode) {
case architectMode:
return ARCHITECT_PROMPT(cwd, supportsComputerUse, mcpHub, diffStrategy, browserViewportSize)
return ARCHITECT_PROMPT(cwd, supportsComputerUse, mcpHub, diffStrategy, browserViewportSize, customPrompts?.architect)
case askMode:
return ASK_PROMPT(cwd, supportsComputerUse, mcpHub, diffStrategy, browserViewportSize)
return ASK_PROMPT(cwd, supportsComputerUse, mcpHub, diffStrategy, browserViewportSize, customPrompts?.ask)
default:
return CODE_PROMPT(cwd, supportsComputerUse, mcpHub, diffStrategy, browserViewportSize)
return CODE_PROMPT(cwd, supportsComputerUse, mcpHub, diffStrategy, browserViewportSize, customPrompts?.code)
}
}

View File

@@ -17,6 +17,8 @@ import { findLast } from "../../shared/array"
import { ApiConfigMeta, ExtensionMessage } from "../../shared/ExtensionMessage"
import { HistoryItem } from "../../shared/HistoryItem"
import { WebviewMessage } from "../../shared/WebviewMessage"
import { defaultPrompts } from "../../shared/modes"
import { SYSTEM_PROMPT, addCustomInstructions } from "../prompts/system"
import { fileExistsAtPath } from "../../utils/fs"
import { Cline } from "../Cline"
import { openMention } from "../mentions"
@@ -28,7 +30,7 @@ import { enhancePrompt } from "../../utils/enhance-prompt"
import { getCommitInfo, searchCommits, getWorkingState } from "../../utils/git"
import { ConfigManager } from "../config/ConfigManager"
import { Mode } from "../prompts/types"
import { codeMode } from "../prompts/system"
import { codeMode, CustomPrompts } from "../../shared/modes"
/*
https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default/weather-webview/src/providers/WeatherViewProvider.ts
@@ -93,6 +95,8 @@ type GlobalStateKey =
| "listApiConfigMeta"
| "mode"
| "modeApiConfigs"
| "customPrompts"
| "enhancementApiConfigId"
export const GlobalFileNames = {
apiConversationHistory: "api_conversation_history.json",
@@ -111,7 +115,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
private cline?: Cline
private workspaceTracker?: WorkspaceTracker
mcpHub?: McpHub
private latestAnnouncementId = "dec-10-2024" // update to some unique identifier when we add a new announcement
private latestAnnouncementId = "jan-13-2025-custom-prompt" // update to some unique identifier when we add a new announcement
configManager: ConfigManager
constructor(
@@ -727,6 +731,32 @@ export class ClineProvider implements vscode.WebviewViewProvider {
await this.postStateToWebview()
break
case "updatePrompt":
if (message.promptMode && message.customPrompt !== undefined) {
const existingPrompts = await this.getGlobalState("customPrompts") || {}
const updatedPrompts = {
...existingPrompts,
[message.promptMode]: message.customPrompt
}
await this.updateGlobalState("customPrompts", updatedPrompts)
// Get current state and explicitly include customPrompts
const currentState = await this.getState()
const stateWithPrompts = {
...currentState,
customPrompts: updatedPrompts
}
// Post state with prompts
this.view?.webview.postMessage({
type: "state",
state: stateWithPrompts
})
}
break
case "deleteMessage": {
const answer = await vscode.window.showInformationMessage(
"What would you like to delete?",
@@ -797,16 +827,28 @@ export class ClineProvider implements vscode.WebviewViewProvider {
await this.updateGlobalState("screenshotQuality", message.value)
await this.postStateToWebview()
break
case "enhancementApiConfigId":
await this.updateGlobalState("enhancementApiConfigId", message.text)
await this.postStateToWebview()
break
case "enhancePrompt":
if (message.text) {
try {
const { apiConfiguration } = await this.getState()
const enhanceConfig = {
...apiConfiguration,
apiProvider: "openrouter" as const,
openRouterModelId: "gpt-4o",
const { apiConfiguration, customPrompts, listApiConfigMeta, enhancementApiConfigId } = await this.getState()
// Try to get enhancement config first, fall back to current config
let configToUse: ApiConfiguration = apiConfiguration
if (enhancementApiConfigId) {
const config = listApiConfigMeta?.find(c => c.id === enhancementApiConfigId)
if (config?.name) {
const loadedConfig = await this.configManager.LoadConfig(config.name)
if (loadedConfig.apiProvider) {
configToUse = loadedConfig
}
}
}
const enhancedPrompt = await enhancePrompt(enhanceConfig, message.text)
const enhancedPrompt = await enhancePrompt(configToUse, message.text, customPrompts?.enhance)
await this.postMessageToWebview({
type: "enhancedPrompt",
text: enhancedPrompt
@@ -814,11 +856,37 @@ export class ClineProvider implements vscode.WebviewViewProvider {
} catch (error) {
console.error("Error enhancing prompt:", error)
vscode.window.showErrorMessage("Failed to enhance prompt")
await this.postMessageToWebview({
type: "enhancedPrompt"
})
}
}
break
case "getSystemPrompt":
try {
const { apiConfiguration, customPrompts, customInstructions, preferredLanguage, browserViewportSize, mcpEnabled } = await this.getState()
const cwd = vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0) || ''
const fullPrompt = await SYSTEM_PROMPT(
cwd,
apiConfiguration.openRouterModelInfo?.supportsComputerUse ?? false,
mcpEnabled ? this.mcpHub : undefined,
undefined,
browserViewportSize ?? "900x600",
message.mode,
customPrompts
) + await addCustomInstructions(customInstructions ?? '', cwd, preferredLanguage)
await this.postMessageToWebview({
type: "systemPrompt",
text: fullPrompt,
mode: message.mode
})
} catch (error) {
console.error("Error getting system prompt:", error)
vscode.window.showErrorMessage("Failed to get system prompt")
}
break
case "searchCommits": {
const cwd = vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0)
if (cwd) {
@@ -1482,6 +1550,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
currentApiConfigName,
listApiConfigMeta,
mode,
customPrompts,
enhancementApiConfigId,
} = await this.getState()
const allowedCommands = vscode.workspace
@@ -1500,11 +1570,11 @@ export class ClineProvider implements vscode.WebviewViewProvider {
uriScheme: vscode.env.uriScheme,
clineMessages: this.cline?.clineMessages || [],
taskHistory: (taskHistory || [])
.filter((item) => item.ts && item.task)
.sort((a, b) => b.ts - a.ts),
.filter((item: HistoryItem) => item.ts && item.task)
.sort((a: HistoryItem, b: HistoryItem) => b.ts - a.ts),
soundEnabled: soundEnabled ?? false,
diffEnabled: diffEnabled ?? true,
shouldShowAnnouncement: false, // lastShownAnnouncementId !== this.latestAnnouncementId,
shouldShowAnnouncement: lastShownAnnouncementId !== this.latestAnnouncementId,
allowedCommands,
soundVolume: soundVolume ?? 0.5,
browserViewportSize: browserViewportSize ?? "900x600",
@@ -1519,6 +1589,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
currentApiConfigName: currentApiConfigName ?? "default",
listApiConfigMeta: listApiConfigMeta ?? [],
mode: mode ?? codeMode,
customPrompts: customPrompts ?? {},
enhancementApiConfigId,
}
}
@@ -1630,6 +1702,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
listApiConfigMeta,
mode,
modeApiConfigs,
customPrompts,
enhancementApiConfigId,
] = await Promise.all([
this.getGlobalState("apiProvider") as Promise<ApiProvider | undefined>,
this.getGlobalState("apiModelId") as Promise<string | undefined>,
@@ -1686,6 +1760,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
this.getGlobalState("listApiConfigMeta") as Promise<ApiConfigMeta[] | undefined>,
this.getGlobalState("mode") as Promise<Mode | undefined>,
this.getGlobalState("modeApiConfigs") as Promise<Record<Mode, string> | undefined>,
this.getGlobalState("customPrompts") as Promise<CustomPrompts | undefined>,
this.getGlobalState("enhancementApiConfigId") as Promise<string | undefined>,
])
let apiProvider: ApiProvider
@@ -1786,6 +1862,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
currentApiConfigName: currentApiConfigName ?? "default",
listApiConfigMeta: listApiConfigMeta ?? [],
modeApiConfigs: modeApiConfigs ?? {} as Record<Mode, string>,
customPrompts: customPrompts ?? {},
enhancementApiConfigId,
}
}

View File

@@ -62,6 +62,7 @@ jest.mock('vscode', () => ({
},
window: {
showInformationMessage: jest.fn(),
showErrorMessage: jest.fn(),
},
workspace: {
getConfiguration: jest.fn().mockReturnValue({
@@ -113,6 +114,13 @@ jest.mock('../../../api', () => ({
buildApiHandler: jest.fn()
}))
// Mock system prompt
jest.mock('../../prompts/system', () => ({
SYSTEM_PROMPT: jest.fn().mockImplementation(async () => 'mocked system prompt'),
codeMode: 'code',
addCustomInstructions: jest.fn().mockImplementation(async () => '')
}))
// Mock WorkspaceTracker
jest.mock('../../../integrations/workspace/WorkspaceTracker', () => {
return jest.fn().mockImplementation(() => ({
@@ -504,6 +512,106 @@ describe('ClineProvider', () => {
expect(mockPostMessage).toHaveBeenCalled()
})
test('handles updatePrompt message correctly', async () => {
provider.resolveWebviewView(mockWebviewView)
const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0]
// Mock existing prompts
const existingPrompts = {
code: 'existing code prompt',
architect: 'existing architect prompt'
}
;(mockContext.globalState.get as jest.Mock).mockImplementation((key: string) => {
if (key === 'customPrompts') {
return existingPrompts
}
return undefined
})
// Test updating a prompt
await messageHandler({
type: 'updatePrompt',
promptMode: 'code',
customPrompt: 'new code prompt'
})
// Verify state was updated correctly
expect(mockContext.globalState.update).toHaveBeenCalledWith(
'customPrompts',
{
...existingPrompts,
code: 'new code prompt'
}
)
// Verify state was posted to webview
expect(mockPostMessage).toHaveBeenCalledWith(
expect.objectContaining({
type: 'state',
state: expect.objectContaining({
customPrompts: {
...existingPrompts,
code: 'new code prompt'
}
})
})
)
})
test('customPrompts defaults to empty object', async () => {
// Mock globalState.get to return undefined for customPrompts
(mockContext.globalState.get as jest.Mock).mockImplementation((key: string) => {
if (key === 'customPrompts') {
return undefined
}
return null
})
const state = await provider.getState()
expect(state.customPrompts).toEqual({})
})
test('saves mode config when updating API configuration', async () => {
// Setup mock context with mode and config name
mockContext = {
...mockContext,
globalState: {
...mockContext.globalState,
get: jest.fn((key: string) => {
if (key === 'mode') {
return 'code'
} else if (key === 'currentApiConfigName') {
return 'test-config'
}
return undefined
}),
update: jest.fn(),
keys: jest.fn().mockReturnValue([]),
}
} as unknown as vscode.ExtensionContext
// Create new provider with updated mock context
provider = new ClineProvider(mockContext, mockOutputChannel)
provider.resolveWebviewView(mockWebviewView)
const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0]
provider.configManager = {
ListConfig: jest.fn().mockResolvedValue([
{ name: 'test-config', id: 'test-id', apiProvider: 'anthropic' }
]),
SetModeConfig: jest.fn()
} as any
// Update API configuration
await messageHandler({
type: 'apiConfiguration',
apiConfiguration: { apiProvider: 'anthropic' }
})
// Should save config as default for current mode
expect(provider.configManager.SetModeConfig).toHaveBeenCalledWith('code', 'test-id')
})
test('file content includes line numbers', async () => {
const { extractTextFromFile } = require('../../../integrations/misc/extract-text')
const result = await extractTextFromFile('test.js')
@@ -654,4 +762,103 @@ describe('ClineProvider', () => {
expect(mockCline.overwriteApiConversationHistory).not.toHaveBeenCalled()
})
})
describe('getSystemPrompt', () => {
beforeEach(() => {
mockPostMessage.mockClear();
provider.resolveWebviewView(mockWebviewView);
});
const getMessageHandler = () => {
const mockCalls = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls;
expect(mockCalls.length).toBeGreaterThan(0);
return mockCalls[0][0];
};
test('handles mcpEnabled setting correctly', async () => {
// Mock getState to return mcpEnabled: true
jest.spyOn(provider, 'getState').mockResolvedValue({
apiConfiguration: {
apiProvider: 'openrouter' as const,
openRouterModelInfo: {
supportsComputerUse: true,
supportsPromptCache: false,
maxTokens: 4096,
contextWindow: 8192,
supportsImages: false,
inputPrice: 0.0,
outputPrice: 0.0,
description: undefined
}
},
mcpEnabled: true,
mode: 'code' as const
} as any);
const handler1 = getMessageHandler();
expect(typeof handler1).toBe('function');
await handler1({ type: 'getSystemPrompt', mode: 'code' });
// Verify mcpHub is passed when mcpEnabled is true
expect(mockPostMessage).toHaveBeenCalledWith(
expect.objectContaining({
type: 'systemPrompt',
text: expect.any(String)
})
);
// Mock getState to return mcpEnabled: false
jest.spyOn(provider, 'getState').mockResolvedValue({
apiConfiguration: {
apiProvider: 'openrouter' as const,
openRouterModelInfo: {
supportsComputerUse: true,
supportsPromptCache: false,
maxTokens: 4096,
contextWindow: 8192,
supportsImages: false,
inputPrice: 0.0,
outputPrice: 0.0,
description: undefined
}
},
mcpEnabled: false,
mode: 'code' as const
} as any);
const handler2 = getMessageHandler();
await handler2({ type: 'getSystemPrompt', mode: 'code' });
// Verify mcpHub is not passed when mcpEnabled is false
expect(mockPostMessage).toHaveBeenCalledWith(
expect.objectContaining({
type: 'systemPrompt',
text: expect.any(String)
})
);
});
test('returns empty prompt for enhance mode', async () => {
const enhanceHandler = getMessageHandler();
await enhanceHandler({ type: 'getSystemPrompt', mode: 'enhance' })
expect(mockPostMessage).toHaveBeenCalledWith(
expect.objectContaining({
type: 'systemPrompt',
text: ''
})
)
})
test('handles errors gracefully', async () => {
// Mock SYSTEM_PROMPT to throw an error
const systemPrompt = require('../../prompts/system')
jest.spyOn(systemPrompt, 'SYSTEM_PROMPT').mockRejectedValueOnce(new Error('Test error'))
const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0]
await messageHandler({ type: 'getSystemPrompt', mode: 'code' })
expect(vscode.window.showErrorMessage).toHaveBeenCalledWith('Failed to get system prompt')
})
})
})

View File

@@ -59,6 +59,12 @@ export function activate(context: vscode.ExtensionContext) {
}),
)
context.subscriptions.push(
vscode.commands.registerCommand("roo-cline.promptsButtonClicked", () => {
sidebarProvider.postMessageToWebview({ type: "action", action: "promptsButtonClicked" })
}),
)
const openClineInNewTab = async () => {
outputChannel.appendLine("Opening Cline in new tab")
// (this example uses webviewProvider activation event which is necessary to deserialize cached webview, but since we use retainContextWhenHidden, we don't need to use that event)

View File

@@ -4,7 +4,7 @@ import { ApiConfiguration, ApiProvider, ModelInfo } from "./api"
import { HistoryItem } from "./HistoryItem"
import { McpServer } from "./mcp"
import { GitCommit } from "../utils/git"
import { Mode } from "../core/prompts/types"
import { Mode, CustomPrompts } from "./modes"
// webview will hold state
export interface ExtensionMessage {
@@ -25,12 +25,15 @@ export interface ExtensionMessage {
| "enhancedPrompt"
| "commitSearchResults"
| "listApiConfig"
| "updatePrompt"
| "systemPrompt"
text?: string
action?:
| "chatButtonClicked"
| "mcpButtonClicked"
| "settingsButtonClicked"
| "historyButtonClicked"
| "promptsButtonClicked"
| "didBecomeVisible"
invoke?: "sendMessage" | "primaryButtonClick" | "secondaryButtonClick"
state?: ExtensionState
@@ -45,6 +48,7 @@ export interface ExtensionMessage {
mcpServers?: McpServer[]
commits?: GitCommit[]
listApiConfig?: ApiConfigMeta[]
mode?: Mode | 'enhance'
}
export interface ApiConfigMeta {
@@ -62,6 +66,7 @@ export interface ExtensionState {
currentApiConfigName?: string
listApiConfigMeta?: ApiConfigMeta[]
customInstructions?: string
customPrompts?: CustomPrompts
alwaysAllowReadOnly?: boolean
alwaysAllowWrite?: boolean
alwaysAllowExecute?: boolean
@@ -82,7 +87,8 @@ export interface ExtensionState {
terminalOutputLineLimit?: number
mcpEnabled: boolean
mode: Mode
modeApiConfigs?: Record<Mode, string>;
modeApiConfigs?: Record<Mode, string>
enhancementApiConfigId?: string
}
export interface ClineMessage {

View File

@@ -1,4 +1,7 @@
import { ApiConfiguration, ApiProvider } from "./api"
import { Mode } from "./modes"
export type PromptMode = Mode | 'enhance'
export type AudioType = "notification" | "celebration" | "progress_loop"
@@ -62,6 +65,10 @@ export interface WebviewMessage {
| "requestDelaySeconds"
| "setApiConfigPassword"
| "mode"
| "updatePrompt"
| "getSystemPrompt"
| "systemPrompt"
| "enhancementApiConfigId"
text?: string
disabled?: boolean
askResponse?: ClineAskResponse
@@ -74,6 +81,9 @@ export interface WebviewMessage {
serverName?: string
toolName?: string
alwaysAllow?: boolean
mode?: Mode
promptMode?: PromptMode
customPrompt?: string
dataUrls?: string[]
values?: Record<string, any>
query?: string

View File

@@ -2,4 +2,18 @@ export const codeMode = 'code' as const;
export const architectMode = 'architect' as const;
export const askMode = 'ask' as const;
export type Mode = typeof codeMode | typeof architectMode | typeof askMode;
export type Mode = typeof codeMode | typeof architectMode | typeof askMode;
export type CustomPrompts = {
ask?: string;
code?: string;
architect?: string;
enhance?: string;
}
export const defaultPrompts = {
[askMode]: "You are Cline, a knowledgeable technical assistant focused on answering questions and providing information about software development, technology, and related topics. You can analyze code, explain concepts, and access external resources while maintaining a read-only approach to the codebase. Make sure to answer the user's questions and don't rush to switch to implementing code.",
[codeMode]: "You are Cline, a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.",
[architectMode]: "You are Cline, a software architecture expert specializing in analyzing codebases, identifying patterns, and providing high-level technical guidance. You excel at understanding complex systems, evaluating architectural decisions, and suggesting improvements while maintaining a read-only approach to the codebase. Make sure to help the user come up with a solid implementation plan for their project and don't rush to switch to implementing code.",
enhance: "Generate an enhanced version of this prompt (reply with only the enhanced prompt - no conversation, explanations, lead-in, bullet points, placeholders, or surrounding quotes):"
} as const;

View File

@@ -1,80 +1,126 @@
import { enhancePrompt } from '../enhance-prompt'
import { buildApiHandler } from '../../api'
import { ApiConfiguration } from '../../shared/api'
import { OpenRouterHandler } from '../../api/providers/openrouter'
import { buildApiHandler, SingleCompletionHandler } from '../../api'
import { defaultPrompts } from '../../shared/modes'
// Mock the buildApiHandler function
// Mock the API handler
jest.mock('../../api', () => ({
buildApiHandler: jest.fn()
buildApiHandler: jest.fn()
}))
describe('enhancePrompt', () => {
const mockApiConfig: ApiConfiguration = {
apiProvider: 'openrouter',
apiKey: 'test-key',
openRouterApiKey: 'test-key',
openRouterModelId: 'test-model'
}
const mockApiConfig: ApiConfiguration = {
apiProvider: 'openai',
openAiApiKey: 'test-key',
openAiBaseUrl: 'https://api.openai.com/v1'
}
// Create a mock handler that looks like OpenRouterHandler
const mockHandler = {
completePrompt: jest.fn(),
createMessage: jest.fn(),
getModel: jest.fn()
}
// Make instanceof check work
Object.setPrototypeOf(mockHandler, OpenRouterHandler.prototype)
beforeEach(() => {
jest.clearAllMocks()
;(buildApiHandler as jest.Mock).mockReturnValue(mockHandler)
})
it('should throw error for non-OpenRouter providers', async () => {
const nonOpenRouterConfig: ApiConfiguration = {
apiProvider: 'anthropic',
apiKey: 'test-key',
apiModelId: 'claude-3'
beforeEach(() => {
jest.clearAllMocks()
// Mock the API handler with a completePrompt method
;(buildApiHandler as jest.Mock).mockReturnValue({
completePrompt: jest.fn().mockResolvedValue('Enhanced prompt'),
createMessage: jest.fn(),
getModel: jest.fn().mockReturnValue({
id: 'test-model',
info: {
maxTokens: 4096,
contextWindow: 8192,
supportsPromptCache: false
}
await expect(enhancePrompt(nonOpenRouterConfig, 'test')).rejects.toThrow('Prompt enhancement is only available with OpenRouter')
})
} as unknown as SingleCompletionHandler)
})
it('enhances prompt using default enhancement prompt when no custom prompt provided', async () => {
const result = await enhancePrompt(mockApiConfig, 'Test prompt')
expect(result).toBe('Enhanced prompt')
const handler = buildApiHandler(mockApiConfig)
expect((handler as any).completePrompt).toHaveBeenCalledWith(
`${defaultPrompts.enhance}\n\nTest prompt`
)
})
it('enhances prompt using custom enhancement prompt when provided', async () => {
const customEnhancePrompt = 'You are a custom prompt enhancer'
const result = await enhancePrompt(mockApiConfig, 'Test prompt', customEnhancePrompt)
expect(result).toBe('Enhanced prompt')
const handler = buildApiHandler(mockApiConfig)
expect((handler as any).completePrompt).toHaveBeenCalledWith(
`${customEnhancePrompt}\n\nTest prompt`
)
})
it('throws error for empty prompt input', async () => {
await expect(enhancePrompt(mockApiConfig, '')).rejects.toThrow('No prompt text provided')
})
it('throws error for missing API configuration', async () => {
await expect(enhancePrompt({} as ApiConfiguration, 'Test prompt')).rejects.toThrow('No valid API configuration provided')
})
it('throws error for API provider that does not support prompt enhancement', async () => {
(buildApiHandler as jest.Mock).mockReturnValue({
// No completePrompt method
createMessage: jest.fn(),
getModel: jest.fn().mockReturnValue({
id: 'test-model',
info: {
maxTokens: 4096,
contextWindow: 8192,
supportsPromptCache: false
}
})
})
it('should enhance a valid prompt', async () => {
const inputPrompt = 'Write a function to sort an array'
const enhancedPrompt = 'Write a TypeScript function that implements an efficient sorting algorithm for a generic array, including error handling and type safety'
mockHandler.completePrompt.mockResolvedValue(enhancedPrompt)
await expect(enhancePrompt(mockApiConfig, 'Test prompt')).rejects.toThrow('The selected API provider does not support prompt enhancement')
})
const result = await enhancePrompt(mockApiConfig, inputPrompt)
it('uses appropriate model based on provider', async () => {
const openRouterConfig: ApiConfiguration = {
apiProvider: 'openrouter',
openRouterApiKey: 'test-key',
openRouterModelId: 'test-model'
}
expect(result).toBe(enhancedPrompt)
expect(buildApiHandler).toHaveBeenCalledWith(mockApiConfig)
expect(mockHandler.completePrompt).toHaveBeenCalledWith(
expect.stringContaining(inputPrompt)
)
})
// Mock successful enhancement
;(buildApiHandler as jest.Mock).mockReturnValue({
completePrompt: jest.fn().mockResolvedValue('Enhanced prompt'),
createMessage: jest.fn(),
getModel: jest.fn().mockReturnValue({
id: 'test-model',
info: {
maxTokens: 4096,
contextWindow: 8192,
supportsPromptCache: false
}
})
} as unknown as SingleCompletionHandler)
it('should throw error when no prompt text is provided', async () => {
await expect(enhancePrompt(mockApiConfig, '')).rejects.toThrow('No prompt text provided')
expect(mockHandler.completePrompt).not.toHaveBeenCalled()
})
const result = await enhancePrompt(openRouterConfig, 'Test prompt')
expect(buildApiHandler).toHaveBeenCalledWith(openRouterConfig)
expect(result).toBe('Enhanced prompt')
})
it('should pass through API errors', async () => {
const inputPrompt = 'Test prompt'
mockHandler.completePrompt.mockRejectedValue('API error')
it('propagates API errors', async () => {
(buildApiHandler as jest.Mock).mockReturnValue({
completePrompt: jest.fn().mockRejectedValue(new Error('API Error')),
createMessage: jest.fn(),
getModel: jest.fn().mockReturnValue({
id: 'test-model',
info: {
maxTokens: 4096,
contextWindow: 8192,
supportsPromptCache: false
}
})
} as unknown as SingleCompletionHandler)
await expect(enhancePrompt(mockApiConfig, inputPrompt)).rejects.toBe('API error')
})
it('should pass the correct prompt format to the API', async () => {
const inputPrompt = 'Test prompt'
mockHandler.completePrompt.mockResolvedValue('Enhanced test prompt')
await enhancePrompt(mockApiConfig, inputPrompt)
expect(mockHandler.completePrompt).toHaveBeenCalledWith(
'Generate an enhanced version of this prompt (reply with only the enhanced prompt - no conversation, explanations, lead-in, bullet points, placeholders, or surrounding quotes):\n\nTest prompt'
)
})
await expect(enhancePrompt(mockApiConfig, 'Test prompt')).rejects.toThrow('API Error')
})
})

View File

@@ -1,26 +1,27 @@
import { ApiConfiguration } from "../shared/api"
import { buildApiHandler } from "../api"
import { OpenRouterHandler } from "../api/providers/openrouter"
import { buildApiHandler, SingleCompletionHandler } from "../api"
import { defaultPrompts } from "../shared/modes"
/**
* Enhances a prompt using the OpenRouter API without creating a full Cline instance or task history.
* Enhances a prompt using the configured API without creating a full Cline instance or task history.
* This is a lightweight alternative that only uses the API's completion functionality.
*/
export async function enhancePrompt(apiConfiguration: ApiConfiguration, promptText: string): Promise<string> {
export async function enhancePrompt(apiConfiguration: ApiConfiguration, promptText: string, enhancePrompt?: string): Promise<string> {
if (!promptText) {
throw new Error("No prompt text provided")
}
if (apiConfiguration.apiProvider !== "openrouter") {
throw new Error("Prompt enhancement is only available with OpenRouter")
if (!apiConfiguration || !apiConfiguration.apiProvider) {
throw new Error("No valid API configuration provided")
}
const handler = buildApiHandler(apiConfiguration)
// Type guard to check if handler is OpenRouterHandler
if (!(handler instanceof OpenRouterHandler)) {
throw new Error("Expected OpenRouter handler")
// Check if handler supports single completions
if (!('completePrompt' in handler)) {
throw new Error("The selected API provider does not support prompt enhancement")
}
const prompt = `Generate an enhanced version of this prompt (reply with only the enhanced prompt - no conversation, explanations, lead-in, bullet points, placeholders, or surrounding quotes):\n\n${promptText}`
return handler.completePrompt(prompt)
const enhancePromptText = enhancePrompt ?? defaultPrompts.enhance
const prompt = `${enhancePromptText}\n\n${promptText}`
return (handler as SingleCompletionHandler).completePrompt(prompt)
}