mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-21 12:51:17 -05:00
Enhance prompt button for openrouter
This commit is contained in:
@@ -11,6 +11,10 @@ import { GeminiHandler } from "./providers/gemini"
|
||||
import { OpenAiNativeHandler } from "./providers/openai-native"
|
||||
import { ApiStream } from "./transform/stream"
|
||||
|
||||
export interface SingleCompletionHandler {
|
||||
completePrompt(prompt: string): Promise<string>
|
||||
}
|
||||
|
||||
export interface ApiHandler {
|
||||
createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream
|
||||
getModel(): { id: string; info: ModelInfo }
|
||||
|
||||
@@ -51,6 +51,14 @@ describe('OpenRouterHandler', () => {
|
||||
})
|
||||
})
|
||||
|
||||
test('getModel returns default model info when options are not provided', () => {
|
||||
const handler = new OpenRouterHandler({})
|
||||
const result = handler.getModel()
|
||||
|
||||
expect(result.id).toBe('anthropic/claude-3.5-sonnet:beta')
|
||||
expect(result.info.supportsPromptCache).toBe(true)
|
||||
})
|
||||
|
||||
test('createMessage generates correct stream chunks', async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockStream = {
|
||||
@@ -118,4 +126,158 @@ describe('OpenRouterHandler', () => {
|
||||
stream: true
|
||||
}))
|
||||
})
|
||||
|
||||
test('createMessage with middle-out transform enabled', async () => {
|
||||
const handler = new OpenRouterHandler({
|
||||
...mockOptions,
|
||||
openRouterUseMiddleOutTransform: true
|
||||
})
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: 'test-id',
|
||||
choices: [{
|
||||
delta: {
|
||||
content: 'test response'
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
||||
|
||||
await handler.createMessage('test', []).next()
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
transforms: ['middle-out']
|
||||
}))
|
||||
})
|
||||
|
||||
test('createMessage with Claude model adds cache control', async () => {
|
||||
const handler = new OpenRouterHandler({
|
||||
...mockOptions,
|
||||
openRouterModelId: 'anthropic/claude-3.5-sonnet'
|
||||
})
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: 'test-id',
|
||||
choices: [{
|
||||
delta: {
|
||||
content: 'test response'
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
||||
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{ role: 'user', content: 'message 1' },
|
||||
{ role: 'assistant', content: 'response 1' },
|
||||
{ role: 'user', content: 'message 2' }
|
||||
]
|
||||
|
||||
await handler.createMessage('test system', messages).next()
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
messages: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
role: 'system',
|
||||
content: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
cache_control: { type: 'ephemeral' }
|
||||
})
|
||||
])
|
||||
})
|
||||
])
|
||||
}))
|
||||
})
|
||||
|
||||
test('createMessage handles API errors', async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
error: {
|
||||
message: 'API Error',
|
||||
code: 500
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
|
||||
const generator = handler.createMessage('test', [])
|
||||
await expect(generator.next()).rejects.toThrow('OpenRouter API Error 500: API Error')
|
||||
})
|
||||
|
||||
test('completePrompt returns correct response', async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockResponse = {
|
||||
choices: [{
|
||||
message: {
|
||||
content: 'test completion'
|
||||
}
|
||||
}]
|
||||
}
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockResponse)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
|
||||
const result = await handler.completePrompt('test prompt')
|
||||
|
||||
expect(result).toBe('test completion')
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: mockOptions.openRouterModelId,
|
||||
messages: [{ role: 'user', content: 'test prompt' }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
})
|
||||
})
|
||||
|
||||
test('completePrompt handles API errors', async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockError = {
|
||||
error: {
|
||||
message: 'API Error',
|
||||
code: 500
|
||||
}
|
||||
}
|
||||
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockError)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
|
||||
await expect(handler.completePrompt('test prompt'))
|
||||
.rejects.toThrow('OpenRouter API Error 500: API Error')
|
||||
})
|
||||
|
||||
test('completePrompt handles unexpected errors', async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockCreate = jest.fn().mockRejectedValue(new Error('Unexpected error'))
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
|
||||
await expect(handler.completePrompt('test prompt'))
|
||||
.rejects.toThrow('OpenRouter completion error: Unexpected error')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -4,11 +4,11 @@ import OpenAI from "openai"
|
||||
import { ApiHandler } from "../"
|
||||
import { ApiHandlerOptions, ModelInfo, openRouterDefaultModelId, openRouterDefaultModelInfo } from "../../shared/api"
|
||||
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||
import { ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream"
|
||||
import { ApiStream, ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream"
|
||||
import delay from "delay"
|
||||
|
||||
// Add custom interface for OpenRouter params
|
||||
interface OpenRouterChatCompletionParams extends OpenAI.Chat.ChatCompletionCreateParamsStreaming {
|
||||
type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & {
|
||||
transforms?: string[];
|
||||
}
|
||||
|
||||
@@ -17,7 +17,12 @@ interface OpenRouterApiStreamUsageChunk extends ApiStreamUsageChunk {
|
||||
fullResponseText: string;
|
||||
}
|
||||
|
||||
export class OpenRouterHandler implements ApiHandler {
|
||||
// Interface for providers that support single completions
|
||||
export interface SingleCompletionHandler {
|
||||
completePrompt(prompt: string): Promise<string>
|
||||
}
|
||||
|
||||
export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: OpenAI
|
||||
|
||||
@@ -184,4 +189,28 @@ export class OpenRouterHandler implements ApiHandler {
|
||||
}
|
||||
return { id: openRouterDefaultModelId, info: openRouterDefaultModelInfo }
|
||||
}
|
||||
|
||||
async completePrompt(prompt: string): Promise<string> {
|
||||
try {
|
||||
const response = await this.client.chat.completions.create({
|
||||
model: this.getModel().id,
|
||||
messages: [{ role: "user", content: prompt }],
|
||||
temperature: 0,
|
||||
stream: false
|
||||
})
|
||||
|
||||
if ("error" in response) {
|
||||
const error = response.error as { message?: string; code?: number }
|
||||
throw new Error(`OpenRouter API Error ${error?.code}: ${error?.message}`)
|
||||
}
|
||||
|
||||
const completion = response as OpenAI.Chat.ChatCompletion
|
||||
return completion.choices[0]?.message?.content || ""
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`OpenRouter completion error: ${error.message}`)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import pWaitFor from "p-wait-for"
|
||||
import * as path from "path"
|
||||
import { serializeError } from "serialize-error"
|
||||
import * as vscode from "vscode"
|
||||
import { ApiHandler, buildApiHandler } from "../api"
|
||||
import { ApiHandler, SingleCompletionHandler, buildApiHandler } from "../api"
|
||||
import { ApiStream } from "../api/transform/stream"
|
||||
import { DiffViewProvider } from "../integrations/editor/DiffViewProvider"
|
||||
import { findToolName, formatContentBlockToMarkdown } from "../integrations/misc/export-markdown"
|
||||
@@ -49,6 +49,7 @@ import { truncateHalfConversation } from "./sliding-window"
|
||||
import { ClineProvider, GlobalFileNames } from "./webview/ClineProvider"
|
||||
import { detectCodeOmission } from "../integrations/editor/detect-omission"
|
||||
import { BrowserSession } from "../services/browser/BrowserSession"
|
||||
import { OpenRouterHandler } from "../api/providers/openrouter"
|
||||
|
||||
const cwd =
|
||||
vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0) ?? path.join(os.homedir(), "Desktop") // may or may not exist but fs checking existence would immediately ask for permission which would be bad UX, need to come up with a better solution
|
||||
@@ -126,6 +127,22 @@ export class Cline {
|
||||
}
|
||||
}
|
||||
|
||||
async enhancePrompt(promptText: string): Promise<string> {
|
||||
if (!promptText) {
|
||||
throw new Error("No prompt text provided")
|
||||
}
|
||||
|
||||
const prompt = `Generate an enhanced version of this prompt (reply with only the enhanced prompt, no bullet points): ${promptText}`
|
||||
|
||||
// Check if the API handler supports completePrompt
|
||||
if (this.api instanceof OpenRouterHandler) {
|
||||
return this.api.completePrompt(prompt)
|
||||
}
|
||||
|
||||
// Otherwise just return the prompt as is
|
||||
return prompt;
|
||||
}
|
||||
|
||||
// Storing task to disk for history
|
||||
|
||||
private async ensureTaskDirectoryExists(): Promise<string> {
|
||||
|
||||
@@ -23,6 +23,7 @@ import { openMention } from "../mentions"
|
||||
import { getNonce } from "./getNonce"
|
||||
import { getUri } from "./getUri"
|
||||
import { playSound, setSoundEnabled, setSoundVolume } from "../../utils/sound"
|
||||
import { enhancePrompt } from "../../utils/enhance-prompt"
|
||||
|
||||
/*
|
||||
https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default/weather-webview/src/providers/WeatherViewProvider.ts
|
||||
@@ -637,6 +638,26 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
||||
await this.updateGlobalState("writeDelayMs", message.value)
|
||||
await this.postStateToWebview()
|
||||
break
|
||||
case "enhancePrompt":
|
||||
if (message.text) {
|
||||
try {
|
||||
const { apiConfiguration } = await this.getState()
|
||||
const enhanceConfig = {
|
||||
...apiConfiguration,
|
||||
apiProvider: "openrouter" as const,
|
||||
openRouterModelId: "gpt-4o",
|
||||
}
|
||||
const enhancedPrompt = await enhancePrompt(enhanceConfig, message.text)
|
||||
await this.postMessageToWebview({
|
||||
type: "enhancedPrompt",
|
||||
text: enhancedPrompt
|
||||
})
|
||||
} catch (error) {
|
||||
console.error("Error enhancing prompt:", error)
|
||||
vscode.window.showErrorMessage("Failed to enhance prompt")
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
},
|
||||
null,
|
||||
|
||||
@@ -19,6 +19,7 @@ export interface ExtensionMessage {
|
||||
| "openRouterModels"
|
||||
| "openAiModels"
|
||||
| "mcpServers"
|
||||
| "enhancedPrompt"
|
||||
text?: string
|
||||
action?:
|
||||
| "chatButtonClicked"
|
||||
|
||||
@@ -43,6 +43,9 @@ export interface WebviewMessage {
|
||||
| "fuzzyMatchThreshold"
|
||||
| "preferredLanguage"
|
||||
| "writeDelayMs"
|
||||
| "enhancePrompt"
|
||||
| "enhancedPrompt"
|
||||
| "draggedImages"
|
||||
text?: string
|
||||
disabled?: boolean
|
||||
askResponse?: ClineAskResponse
|
||||
@@ -52,10 +55,10 @@ export interface WebviewMessage {
|
||||
value?: number
|
||||
commands?: string[]
|
||||
audioType?: AudioType
|
||||
// For toggleToolAutoApprove
|
||||
serverName?: string
|
||||
toolName?: string
|
||||
alwaysAllow?: boolean
|
||||
dataUrls?: string[]
|
||||
}
|
||||
|
||||
export type ClineAskResponse = "yesButtonClicked" | "noButtonClicked" | "messageResponse"
|
||||
|
||||
80
src/utils/__tests__/enhance-prompt.test.ts
Normal file
80
src/utils/__tests__/enhance-prompt.test.ts
Normal file
@@ -0,0 +1,80 @@
|
||||
import { enhancePrompt } from '../enhance-prompt'
|
||||
import { buildApiHandler } from '../../api'
|
||||
import { ApiConfiguration } from '../../shared/api'
|
||||
import { OpenRouterHandler } from '../../api/providers/openrouter'
|
||||
|
||||
// Mock the buildApiHandler function
|
||||
jest.mock('../../api', () => ({
|
||||
buildApiHandler: jest.fn()
|
||||
}))
|
||||
|
||||
describe('enhancePrompt', () => {
|
||||
const mockApiConfig: ApiConfiguration = {
|
||||
apiProvider: 'openrouter',
|
||||
apiKey: 'test-key',
|
||||
openRouterApiKey: 'test-key',
|
||||
openRouterModelId: 'test-model'
|
||||
}
|
||||
|
||||
// Create a mock handler that looks like OpenRouterHandler
|
||||
const mockHandler = {
|
||||
completePrompt: jest.fn(),
|
||||
createMessage: jest.fn(),
|
||||
getModel: jest.fn()
|
||||
}
|
||||
|
||||
// Make instanceof check work
|
||||
Object.setPrototypeOf(mockHandler, OpenRouterHandler.prototype)
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
;(buildApiHandler as jest.Mock).mockReturnValue(mockHandler)
|
||||
})
|
||||
|
||||
it('should throw error for non-OpenRouter providers', async () => {
|
||||
const nonOpenRouterConfig: ApiConfiguration = {
|
||||
apiProvider: 'anthropic',
|
||||
apiKey: 'test-key',
|
||||
apiModelId: 'claude-3'
|
||||
}
|
||||
await expect(enhancePrompt(nonOpenRouterConfig, 'test')).rejects.toThrow('Prompt enhancement is only available with OpenRouter')
|
||||
})
|
||||
|
||||
it('should enhance a valid prompt', async () => {
|
||||
const inputPrompt = 'Write a function to sort an array'
|
||||
const enhancedPrompt = 'Write a TypeScript function that implements an efficient sorting algorithm for a generic array, including error handling and type safety'
|
||||
|
||||
mockHandler.completePrompt.mockResolvedValue(enhancedPrompt)
|
||||
|
||||
const result = await enhancePrompt(mockApiConfig, inputPrompt)
|
||||
|
||||
expect(result).toBe(enhancedPrompt)
|
||||
expect(buildApiHandler).toHaveBeenCalledWith(mockApiConfig)
|
||||
expect(mockHandler.completePrompt).toHaveBeenCalledWith(
|
||||
expect.stringContaining(inputPrompt)
|
||||
)
|
||||
})
|
||||
|
||||
it('should throw error when no prompt text is provided', async () => {
|
||||
await expect(enhancePrompt(mockApiConfig, '')).rejects.toThrow('No prompt text provided')
|
||||
expect(mockHandler.completePrompt).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should pass through API errors', async () => {
|
||||
const inputPrompt = 'Test prompt'
|
||||
mockHandler.completePrompt.mockRejectedValue('API error')
|
||||
|
||||
await expect(enhancePrompt(mockApiConfig, inputPrompt)).rejects.toBe('API error')
|
||||
})
|
||||
|
||||
it('should pass the correct prompt format to the API', async () => {
|
||||
const inputPrompt = 'Test prompt'
|
||||
mockHandler.completePrompt.mockResolvedValue('Enhanced test prompt')
|
||||
|
||||
await enhancePrompt(mockApiConfig, inputPrompt)
|
||||
|
||||
expect(mockHandler.completePrompt).toHaveBeenCalledWith(
|
||||
'Generate an enhanced version of this prompt (reply with only the enhanced prompt, no other text or bullet points): Test prompt'
|
||||
)
|
||||
})
|
||||
})
|
||||
26
src/utils/enhance-prompt.ts
Normal file
26
src/utils/enhance-prompt.ts
Normal file
@@ -0,0 +1,26 @@
|
||||
import { ApiConfiguration } from "../shared/api"
|
||||
import { buildApiHandler } from "../api"
|
||||
import { OpenRouterHandler, SingleCompletionHandler } from "../api/providers/openrouter"
|
||||
|
||||
/**
|
||||
* Enhances a prompt using the OpenRouter API without creating a full Cline instance or task history.
|
||||
* This is a lightweight alternative that only uses the API's completion functionality.
|
||||
*/
|
||||
export async function enhancePrompt(apiConfiguration: ApiConfiguration, promptText: string): Promise<string> {
|
||||
if (!promptText) {
|
||||
throw new Error("No prompt text provided")
|
||||
}
|
||||
if (apiConfiguration.apiProvider !== "openrouter") {
|
||||
throw new Error("Prompt enhancement is only available with OpenRouter")
|
||||
}
|
||||
|
||||
const handler = buildApiHandler(apiConfiguration)
|
||||
|
||||
// Type guard to check if handler is OpenRouterHandler
|
||||
if (!(handler instanceof OpenRouterHandler)) {
|
||||
throw new Error("Expected OpenRouter handler")
|
||||
}
|
||||
|
||||
const prompt = `Generate an enhanced version of this prompt (reply with only the enhanced prompt, no other text or bullet points): ${promptText}`
|
||||
return handler.completePrompt(prompt)
|
||||
}
|
||||
Reference in New Issue
Block a user