mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-20 12:21:13 -05:00
Enhance prompt button for openrouter
This commit is contained in:
@@ -11,6 +11,10 @@ import { GeminiHandler } from "./providers/gemini"
|
|||||||
import { OpenAiNativeHandler } from "./providers/openai-native"
|
import { OpenAiNativeHandler } from "./providers/openai-native"
|
||||||
import { ApiStream } from "./transform/stream"
|
import { ApiStream } from "./transform/stream"
|
||||||
|
|
||||||
|
export interface SingleCompletionHandler {
|
||||||
|
completePrompt(prompt: string): Promise<string>
|
||||||
|
}
|
||||||
|
|
||||||
export interface ApiHandler {
|
export interface ApiHandler {
|
||||||
createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream
|
createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream
|
||||||
getModel(): { id: string; info: ModelInfo }
|
getModel(): { id: string; info: ModelInfo }
|
||||||
|
|||||||
@@ -51,6 +51,14 @@ describe('OpenRouterHandler', () => {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test('getModel returns default model info when options are not provided', () => {
|
||||||
|
const handler = new OpenRouterHandler({})
|
||||||
|
const result = handler.getModel()
|
||||||
|
|
||||||
|
expect(result.id).toBe('anthropic/claude-3.5-sonnet:beta')
|
||||||
|
expect(result.info.supportsPromptCache).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
test('createMessage generates correct stream chunks', async () => {
|
test('createMessage generates correct stream chunks', async () => {
|
||||||
const handler = new OpenRouterHandler(mockOptions)
|
const handler = new OpenRouterHandler(mockOptions)
|
||||||
const mockStream = {
|
const mockStream = {
|
||||||
@@ -118,4 +126,158 @@ describe('OpenRouterHandler', () => {
|
|||||||
stream: true
|
stream: true
|
||||||
}))
|
}))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test('createMessage with middle-out transform enabled', async () => {
|
||||||
|
const handler = new OpenRouterHandler({
|
||||||
|
...mockOptions,
|
||||||
|
openRouterUseMiddleOutTransform: true
|
||||||
|
})
|
||||||
|
const mockStream = {
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
yield {
|
||||||
|
id: 'test-id',
|
||||||
|
choices: [{
|
||||||
|
delta: {
|
||||||
|
content: 'test response'
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||||
|
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||||
|
completions: { create: mockCreate }
|
||||||
|
} as any
|
||||||
|
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
||||||
|
|
||||||
|
await handler.createMessage('test', []).next()
|
||||||
|
|
||||||
|
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||||
|
transforms: ['middle-out']
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
test('createMessage with Claude model adds cache control', async () => {
|
||||||
|
const handler = new OpenRouterHandler({
|
||||||
|
...mockOptions,
|
||||||
|
openRouterModelId: 'anthropic/claude-3.5-sonnet'
|
||||||
|
})
|
||||||
|
const mockStream = {
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
yield {
|
||||||
|
id: 'test-id',
|
||||||
|
choices: [{
|
||||||
|
delta: {
|
||||||
|
content: 'test response'
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||||
|
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||||
|
completions: { create: mockCreate }
|
||||||
|
} as any
|
||||||
|
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
|
||||||
|
|
||||||
|
const messages: Anthropic.Messages.MessageParam[] = [
|
||||||
|
{ role: 'user', content: 'message 1' },
|
||||||
|
{ role: 'assistant', content: 'response 1' },
|
||||||
|
{ role: 'user', content: 'message 2' }
|
||||||
|
]
|
||||||
|
|
||||||
|
await handler.createMessage('test system', messages).next()
|
||||||
|
|
||||||
|
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||||
|
messages: expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
role: 'system',
|
||||||
|
content: expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
cache_control: { type: 'ephemeral' }
|
||||||
|
})
|
||||||
|
])
|
||||||
|
})
|
||||||
|
])
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
test('createMessage handles API errors', async () => {
|
||||||
|
const handler = new OpenRouterHandler(mockOptions)
|
||||||
|
const mockStream = {
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
yield {
|
||||||
|
error: {
|
||||||
|
message: 'API Error',
|
||||||
|
code: 500
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||||
|
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||||
|
completions: { create: mockCreate }
|
||||||
|
} as any
|
||||||
|
|
||||||
|
const generator = handler.createMessage('test', [])
|
||||||
|
await expect(generator.next()).rejects.toThrow('OpenRouter API Error 500: API Error')
|
||||||
|
})
|
||||||
|
|
||||||
|
test('completePrompt returns correct response', async () => {
|
||||||
|
const handler = new OpenRouterHandler(mockOptions)
|
||||||
|
const mockResponse = {
|
||||||
|
choices: [{
|
||||||
|
message: {
|
||||||
|
content: 'test completion'
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
const mockCreate = jest.fn().mockResolvedValue(mockResponse)
|
||||||
|
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||||
|
completions: { create: mockCreate }
|
||||||
|
} as any
|
||||||
|
|
||||||
|
const result = await handler.completePrompt('test prompt')
|
||||||
|
|
||||||
|
expect(result).toBe('test completion')
|
||||||
|
expect(mockCreate).toHaveBeenCalledWith({
|
||||||
|
model: mockOptions.openRouterModelId,
|
||||||
|
messages: [{ role: 'user', content: 'test prompt' }],
|
||||||
|
temperature: 0,
|
||||||
|
stream: false
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test('completePrompt handles API errors', async () => {
|
||||||
|
const handler = new OpenRouterHandler(mockOptions)
|
||||||
|
const mockError = {
|
||||||
|
error: {
|
||||||
|
message: 'API Error',
|
||||||
|
code: 500
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const mockCreate = jest.fn().mockResolvedValue(mockError)
|
||||||
|
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||||
|
completions: { create: mockCreate }
|
||||||
|
} as any
|
||||||
|
|
||||||
|
await expect(handler.completePrompt('test prompt'))
|
||||||
|
.rejects.toThrow('OpenRouter API Error 500: API Error')
|
||||||
|
})
|
||||||
|
|
||||||
|
test('completePrompt handles unexpected errors', async () => {
|
||||||
|
const handler = new OpenRouterHandler(mockOptions)
|
||||||
|
const mockCreate = jest.fn().mockRejectedValue(new Error('Unexpected error'))
|
||||||
|
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||||
|
completions: { create: mockCreate }
|
||||||
|
} as any
|
||||||
|
|
||||||
|
await expect(handler.completePrompt('test prompt'))
|
||||||
|
.rejects.toThrow('OpenRouter completion error: Unexpected error')
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ import OpenAI from "openai"
|
|||||||
import { ApiHandler } from "../"
|
import { ApiHandler } from "../"
|
||||||
import { ApiHandlerOptions, ModelInfo, openRouterDefaultModelId, openRouterDefaultModelInfo } from "../../shared/api"
|
import { ApiHandlerOptions, ModelInfo, openRouterDefaultModelId, openRouterDefaultModelInfo } from "../../shared/api"
|
||||||
import { convertToOpenAiMessages } from "../transform/openai-format"
|
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||||
import { ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream"
|
import { ApiStream, ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream"
|
||||||
import delay from "delay"
|
import delay from "delay"
|
||||||
|
|
||||||
// Add custom interface for OpenRouter params
|
// Add custom interface for OpenRouter params
|
||||||
interface OpenRouterChatCompletionParams extends OpenAI.Chat.ChatCompletionCreateParamsStreaming {
|
type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & {
|
||||||
transforms?: string[];
|
transforms?: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -17,7 +17,12 @@ interface OpenRouterApiStreamUsageChunk extends ApiStreamUsageChunk {
|
|||||||
fullResponseText: string;
|
fullResponseText: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export class OpenRouterHandler implements ApiHandler {
|
// Interface for providers that support single completions
|
||||||
|
export interface SingleCompletionHandler {
|
||||||
|
completePrompt(prompt: string): Promise<string>
|
||||||
|
}
|
||||||
|
|
||||||
|
export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
|
||||||
private options: ApiHandlerOptions
|
private options: ApiHandlerOptions
|
||||||
private client: OpenAI
|
private client: OpenAI
|
||||||
|
|
||||||
@@ -184,4 +189,28 @@ export class OpenRouterHandler implements ApiHandler {
|
|||||||
}
|
}
|
||||||
return { id: openRouterDefaultModelId, info: openRouterDefaultModelInfo }
|
return { id: openRouterDefaultModelId, info: openRouterDefaultModelInfo }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async completePrompt(prompt: string): Promise<string> {
|
||||||
|
try {
|
||||||
|
const response = await this.client.chat.completions.create({
|
||||||
|
model: this.getModel().id,
|
||||||
|
messages: [{ role: "user", content: prompt }],
|
||||||
|
temperature: 0,
|
||||||
|
stream: false
|
||||||
|
})
|
||||||
|
|
||||||
|
if ("error" in response) {
|
||||||
|
const error = response.error as { message?: string; code?: number }
|
||||||
|
throw new Error(`OpenRouter API Error ${error?.code}: ${error?.message}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const completion = response as OpenAI.Chat.ChatCompletion
|
||||||
|
return completion.choices[0]?.message?.content || ""
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error) {
|
||||||
|
throw new Error(`OpenRouter completion error: ${error.message}`)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import pWaitFor from "p-wait-for"
|
|||||||
import * as path from "path"
|
import * as path from "path"
|
||||||
import { serializeError } from "serialize-error"
|
import { serializeError } from "serialize-error"
|
||||||
import * as vscode from "vscode"
|
import * as vscode from "vscode"
|
||||||
import { ApiHandler, buildApiHandler } from "../api"
|
import { ApiHandler, SingleCompletionHandler, buildApiHandler } from "../api"
|
||||||
import { ApiStream } from "../api/transform/stream"
|
import { ApiStream } from "../api/transform/stream"
|
||||||
import { DiffViewProvider } from "../integrations/editor/DiffViewProvider"
|
import { DiffViewProvider } from "../integrations/editor/DiffViewProvider"
|
||||||
import { findToolName, formatContentBlockToMarkdown } from "../integrations/misc/export-markdown"
|
import { findToolName, formatContentBlockToMarkdown } from "../integrations/misc/export-markdown"
|
||||||
@@ -49,6 +49,7 @@ import { truncateHalfConversation } from "./sliding-window"
|
|||||||
import { ClineProvider, GlobalFileNames } from "./webview/ClineProvider"
|
import { ClineProvider, GlobalFileNames } from "./webview/ClineProvider"
|
||||||
import { detectCodeOmission } from "../integrations/editor/detect-omission"
|
import { detectCodeOmission } from "../integrations/editor/detect-omission"
|
||||||
import { BrowserSession } from "../services/browser/BrowserSession"
|
import { BrowserSession } from "../services/browser/BrowserSession"
|
||||||
|
import { OpenRouterHandler } from "../api/providers/openrouter"
|
||||||
|
|
||||||
const cwd =
|
const cwd =
|
||||||
vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0) ?? path.join(os.homedir(), "Desktop") // may or may not exist but fs checking existence would immediately ask for permission which would be bad UX, need to come up with a better solution
|
vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0) ?? path.join(os.homedir(), "Desktop") // may or may not exist but fs checking existence would immediately ask for permission which would be bad UX, need to come up with a better solution
|
||||||
@@ -126,6 +127,22 @@ export class Cline {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async enhancePrompt(promptText: string): Promise<string> {
|
||||||
|
if (!promptText) {
|
||||||
|
throw new Error("No prompt text provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
const prompt = `Generate an enhanced version of this prompt (reply with only the enhanced prompt, no bullet points): ${promptText}`
|
||||||
|
|
||||||
|
// Check if the API handler supports completePrompt
|
||||||
|
if (this.api instanceof OpenRouterHandler) {
|
||||||
|
return this.api.completePrompt(prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise just return the prompt as is
|
||||||
|
return prompt;
|
||||||
|
}
|
||||||
|
|
||||||
// Storing task to disk for history
|
// Storing task to disk for history
|
||||||
|
|
||||||
private async ensureTaskDirectoryExists(): Promise<string> {
|
private async ensureTaskDirectoryExists(): Promise<string> {
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import { openMention } from "../mentions"
|
|||||||
import { getNonce } from "./getNonce"
|
import { getNonce } from "./getNonce"
|
||||||
import { getUri } from "./getUri"
|
import { getUri } from "./getUri"
|
||||||
import { playSound, setSoundEnabled, setSoundVolume } from "../../utils/sound"
|
import { playSound, setSoundEnabled, setSoundVolume } from "../../utils/sound"
|
||||||
|
import { enhancePrompt } from "../../utils/enhance-prompt"
|
||||||
|
|
||||||
/*
|
/*
|
||||||
https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default/weather-webview/src/providers/WeatherViewProvider.ts
|
https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default/weather-webview/src/providers/WeatherViewProvider.ts
|
||||||
@@ -637,6 +638,26 @@ export class ClineProvider implements vscode.WebviewViewProvider {
|
|||||||
await this.updateGlobalState("writeDelayMs", message.value)
|
await this.updateGlobalState("writeDelayMs", message.value)
|
||||||
await this.postStateToWebview()
|
await this.postStateToWebview()
|
||||||
break
|
break
|
||||||
|
case "enhancePrompt":
|
||||||
|
if (message.text) {
|
||||||
|
try {
|
||||||
|
const { apiConfiguration } = await this.getState()
|
||||||
|
const enhanceConfig = {
|
||||||
|
...apiConfiguration,
|
||||||
|
apiProvider: "openrouter" as const,
|
||||||
|
openRouterModelId: "gpt-4o",
|
||||||
|
}
|
||||||
|
const enhancedPrompt = await enhancePrompt(enhanceConfig, message.text)
|
||||||
|
await this.postMessageToWebview({
|
||||||
|
type: "enhancedPrompt",
|
||||||
|
text: enhancedPrompt
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error enhancing prompt:", error)
|
||||||
|
vscode.window.showErrorMessage("Failed to enhance prompt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
null,
|
null,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ export interface ExtensionMessage {
|
|||||||
| "openRouterModels"
|
| "openRouterModels"
|
||||||
| "openAiModels"
|
| "openAiModels"
|
||||||
| "mcpServers"
|
| "mcpServers"
|
||||||
|
| "enhancedPrompt"
|
||||||
text?: string
|
text?: string
|
||||||
action?:
|
action?:
|
||||||
| "chatButtonClicked"
|
| "chatButtonClicked"
|
||||||
|
|||||||
@@ -43,6 +43,9 @@ export interface WebviewMessage {
|
|||||||
| "fuzzyMatchThreshold"
|
| "fuzzyMatchThreshold"
|
||||||
| "preferredLanguage"
|
| "preferredLanguage"
|
||||||
| "writeDelayMs"
|
| "writeDelayMs"
|
||||||
|
| "enhancePrompt"
|
||||||
|
| "enhancedPrompt"
|
||||||
|
| "draggedImages"
|
||||||
text?: string
|
text?: string
|
||||||
disabled?: boolean
|
disabled?: boolean
|
||||||
askResponse?: ClineAskResponse
|
askResponse?: ClineAskResponse
|
||||||
@@ -52,10 +55,10 @@ export interface WebviewMessage {
|
|||||||
value?: number
|
value?: number
|
||||||
commands?: string[]
|
commands?: string[]
|
||||||
audioType?: AudioType
|
audioType?: AudioType
|
||||||
// For toggleToolAutoApprove
|
|
||||||
serverName?: string
|
serverName?: string
|
||||||
toolName?: string
|
toolName?: string
|
||||||
alwaysAllow?: boolean
|
alwaysAllow?: boolean
|
||||||
|
dataUrls?: string[]
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ClineAskResponse = "yesButtonClicked" | "noButtonClicked" | "messageResponse"
|
export type ClineAskResponse = "yesButtonClicked" | "noButtonClicked" | "messageResponse"
|
||||||
|
|||||||
80
src/utils/__tests__/enhance-prompt.test.ts
Normal file
80
src/utils/__tests__/enhance-prompt.test.ts
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import { enhancePrompt } from '../enhance-prompt'
|
||||||
|
import { buildApiHandler } from '../../api'
|
||||||
|
import { ApiConfiguration } from '../../shared/api'
|
||||||
|
import { OpenRouterHandler } from '../../api/providers/openrouter'
|
||||||
|
|
||||||
|
// Mock the buildApiHandler function
|
||||||
|
jest.mock('../../api', () => ({
|
||||||
|
buildApiHandler: jest.fn()
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('enhancePrompt', () => {
|
||||||
|
const mockApiConfig: ApiConfiguration = {
|
||||||
|
apiProvider: 'openrouter',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
openRouterApiKey: 'test-key',
|
||||||
|
openRouterModelId: 'test-model'
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a mock handler that looks like OpenRouterHandler
|
||||||
|
const mockHandler = {
|
||||||
|
completePrompt: jest.fn(),
|
||||||
|
createMessage: jest.fn(),
|
||||||
|
getModel: jest.fn()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make instanceof check work
|
||||||
|
Object.setPrototypeOf(mockHandler, OpenRouterHandler.prototype)
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks()
|
||||||
|
;(buildApiHandler as jest.Mock).mockReturnValue(mockHandler)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should throw error for non-OpenRouter providers', async () => {
|
||||||
|
const nonOpenRouterConfig: ApiConfiguration = {
|
||||||
|
apiProvider: 'anthropic',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiModelId: 'claude-3'
|
||||||
|
}
|
||||||
|
await expect(enhancePrompt(nonOpenRouterConfig, 'test')).rejects.toThrow('Prompt enhancement is only available with OpenRouter')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should enhance a valid prompt', async () => {
|
||||||
|
const inputPrompt = 'Write a function to sort an array'
|
||||||
|
const enhancedPrompt = 'Write a TypeScript function that implements an efficient sorting algorithm for a generic array, including error handling and type safety'
|
||||||
|
|
||||||
|
mockHandler.completePrompt.mockResolvedValue(enhancedPrompt)
|
||||||
|
|
||||||
|
const result = await enhancePrompt(mockApiConfig, inputPrompt)
|
||||||
|
|
||||||
|
expect(result).toBe(enhancedPrompt)
|
||||||
|
expect(buildApiHandler).toHaveBeenCalledWith(mockApiConfig)
|
||||||
|
expect(mockHandler.completePrompt).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining(inputPrompt)
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should throw error when no prompt text is provided', async () => {
|
||||||
|
await expect(enhancePrompt(mockApiConfig, '')).rejects.toThrow('No prompt text provided')
|
||||||
|
expect(mockHandler.completePrompt).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should pass through API errors', async () => {
|
||||||
|
const inputPrompt = 'Test prompt'
|
||||||
|
mockHandler.completePrompt.mockRejectedValue('API error')
|
||||||
|
|
||||||
|
await expect(enhancePrompt(mockApiConfig, inputPrompt)).rejects.toBe('API error')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should pass the correct prompt format to the API', async () => {
|
||||||
|
const inputPrompt = 'Test prompt'
|
||||||
|
mockHandler.completePrompt.mockResolvedValue('Enhanced test prompt')
|
||||||
|
|
||||||
|
await enhancePrompt(mockApiConfig, inputPrompt)
|
||||||
|
|
||||||
|
expect(mockHandler.completePrompt).toHaveBeenCalledWith(
|
||||||
|
'Generate an enhanced version of this prompt (reply with only the enhanced prompt, no other text or bullet points): Test prompt'
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
26
src/utils/enhance-prompt.ts
Normal file
26
src/utils/enhance-prompt.ts
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import { ApiConfiguration } from "../shared/api"
|
||||||
|
import { buildApiHandler } from "../api"
|
||||||
|
import { OpenRouterHandler, SingleCompletionHandler } from "../api/providers/openrouter"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Enhances a prompt using the OpenRouter API without creating a full Cline instance or task history.
|
||||||
|
* This is a lightweight alternative that only uses the API's completion functionality.
|
||||||
|
*/
|
||||||
|
export async function enhancePrompt(apiConfiguration: ApiConfiguration, promptText: string): Promise<string> {
|
||||||
|
if (!promptText) {
|
||||||
|
throw new Error("No prompt text provided")
|
||||||
|
}
|
||||||
|
if (apiConfiguration.apiProvider !== "openrouter") {
|
||||||
|
throw new Error("Prompt enhancement is only available with OpenRouter")
|
||||||
|
}
|
||||||
|
|
||||||
|
const handler = buildApiHandler(apiConfiguration)
|
||||||
|
|
||||||
|
// Type guard to check if handler is OpenRouterHandler
|
||||||
|
if (!(handler instanceof OpenRouterHandler)) {
|
||||||
|
throw new Error("Expected OpenRouter handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
const prompt = `Generate an enhanced version of this prompt (reply with only the enhanced prompt, no other text or bullet points): ${promptText}`
|
||||||
|
return handler.completePrompt(prompt)
|
||||||
|
}
|
||||||
@@ -13,7 +13,7 @@ import { MAX_IMAGES_PER_MESSAGE } from "./ChatView"
|
|||||||
import ContextMenu from "./ContextMenu"
|
import ContextMenu from "./ContextMenu"
|
||||||
import Thumbnails from "../common/Thumbnails"
|
import Thumbnails from "../common/Thumbnails"
|
||||||
|
|
||||||
declare const vscode: any;
|
import { vscode } from "../../utils/vscode"
|
||||||
|
|
||||||
interface ChatTextAreaProps {
|
interface ChatTextAreaProps {
|
||||||
inputValue: string
|
inputValue: string
|
||||||
@@ -44,8 +44,20 @@ const ChatTextArea = forwardRef<HTMLTextAreaElement, ChatTextAreaProps>(
|
|||||||
},
|
},
|
||||||
ref,
|
ref,
|
||||||
) => {
|
) => {
|
||||||
const { filePaths } = useExtensionState()
|
const { filePaths, apiConfiguration } = useExtensionState()
|
||||||
const [isTextAreaFocused, setIsTextAreaFocused] = useState(false)
|
const [isTextAreaFocused, setIsTextAreaFocused] = useState(false)
|
||||||
|
|
||||||
|
// Handle enhanced prompt response
|
||||||
|
useEffect(() => {
|
||||||
|
const messageHandler = (event: MessageEvent) => {
|
||||||
|
const message = event.data
|
||||||
|
if (message.type === 'enhancedPrompt' && message.text) {
|
||||||
|
setInputValue(message.text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
window.addEventListener('message', messageHandler)
|
||||||
|
return () => window.removeEventListener('message', messageHandler)
|
||||||
|
}, [setInputValue])
|
||||||
const [thumbnailsHeight, setThumbnailsHeight] = useState(0)
|
const [thumbnailsHeight, setThumbnailsHeight] = useState(0)
|
||||||
const [textAreaBaseHeight, setTextAreaBaseHeight] = useState<number | undefined>(undefined)
|
const [textAreaBaseHeight, setTextAreaBaseHeight] = useState<number | undefined>(undefined)
|
||||||
const [showContextMenu, setShowContextMenu] = useState(false)
|
const [showContextMenu, setShowContextMenu] = useState(false)
|
||||||
@@ -60,6 +72,63 @@ const ChatTextArea = forwardRef<HTMLTextAreaElement, ChatTextAreaProps>(
|
|||||||
const [intendedCursorPosition, setIntendedCursorPosition] = useState<number | null>(null)
|
const [intendedCursorPosition, setIntendedCursorPosition] = useState<number | null>(null)
|
||||||
const contextMenuContainerRef = useRef<HTMLDivElement>(null)
|
const contextMenuContainerRef = useRef<HTMLDivElement>(null)
|
||||||
|
|
||||||
|
const [isEnhancingPrompt, setIsEnhancingPrompt] = useState(false)
|
||||||
|
|
||||||
|
const handleEnhancePrompt = useCallback(() => {
|
||||||
|
if (!textAreaDisabled) {
|
||||||
|
const trimmedInput = inputValue.trim()
|
||||||
|
if (trimmedInput) {
|
||||||
|
setIsEnhancingPrompt(true)
|
||||||
|
const message = {
|
||||||
|
type: "enhancePrompt" as const,
|
||||||
|
text: trimmedInput,
|
||||||
|
}
|
||||||
|
vscode.postMessage(message)
|
||||||
|
} else {
|
||||||
|
const promptDescription = "The 'Enhance Prompt' button helps improve your prompt by providing additional context, clarification, or rephrasing. Try typing a prompt in here and clicking the button again to see how it works."
|
||||||
|
setInputValue(promptDescription)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [inputValue, textAreaDisabled, setInputValue])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const messageHandler = (event: MessageEvent) => {
|
||||||
|
const message = event.data
|
||||||
|
if (message.type === 'enhancedPrompt') {
|
||||||
|
setInputValue(message.text)
|
||||||
|
setIsEnhancingPrompt(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
window.addEventListener('message', messageHandler)
|
||||||
|
return () => window.removeEventListener('message', messageHandler)
|
||||||
|
}, [setInputValue])
|
||||||
|
|
||||||
|
// Handle enhanced prompt response
|
||||||
|
useEffect(() => {
|
||||||
|
const messageHandler = (event: MessageEvent) => {
|
||||||
|
const message = event.data
|
||||||
|
if (message.type === 'enhancedPrompt') {
|
||||||
|
setInputValue(message.text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
window.addEventListener('message', messageHandler)
|
||||||
|
return () => {
|
||||||
|
window.removeEventListener('message', messageHandler)
|
||||||
|
}
|
||||||
|
}, [setInputValue])
|
||||||
|
|
||||||
|
// Handle enhanced prompt response
|
||||||
|
useEffect(() => {
|
||||||
|
const messageHandler = (event: MessageEvent) => {
|
||||||
|
const message = event.data
|
||||||
|
if (message.type === 'enhancedPrompt' && message.text) {
|
||||||
|
setInputValue(message.text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
window.addEventListener('message', messageHandler)
|
||||||
|
return () => window.removeEventListener('message', messageHandler)
|
||||||
|
}, [setInputValue])
|
||||||
|
|
||||||
const queryItems = useMemo(() => {
|
const queryItems = useMemo(() => {
|
||||||
return [
|
return [
|
||||||
{ type: ContextMenuOptionType.Problems, value: "problems" },
|
{ type: ContextMenuOptionType.Problems, value: "problems" },
|
||||||
@@ -423,68 +492,64 @@ const ChatTextArea = forwardRef<HTMLTextAreaElement, ChatTextAreaProps>(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div style={{
|
||||||
style={{
|
padding: "10px 15px",
|
||||||
padding: "10px 15px",
|
opacity: textAreaDisabled ? 0.5 : 1,
|
||||||
opacity: textAreaDisabled ? 0.5 : 1,
|
position: "relative",
|
||||||
position: "relative",
|
display: "flex",
|
||||||
display: "flex",
|
}}
|
||||||
}}
|
onDrop={async (e) => {
|
||||||
onDrop={async (e) => {
|
e.preventDefault()
|
||||||
console.log("onDrop called")
|
const files = Array.from(e.dataTransfer.files)
|
||||||
e.preventDefault()
|
const text = e.dataTransfer.getData("text")
|
||||||
const files = Array.from(e.dataTransfer.files)
|
if (text) {
|
||||||
const text = e.dataTransfer.getData("text")
|
const newValue =
|
||||||
if (text) {
|
inputValue.slice(0, cursorPosition) + text + inputValue.slice(cursorPosition)
|
||||||
const newValue =
|
setInputValue(newValue)
|
||||||
inputValue.slice(0, cursorPosition) + text + inputValue.slice(cursorPosition)
|
const newCursorPosition = cursorPosition + text.length
|
||||||
setInputValue(newValue)
|
setCursorPosition(newCursorPosition)
|
||||||
const newCursorPosition = cursorPosition + text.length
|
setIntendedCursorPosition(newCursorPosition)
|
||||||
setCursorPosition(newCursorPosition)
|
return
|
||||||
setIntendedCursorPosition(newCursorPosition)
|
}
|
||||||
return
|
const acceptedTypes = ["png", "jpeg", "webp"]
|
||||||
}
|
const imageFiles = files.filter((file) => {
|
||||||
const acceptedTypes = ["png", "jpeg", "webp"]
|
const [type, subtype] = file.type.split("/")
|
||||||
const imageFiles = files.filter((file) => {
|
return type === "image" && acceptedTypes.includes(subtype)
|
||||||
const [type, subtype] = file.type.split("/")
|
})
|
||||||
return type === "image" && acceptedTypes.includes(subtype)
|
if (!shouldDisableImages && imageFiles.length > 0) {
|
||||||
})
|
const imagePromises = imageFiles.map((file) => {
|
||||||
if (!shouldDisableImages && imageFiles.length > 0) {
|
return new Promise<string | null>((resolve) => {
|
||||||
const imagePromises = imageFiles.map((file) => {
|
const reader = new FileReader()
|
||||||
return new Promise<string | null>((resolve) => {
|
reader.onloadend = () => {
|
||||||
const reader = new FileReader()
|
if (reader.error) {
|
||||||
reader.onloadend = () => {
|
console.error("Error reading file:", reader.error)
|
||||||
if (reader.error) {
|
resolve(null)
|
||||||
console.error("Error reading file:", reader.error)
|
} else {
|
||||||
resolve(null)
|
const result = reader.result
|
||||||
} else {
|
resolve(typeof result === "string" ? result : null)
|
||||||
const result = reader.result
|
|
||||||
console.log("File read successfully", result)
|
|
||||||
resolve(typeof result === "string" ? result : null)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
reader.readAsDataURL(file)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
const imageDataArray = await Promise.all(imagePromises)
|
|
||||||
const dataUrls = imageDataArray.filter((dataUrl): dataUrl is string => dataUrl !== null)
|
|
||||||
if (dataUrls.length > 0) {
|
|
||||||
setSelectedImages((prevImages) => [...prevImages, ...dataUrls].slice(0, MAX_IMAGES_PER_MESSAGE))
|
|
||||||
if (typeof vscode !== 'undefined') {
|
|
||||||
vscode.postMessage({
|
|
||||||
type: 'draggedImages',
|
|
||||||
dataUrls: dataUrls
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
} else {
|
reader.readAsDataURL(file)
|
||||||
console.warn("No valid images were processed")
|
})
|
||||||
|
})
|
||||||
|
const imageDataArray = await Promise.all(imagePromises)
|
||||||
|
const dataUrls = imageDataArray.filter((dataUrl): dataUrl is string => dataUrl !== null)
|
||||||
|
if (dataUrls.length > 0) {
|
||||||
|
setSelectedImages((prevImages) => [...prevImages, ...dataUrls].slice(0, MAX_IMAGES_PER_MESSAGE))
|
||||||
|
if (typeof vscode !== 'undefined') {
|
||||||
|
vscode.postMessage({
|
||||||
|
type: 'draggedImages',
|
||||||
|
dataUrls: dataUrls
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
console.warn("No valid images were processed")
|
||||||
}
|
}
|
||||||
}}
|
}
|
||||||
onDragOver={(e) => {
|
}}
|
||||||
e.preventDefault()
|
onDragOver={(e) => {
|
||||||
}}
|
e.preventDefault()
|
||||||
>
|
}}>
|
||||||
{showContextMenu && (
|
{showContextMenu && (
|
||||||
<div ref={contextMenuContainerRef}>
|
<div ref={contextMenuContainerRef}>
|
||||||
<ContextMenu
|
<ContextMenu
|
||||||
@@ -533,7 +598,7 @@ const ChatTextArea = forwardRef<HTMLTextAreaElement, ChatTextAreaProps>(
|
|||||||
borderTop: 0,
|
borderTop: 0,
|
||||||
borderColor: "transparent",
|
borderColor: "transparent",
|
||||||
borderBottom: `${thumbnailsHeight + 6}px solid transparent`,
|
borderBottom: `${thumbnailsHeight + 6}px solid transparent`,
|
||||||
padding: "9px 49px 3px 9px",
|
padding: "9px 9px 25px 9px",
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
<DynamicTextArea
|
<DynamicTextArea
|
||||||
@@ -588,11 +653,11 @@ const ChatTextArea = forwardRef<HTMLTextAreaElement, ChatTextAreaProps>(
|
|||||||
borderTop: 0,
|
borderTop: 0,
|
||||||
borderBottom: `${thumbnailsHeight + 6}px solid transparent`,
|
borderBottom: `${thumbnailsHeight + 6}px solid transparent`,
|
||||||
borderColor: "transparent",
|
borderColor: "transparent",
|
||||||
|
padding: "9px 9px 25px 9px",
|
||||||
// borderRight: "54px solid transparent",
|
// borderRight: "54px solid transparent",
|
||||||
// borderLeft: "9px solid transparent", // NOTE: react-textarea-autosize doesn't calculate correct height when using borderLeft/borderRight so we need to use horizontal padding instead
|
// borderLeft: "9px solid transparent", // NOTE: react-textarea-autosize doesn't calculate correct height when using borderLeft/borderRight so we need to use horizontal padding instead
|
||||||
// Instead of using boxShadow, we use a div with a border to better replicate the behavior when the textarea is focused
|
// Instead of using boxShadow, we use a div with a border to better replicate the behavior when the textarea is focused
|
||||||
// boxShadow: "0px 0px 0px 1px var(--vscode-input-border)",
|
// boxShadow: "0px 0px 0px 1px var(--vscode-input-border)",
|
||||||
padding: "9px 49px 3px 9px",
|
|
||||||
cursor: textAreaDisabled ? "not-allowed" : undefined,
|
cursor: textAreaDisabled ? "not-allowed" : undefined,
|
||||||
flex: 1,
|
flex: 1,
|
||||||
zIndex: 1,
|
zIndex: 1,
|
||||||
@@ -609,45 +674,29 @@ const ChatTextArea = forwardRef<HTMLTextAreaElement, ChatTextAreaProps>(
|
|||||||
paddingTop: 4,
|
paddingTop: 4,
|
||||||
bottom: 14,
|
bottom: 14,
|
||||||
left: 22,
|
left: 22,
|
||||||
right: 67, // (54 + 9) + 4 extra padding
|
right: 67,
|
||||||
zIndex: 2,
|
zIndex: 2,
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
<div
|
<div className="button-row" style={{ position: "absolute", right: 20, display: "flex", alignItems: "center", height: 31, bottom: 8, zIndex: 2, justifyContent: "flex-end" }}>
|
||||||
style={{
|
<span style={{ display: "flex", alignItems: "center", gap: 12 }}>
|
||||||
position: "absolute",
|
{apiConfiguration?.apiProvider === "openrouter" && (
|
||||||
right: 28,
|
<div style={{ display: "flex", alignItems: "center" }}>
|
||||||
display: "flex",
|
{isEnhancingPrompt && <span style={{ marginRight: 10, color: "var(--vscode-input-foreground)", opacity: 0.5 }}>Enhancing prompt...</span>}
|
||||||
alignItems: "flex-end",
|
<span
|
||||||
height: textAreaBaseHeight || 31,
|
role="button"
|
||||||
bottom: 18,
|
aria-label="enhance prompt"
|
||||||
zIndex: 2,
|
data-testid="enhance-prompt-button"
|
||||||
}}>
|
className={`input-icon-button ${textAreaDisabled ? "disabled" : ""} codicon codicon-sparkle`}
|
||||||
<div style={{ display: "flex", flexDirection: "row", alignItems: "center" }}>
|
onClick={() => !textAreaDisabled && handleEnhancePrompt()}
|
||||||
<div
|
style={{ fontSize: 16.5 }}
|
||||||
className={`input-icon-button ${
|
/>
|
||||||
shouldDisableImages ? "disabled" : ""
|
</div>
|
||||||
} codicon codicon-device-camera`}
|
)}
|
||||||
onClick={() => {
|
<span className={`input-icon-button ${shouldDisableImages ? "disabled" : ""} codicon codicon-device-camera`} onClick={() => !shouldDisableImages && onSelectImages()} style={{ fontSize: 16.5 }} />
|
||||||
if (!shouldDisableImages) {
|
<span className={`input-icon-button ${textAreaDisabled ? "disabled" : ""} codicon codicon-send`} onClick={() => !textAreaDisabled && onSend()} style={{ fontSize: 15 }} />
|
||||||
onSelectImages()
|
</span>
|
||||||
}
|
|
||||||
}}
|
|
||||||
style={{
|
|
||||||
marginRight: 5.5,
|
|
||||||
fontSize: 16.5,
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
<div
|
|
||||||
className={`input-icon-button ${textAreaDisabled ? "disabled" : ""} codicon codicon-send`}
|
|
||||||
onClick={() => {
|
|
||||||
if (!textAreaDisabled) {
|
|
||||||
onSend()
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
style={{ fontSize: 15 }}></div>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
|
|||||||
185
webview-ui/src/components/chat/__tests__/ChatTextArea.test.tsx
Normal file
185
webview-ui/src/components/chat/__tests__/ChatTextArea.test.tsx
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
/* eslint-disable import/first */
|
||||||
|
import React from 'react';
|
||||||
|
import { render, fireEvent, screen } from '@testing-library/react';
|
||||||
|
import '@testing-library/jest-dom';
|
||||||
|
import ChatTextArea from '../ChatTextArea';
|
||||||
|
import { useExtensionState } from '../../../context/ExtensionStateContext';
|
||||||
|
import { vscode } from '../../../utils/vscode';
|
||||||
|
|
||||||
|
// Mock modules
|
||||||
|
jest.mock('../../../utils/vscode', () => ({
|
||||||
|
vscode: {
|
||||||
|
postMessage: jest.fn()
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
jest.mock('../../../components/common/CodeBlock');
|
||||||
|
jest.mock('../../../components/common/MarkdownBlock');
|
||||||
|
|
||||||
|
// Get the mocked postMessage function
|
||||||
|
const mockPostMessage = vscode.postMessage as jest.Mock;
|
||||||
|
/* eslint-enable import/first */
|
||||||
|
|
||||||
|
// Mock ExtensionStateContext
|
||||||
|
jest.mock('../../../context/ExtensionStateContext');
|
||||||
|
|
||||||
|
describe('ChatTextArea', () => {
|
||||||
|
const defaultProps = {
|
||||||
|
inputValue: '',
|
||||||
|
setInputValue: jest.fn(),
|
||||||
|
onSend: jest.fn(),
|
||||||
|
textAreaDisabled: false,
|
||||||
|
onSelectImages: jest.fn(),
|
||||||
|
shouldDisableImages: false,
|
||||||
|
placeholderText: 'Type a message...',
|
||||||
|
selectedImages: [],
|
||||||
|
setSelectedImages: jest.fn(),
|
||||||
|
onHeightChange: jest.fn(),
|
||||||
|
};
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
// Default mock implementation for useExtensionState
|
||||||
|
(useExtensionState as jest.Mock).mockReturnValue({
|
||||||
|
filePaths: [],
|
||||||
|
apiConfiguration: {
|
||||||
|
apiProvider: 'anthropic',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('enhance prompt button', () => {
|
||||||
|
it('should show enhance prompt button only when apiProvider is openrouter', () => {
|
||||||
|
// Test with non-openrouter provider
|
||||||
|
(useExtensionState as jest.Mock).mockReturnValue({
|
||||||
|
filePaths: [],
|
||||||
|
apiConfiguration: {
|
||||||
|
apiProvider: 'anthropic',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const { rerender } = render(<ChatTextArea {...defaultProps} />);
|
||||||
|
expect(screen.queryByTestId('enhance-prompt-button')).not.toBeInTheDocument();
|
||||||
|
|
||||||
|
// Test with openrouter provider
|
||||||
|
(useExtensionState as jest.Mock).mockReturnValue({
|
||||||
|
filePaths: [],
|
||||||
|
apiConfiguration: {
|
||||||
|
apiProvider: 'openrouter',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
rerender(<ChatTextArea {...defaultProps} />);
|
||||||
|
const enhanceButton = screen.getByRole('button', { name: /enhance prompt/i });
|
||||||
|
expect(enhanceButton).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should be disabled when textAreaDisabled is true', () => {
|
||||||
|
(useExtensionState as jest.Mock).mockReturnValue({
|
||||||
|
filePaths: [],
|
||||||
|
apiConfiguration: {
|
||||||
|
apiProvider: 'openrouter',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<ChatTextArea {...defaultProps} textAreaDisabled={true} />);
|
||||||
|
const enhanceButton = screen.getByRole('button', { name: /enhance prompt/i });
|
||||||
|
expect(enhanceButton).toHaveClass('disabled');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('handleEnhancePrompt', () => {
|
||||||
|
it('should send message with correct configuration when clicked', () => {
|
||||||
|
const apiConfiguration = {
|
||||||
|
apiProvider: 'openrouter',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
};
|
||||||
|
|
||||||
|
(useExtensionState as jest.Mock).mockReturnValue({
|
||||||
|
filePaths: [],
|
||||||
|
apiConfiguration,
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<ChatTextArea {...defaultProps} inputValue="Test prompt" />);
|
||||||
|
|
||||||
|
const enhanceButton = screen.getByRole('button', { name: /enhance prompt/i });
|
||||||
|
fireEvent.click(enhanceButton);
|
||||||
|
|
||||||
|
expect(mockPostMessage).toHaveBeenCalledWith({
|
||||||
|
type: 'enhancePrompt',
|
||||||
|
text: 'Test prompt',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not send message when input is empty', () => {
|
||||||
|
(useExtensionState as jest.Mock).mockReturnValue({
|
||||||
|
filePaths: [],
|
||||||
|
apiConfiguration: {
|
||||||
|
apiProvider: 'openrouter',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<ChatTextArea {...defaultProps} inputValue="" />);
|
||||||
|
|
||||||
|
const enhanceButton = screen.getByRole('button', { name: /enhance prompt/i });
|
||||||
|
fireEvent.click(enhanceButton);
|
||||||
|
|
||||||
|
expect(mockPostMessage).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should show loading state while enhancing', () => {
|
||||||
|
(useExtensionState as jest.Mock).mockReturnValue({
|
||||||
|
filePaths: [],
|
||||||
|
apiConfiguration: {
|
||||||
|
apiProvider: 'openrouter',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<ChatTextArea {...defaultProps} inputValue="Test prompt" />);
|
||||||
|
|
||||||
|
const enhanceButton = screen.getByRole('button', { name: /enhance prompt/i });
|
||||||
|
fireEvent.click(enhanceButton);
|
||||||
|
|
||||||
|
expect(screen.getByText('Enhancing prompt...')).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('effect dependencies', () => {
|
||||||
|
it('should update when apiConfiguration changes', () => {
|
||||||
|
const { rerender } = render(<ChatTextArea {...defaultProps} />);
|
||||||
|
|
||||||
|
// Update apiConfiguration
|
||||||
|
(useExtensionState as jest.Mock).mockReturnValue({
|
||||||
|
filePaths: [],
|
||||||
|
apiConfiguration: {
|
||||||
|
apiProvider: 'openrouter',
|
||||||
|
newSetting: 'test',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
rerender(<ChatTextArea {...defaultProps} />);
|
||||||
|
|
||||||
|
// Verify the enhance button appears after apiConfiguration changes
|
||||||
|
expect(screen.getByRole('button', { name: /enhance prompt/i })).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('enhanced prompt response', () => {
|
||||||
|
it('should update input value when receiving enhanced prompt', () => {
|
||||||
|
const setInputValue = jest.fn();
|
||||||
|
|
||||||
|
render(<ChatTextArea {...defaultProps} setInputValue={setInputValue} />);
|
||||||
|
|
||||||
|
// Simulate receiving enhanced prompt message
|
||||||
|
window.dispatchEvent(
|
||||||
|
new MessageEvent('message', {
|
||||||
|
data: {
|
||||||
|
type: 'enhancedPrompt',
|
||||||
|
text: 'Enhanced test prompt',
|
||||||
|
},
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(setInputValue).toHaveBeenCalledWith('Enhanced test prompt');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
12
webview-ui/src/components/common/__mocks__/CodeBlock.tsx
Normal file
12
webview-ui/src/components/common/__mocks__/CodeBlock.tsx
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
import * as React from 'react';
|
||||||
|
|
||||||
|
interface CodeBlockProps {
|
||||||
|
children?: React.ReactNode;
|
||||||
|
language?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const CodeBlock: React.FC<CodeBlockProps> = () => (
|
||||||
|
<div data-testid="mock-code-block">Mocked Code Block</div>
|
||||||
|
);
|
||||||
|
|
||||||
|
export default CodeBlock;
|
||||||
12
webview-ui/src/components/common/__mocks__/MarkdownBlock.tsx
Normal file
12
webview-ui/src/components/common/__mocks__/MarkdownBlock.tsx
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
import * as React from 'react';
|
||||||
|
|
||||||
|
interface MarkdownBlockProps {
|
||||||
|
children?: React.ReactNode;
|
||||||
|
content?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const MarkdownBlock: React.FC<MarkdownBlockProps> = ({ content }) => (
|
||||||
|
<div data-testid="mock-markdown-block">{content}</div>
|
||||||
|
);
|
||||||
|
|
||||||
|
export default MarkdownBlock;
|
||||||
Reference in New Issue
Block a user