diff --git a/src/api/index.ts b/src/api/index.ts index ec35c2a..3fee5c4 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -11,6 +11,10 @@ import { GeminiHandler } from "./providers/gemini" import { OpenAiNativeHandler } from "./providers/openai-native" import { ApiStream } from "./transform/stream" +export interface SingleCompletionHandler { + completePrompt(prompt: string): Promise +} + export interface ApiHandler { createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream getModel(): { id: string; info: ModelInfo } diff --git a/src/api/providers/__tests__/openrouter.test.ts b/src/api/providers/__tests__/openrouter.test.ts index d2df3ee..fb24516 100644 --- a/src/api/providers/__tests__/openrouter.test.ts +++ b/src/api/providers/__tests__/openrouter.test.ts @@ -51,6 +51,14 @@ describe('OpenRouterHandler', () => { }) }) + test('getModel returns default model info when options are not provided', () => { + const handler = new OpenRouterHandler({}) + const result = handler.getModel() + + expect(result.id).toBe('anthropic/claude-3.5-sonnet:beta') + expect(result.info.supportsPromptCache).toBe(true) + }) + test('createMessage generates correct stream chunks', async () => { const handler = new OpenRouterHandler(mockOptions) const mockStream = { @@ -118,4 +126,158 @@ describe('OpenRouterHandler', () => { stream: true })) }) + + test('createMessage with middle-out transform enabled', async () => { + const handler = new OpenRouterHandler({ + ...mockOptions, + openRouterUseMiddleOutTransform: true + }) + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + id: 'test-id', + choices: [{ + delta: { + content: 'test response' + } + }] + } + } + } + + const mockCreate = jest.fn().mockResolvedValue(mockStream) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate } + } as any + ;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } }) + + await handler.createMessage('test', []).next() + + expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ + transforms: ['middle-out'] + })) + }) + + test('createMessage with Claude model adds cache control', async () => { + const handler = new OpenRouterHandler({ + ...mockOptions, + openRouterModelId: 'anthropic/claude-3.5-sonnet' + }) + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + id: 'test-id', + choices: [{ + delta: { + content: 'test response' + } + }] + } + } + } + + const mockCreate = jest.fn().mockResolvedValue(mockStream) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate } + } as any + ;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } }) + + const messages: Anthropic.Messages.MessageParam[] = [ + { role: 'user', content: 'message 1' }, + { role: 'assistant', content: 'response 1' }, + { role: 'user', content: 'message 2' } + ] + + await handler.createMessage('test system', messages).next() + + expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ + messages: expect.arrayContaining([ + expect.objectContaining({ + role: 'system', + content: expect.arrayContaining([ + expect.objectContaining({ + cache_control: { type: 'ephemeral' } + }) + ]) + }) + ]) + })) + }) + + test('createMessage handles API errors', async () => { + const handler = new OpenRouterHandler(mockOptions) + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + error: { + message: 'API Error', + code: 500 + } + } + } + } + + const mockCreate = jest.fn().mockResolvedValue(mockStream) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate } + } as any + + const generator = handler.createMessage('test', []) + await expect(generator.next()).rejects.toThrow('OpenRouter API Error 500: API Error') + }) + + test('completePrompt returns correct response', async () => { + const handler = new OpenRouterHandler(mockOptions) + const mockResponse = { + choices: [{ + message: { + content: 'test completion' + } + }] + } + + const mockCreate = jest.fn().mockResolvedValue(mockResponse) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate } + } as any + + const result = await handler.completePrompt('test prompt') + + expect(result).toBe('test completion') + expect(mockCreate).toHaveBeenCalledWith({ + model: mockOptions.openRouterModelId, + messages: [{ role: 'user', content: 'test prompt' }], + temperature: 0, + stream: false + }) + }) + + test('completePrompt handles API errors', async () => { + const handler = new OpenRouterHandler(mockOptions) + const mockError = { + error: { + message: 'API Error', + code: 500 + } + } + + const mockCreate = jest.fn().mockResolvedValue(mockError) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate } + } as any + + await expect(handler.completePrompt('test prompt')) + .rejects.toThrow('OpenRouter API Error 500: API Error') + }) + + test('completePrompt handles unexpected errors', async () => { + const handler = new OpenRouterHandler(mockOptions) + const mockCreate = jest.fn().mockRejectedValue(new Error('Unexpected error')) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate } + } as any + + await expect(handler.completePrompt('test prompt')) + .rejects.toThrow('OpenRouter completion error: Unexpected error') + }) }) diff --git a/src/api/providers/openrouter.ts b/src/api/providers/openrouter.ts index c2c34d8..ccfe167 100644 --- a/src/api/providers/openrouter.ts +++ b/src/api/providers/openrouter.ts @@ -4,11 +4,11 @@ import OpenAI from "openai" import { ApiHandler } from "../" import { ApiHandlerOptions, ModelInfo, openRouterDefaultModelId, openRouterDefaultModelInfo } from "../../shared/api" import { convertToOpenAiMessages } from "../transform/openai-format" -import { ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream" +import { ApiStream, ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream" import delay from "delay" // Add custom interface for OpenRouter params -interface OpenRouterChatCompletionParams extends OpenAI.Chat.ChatCompletionCreateParamsStreaming { +type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & { transforms?: string[]; } @@ -17,7 +17,12 @@ interface OpenRouterApiStreamUsageChunk extends ApiStreamUsageChunk { fullResponseText: string; } -export class OpenRouterHandler implements ApiHandler { +// Interface for providers that support single completions +export interface SingleCompletionHandler { + completePrompt(prompt: string): Promise +} + +export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler { private options: ApiHandlerOptions private client: OpenAI @@ -184,4 +189,28 @@ export class OpenRouterHandler implements ApiHandler { } return { id: openRouterDefaultModelId, info: openRouterDefaultModelInfo } } + + async completePrompt(prompt: string): Promise { + try { + const response = await this.client.chat.completions.create({ + model: this.getModel().id, + messages: [{ role: "user", content: prompt }], + temperature: 0, + stream: false + }) + + if ("error" in response) { + const error = response.error as { message?: string; code?: number } + throw new Error(`OpenRouter API Error ${error?.code}: ${error?.message}`) + } + + const completion = response as OpenAI.Chat.ChatCompletion + return completion.choices[0]?.message?.content || "" + } catch (error) { + if (error instanceof Error) { + throw new Error(`OpenRouter completion error: ${error.message}`) + } + throw error + } + } } diff --git a/src/core/Cline.ts b/src/core/Cline.ts index 343393f..f8b76bc 100644 --- a/src/core/Cline.ts +++ b/src/core/Cline.ts @@ -8,7 +8,7 @@ import pWaitFor from "p-wait-for" import * as path from "path" import { serializeError } from "serialize-error" import * as vscode from "vscode" -import { ApiHandler, buildApiHandler } from "../api" +import { ApiHandler, SingleCompletionHandler, buildApiHandler } from "../api" import { ApiStream } from "../api/transform/stream" import { DiffViewProvider } from "../integrations/editor/DiffViewProvider" import { findToolName, formatContentBlockToMarkdown } from "../integrations/misc/export-markdown" @@ -49,6 +49,7 @@ import { truncateHalfConversation } from "./sliding-window" import { ClineProvider, GlobalFileNames } from "./webview/ClineProvider" import { detectCodeOmission } from "../integrations/editor/detect-omission" import { BrowserSession } from "../services/browser/BrowserSession" +import { OpenRouterHandler } from "../api/providers/openrouter" const cwd = vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0) ?? path.join(os.homedir(), "Desktop") // may or may not exist but fs checking existence would immediately ask for permission which would be bad UX, need to come up with a better solution @@ -126,6 +127,22 @@ export class Cline { } } + async enhancePrompt(promptText: string): Promise { + if (!promptText) { + throw new Error("No prompt text provided") + } + + const prompt = `Generate an enhanced version of this prompt (reply with only the enhanced prompt, no bullet points): ${promptText}` + + // Check if the API handler supports completePrompt + if (this.api instanceof OpenRouterHandler) { + return this.api.completePrompt(prompt) + } + + // Otherwise just return the prompt as is + return prompt; + } + // Storing task to disk for history private async ensureTaskDirectoryExists(): Promise { diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 029d77a..1702f1f 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -23,6 +23,7 @@ import { openMention } from "../mentions" import { getNonce } from "./getNonce" import { getUri } from "./getUri" import { playSound, setSoundEnabled, setSoundVolume } from "../../utils/sound" +import { enhancePrompt } from "../../utils/enhance-prompt" /* https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default/weather-webview/src/providers/WeatherViewProvider.ts @@ -637,6 +638,26 @@ export class ClineProvider implements vscode.WebviewViewProvider { await this.updateGlobalState("writeDelayMs", message.value) await this.postStateToWebview() break + case "enhancePrompt": + if (message.text) { + try { + const { apiConfiguration } = await this.getState() + const enhanceConfig = { + ...apiConfiguration, + apiProvider: "openrouter" as const, + openRouterModelId: "gpt-4o", + } + const enhancedPrompt = await enhancePrompt(enhanceConfig, message.text) + await this.postMessageToWebview({ + type: "enhancedPrompt", + text: enhancedPrompt + }) + } catch (error) { + console.error("Error enhancing prompt:", error) + vscode.window.showErrorMessage("Failed to enhance prompt") + } + } + break } }, null, diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index f17724d..92357d4 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -19,6 +19,7 @@ export interface ExtensionMessage { | "openRouterModels" | "openAiModels" | "mcpServers" + | "enhancedPrompt" text?: string action?: | "chatButtonClicked" diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index 113aae2..c0b7b1d 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -43,6 +43,9 @@ export interface WebviewMessage { | "fuzzyMatchThreshold" | "preferredLanguage" | "writeDelayMs" + | "enhancePrompt" + | "enhancedPrompt" + | "draggedImages" text?: string disabled?: boolean askResponse?: ClineAskResponse @@ -52,10 +55,10 @@ export interface WebviewMessage { value?: number commands?: string[] audioType?: AudioType - // For toggleToolAutoApprove serverName?: string toolName?: string alwaysAllow?: boolean + dataUrls?: string[] } export type ClineAskResponse = "yesButtonClicked" | "noButtonClicked" | "messageResponse" diff --git a/src/utils/__tests__/enhance-prompt.test.ts b/src/utils/__tests__/enhance-prompt.test.ts new file mode 100644 index 0000000..ab8f253 --- /dev/null +++ b/src/utils/__tests__/enhance-prompt.test.ts @@ -0,0 +1,80 @@ +import { enhancePrompt } from '../enhance-prompt' +import { buildApiHandler } from '../../api' +import { ApiConfiguration } from '../../shared/api' +import { OpenRouterHandler } from '../../api/providers/openrouter' + +// Mock the buildApiHandler function +jest.mock('../../api', () => ({ + buildApiHandler: jest.fn() +})) + +describe('enhancePrompt', () => { + const mockApiConfig: ApiConfiguration = { + apiProvider: 'openrouter', + apiKey: 'test-key', + openRouterApiKey: 'test-key', + openRouterModelId: 'test-model' + } + + // Create a mock handler that looks like OpenRouterHandler + const mockHandler = { + completePrompt: jest.fn(), + createMessage: jest.fn(), + getModel: jest.fn() + } + + // Make instanceof check work + Object.setPrototypeOf(mockHandler, OpenRouterHandler.prototype) + + beforeEach(() => { + jest.clearAllMocks() + ;(buildApiHandler as jest.Mock).mockReturnValue(mockHandler) + }) + + it('should throw error for non-OpenRouter providers', async () => { + const nonOpenRouterConfig: ApiConfiguration = { + apiProvider: 'anthropic', + apiKey: 'test-key', + apiModelId: 'claude-3' + } + await expect(enhancePrompt(nonOpenRouterConfig, 'test')).rejects.toThrow('Prompt enhancement is only available with OpenRouter') + }) + + it('should enhance a valid prompt', async () => { + const inputPrompt = 'Write a function to sort an array' + const enhancedPrompt = 'Write a TypeScript function that implements an efficient sorting algorithm for a generic array, including error handling and type safety' + + mockHandler.completePrompt.mockResolvedValue(enhancedPrompt) + + const result = await enhancePrompt(mockApiConfig, inputPrompt) + + expect(result).toBe(enhancedPrompt) + expect(buildApiHandler).toHaveBeenCalledWith(mockApiConfig) + expect(mockHandler.completePrompt).toHaveBeenCalledWith( + expect.stringContaining(inputPrompt) + ) + }) + + it('should throw error when no prompt text is provided', async () => { + await expect(enhancePrompt(mockApiConfig, '')).rejects.toThrow('No prompt text provided') + expect(mockHandler.completePrompt).not.toHaveBeenCalled() + }) + + it('should pass through API errors', async () => { + const inputPrompt = 'Test prompt' + mockHandler.completePrompt.mockRejectedValue('API error') + + await expect(enhancePrompt(mockApiConfig, inputPrompt)).rejects.toBe('API error') + }) + + it('should pass the correct prompt format to the API', async () => { + const inputPrompt = 'Test prompt' + mockHandler.completePrompt.mockResolvedValue('Enhanced test prompt') + + await enhancePrompt(mockApiConfig, inputPrompt) + + expect(mockHandler.completePrompt).toHaveBeenCalledWith( + 'Generate an enhanced version of this prompt (reply with only the enhanced prompt, no other text or bullet points): Test prompt' + ) + }) +}) \ No newline at end of file diff --git a/src/utils/enhance-prompt.ts b/src/utils/enhance-prompt.ts new file mode 100644 index 0000000..3c68ced --- /dev/null +++ b/src/utils/enhance-prompt.ts @@ -0,0 +1,26 @@ +import { ApiConfiguration } from "../shared/api" +import { buildApiHandler } from "../api" +import { OpenRouterHandler, SingleCompletionHandler } from "../api/providers/openrouter" + +/** + * Enhances a prompt using the OpenRouter API without creating a full Cline instance or task history. + * This is a lightweight alternative that only uses the API's completion functionality. + */ +export async function enhancePrompt(apiConfiguration: ApiConfiguration, promptText: string): Promise { + if (!promptText) { + throw new Error("No prompt text provided") + } + if (apiConfiguration.apiProvider !== "openrouter") { + throw new Error("Prompt enhancement is only available with OpenRouter") + } + + const handler = buildApiHandler(apiConfiguration) + + // Type guard to check if handler is OpenRouterHandler + if (!(handler instanceof OpenRouterHandler)) { + throw new Error("Expected OpenRouter handler") + } + + const prompt = `Generate an enhanced version of this prompt (reply with only the enhanced prompt, no other text or bullet points): ${promptText}` + return handler.completePrompt(prompt) +} \ No newline at end of file diff --git a/webview-ui/src/components/chat/ChatTextArea.tsx b/webview-ui/src/components/chat/ChatTextArea.tsx index f48380b..39ec9f5 100644 --- a/webview-ui/src/components/chat/ChatTextArea.tsx +++ b/webview-ui/src/components/chat/ChatTextArea.tsx @@ -13,7 +13,7 @@ import { MAX_IMAGES_PER_MESSAGE } from "./ChatView" import ContextMenu from "./ContextMenu" import Thumbnails from "../common/Thumbnails" -declare const vscode: any; +import { vscode } from "../../utils/vscode" interface ChatTextAreaProps { inputValue: string @@ -44,8 +44,20 @@ const ChatTextArea = forwardRef( }, ref, ) => { - const { filePaths } = useExtensionState() + const { filePaths, apiConfiguration } = useExtensionState() const [isTextAreaFocused, setIsTextAreaFocused] = useState(false) + + // Handle enhanced prompt response + useEffect(() => { + const messageHandler = (event: MessageEvent) => { + const message = event.data + if (message.type === 'enhancedPrompt' && message.text) { + setInputValue(message.text) + } + } + window.addEventListener('message', messageHandler) + return () => window.removeEventListener('message', messageHandler) + }, [setInputValue]) const [thumbnailsHeight, setThumbnailsHeight] = useState(0) const [textAreaBaseHeight, setTextAreaBaseHeight] = useState(undefined) const [showContextMenu, setShowContextMenu] = useState(false) @@ -60,6 +72,63 @@ const ChatTextArea = forwardRef( const [intendedCursorPosition, setIntendedCursorPosition] = useState(null) const contextMenuContainerRef = useRef(null) + const [isEnhancingPrompt, setIsEnhancingPrompt] = useState(false) + + const handleEnhancePrompt = useCallback(() => { + if (!textAreaDisabled) { + const trimmedInput = inputValue.trim() + if (trimmedInput) { + setIsEnhancingPrompt(true) + const message = { + type: "enhancePrompt" as const, + text: trimmedInput, + } + vscode.postMessage(message) + } else { + const promptDescription = "The 'Enhance Prompt' button helps improve your prompt by providing additional context, clarification, or rephrasing. Try typing a prompt in here and clicking the button again to see how it works." + setInputValue(promptDescription) + } + } + }, [inputValue, textAreaDisabled, setInputValue]) + + useEffect(() => { + const messageHandler = (event: MessageEvent) => { + const message = event.data + if (message.type === 'enhancedPrompt') { + setInputValue(message.text) + setIsEnhancingPrompt(false) + } + } + window.addEventListener('message', messageHandler) + return () => window.removeEventListener('message', messageHandler) + }, [setInputValue]) + + // Handle enhanced prompt response + useEffect(() => { + const messageHandler = (event: MessageEvent) => { + const message = event.data + if (message.type === 'enhancedPrompt') { + setInputValue(message.text) + } + } + window.addEventListener('message', messageHandler) + return () => { + window.removeEventListener('message', messageHandler) + } + }, [setInputValue]) + + // Handle enhanced prompt response + useEffect(() => { + const messageHandler = (event: MessageEvent) => { + const message = event.data + if (message.type === 'enhancedPrompt' && message.text) { + setInputValue(message.text) + } + } + window.addEventListener('message', messageHandler) + return () => window.removeEventListener('message', messageHandler) + }, [setInputValue]) + const queryItems = useMemo(() => { return [ { type: ContextMenuOptionType.Problems, value: "problems" }, @@ -423,68 +492,64 @@ const ChatTextArea = forwardRef( ) return ( -
{ - console.log("onDrop called") - e.preventDefault() - const files = Array.from(e.dataTransfer.files) - const text = e.dataTransfer.getData("text") - if (text) { - const newValue = - inputValue.slice(0, cursorPosition) + text + inputValue.slice(cursorPosition) - setInputValue(newValue) - const newCursorPosition = cursorPosition + text.length - setCursorPosition(newCursorPosition) - setIntendedCursorPosition(newCursorPosition) - return - } - const acceptedTypes = ["png", "jpeg", "webp"] - const imageFiles = files.filter((file) => { - const [type, subtype] = file.type.split("/") - return type === "image" && acceptedTypes.includes(subtype) - }) - if (!shouldDisableImages && imageFiles.length > 0) { - const imagePromises = imageFiles.map((file) => { - return new Promise((resolve) => { - const reader = new FileReader() - reader.onloadend = () => { - if (reader.error) { - console.error("Error reading file:", reader.error) - resolve(null) - } else { - const result = reader.result - console.log("File read successfully", result) - resolve(typeof result === "string" ? result : null) - } +
{ + e.preventDefault() + const files = Array.from(e.dataTransfer.files) + const text = e.dataTransfer.getData("text") + if (text) { + const newValue = + inputValue.slice(0, cursorPosition) + text + inputValue.slice(cursorPosition) + setInputValue(newValue) + const newCursorPosition = cursorPosition + text.length + setCursorPosition(newCursorPosition) + setIntendedCursorPosition(newCursorPosition) + return + } + const acceptedTypes = ["png", "jpeg", "webp"] + const imageFiles = files.filter((file) => { + const [type, subtype] = file.type.split("/") + return type === "image" && acceptedTypes.includes(subtype) + }) + if (!shouldDisableImages && imageFiles.length > 0) { + const imagePromises = imageFiles.map((file) => { + return new Promise((resolve) => { + const reader = new FileReader() + reader.onloadend = () => { + if (reader.error) { + console.error("Error reading file:", reader.error) + resolve(null) + } else { + const result = reader.result + resolve(typeof result === "string" ? result : null) } - reader.readAsDataURL(file) - }) - }) - const imageDataArray = await Promise.all(imagePromises) - const dataUrls = imageDataArray.filter((dataUrl): dataUrl is string => dataUrl !== null) - if (dataUrls.length > 0) { - setSelectedImages((prevImages) => [...prevImages, ...dataUrls].slice(0, MAX_IMAGES_PER_MESSAGE)) - if (typeof vscode !== 'undefined') { - vscode.postMessage({ - type: 'draggedImages', - dataUrls: dataUrls - }) } - } else { - console.warn("No valid images were processed") + reader.readAsDataURL(file) + }) + }) + const imageDataArray = await Promise.all(imagePromises) + const dataUrls = imageDataArray.filter((dataUrl): dataUrl is string => dataUrl !== null) + if (dataUrls.length > 0) { + setSelectedImages((prevImages) => [...prevImages, ...dataUrls].slice(0, MAX_IMAGES_PER_MESSAGE)) + if (typeof vscode !== 'undefined') { + vscode.postMessage({ + type: 'draggedImages', + dataUrls: dataUrls + }) } + } else { + console.warn("No valid images were processed") } - }} - onDragOver={(e) => { - e.preventDefault() - }} - > + } + }} + onDragOver={(e) => { + e.preventDefault() + }}> {showContextMenu && (
( borderTop: 0, borderColor: "transparent", borderBottom: `${thumbnailsHeight + 6}px solid transparent`, - padding: "9px 49px 3px 9px", + padding: "9px 9px 25px 9px", }} /> ( borderTop: 0, borderBottom: `${thumbnailsHeight + 6}px solid transparent`, borderColor: "transparent", + padding: "9px 9px 25px 9px", // borderRight: "54px solid transparent", // borderLeft: "9px solid transparent", // NOTE: react-textarea-autosize doesn't calculate correct height when using borderLeft/borderRight so we need to use horizontal padding instead // Instead of using boxShadow, we use a div with a border to better replicate the behavior when the textarea is focused // boxShadow: "0px 0px 0px 1px var(--vscode-input-border)", - padding: "9px 49px 3px 9px", cursor: textAreaDisabled ? "not-allowed" : undefined, flex: 1, zIndex: 1, @@ -609,45 +674,29 @@ const ChatTextArea = forwardRef( paddingTop: 4, bottom: 14, left: 22, - right: 67, // (54 + 9) + 4 extra padding + right: 67, zIndex: 2, }} /> )} -
-
-
{ - if (!shouldDisableImages) { - onSelectImages() - } - }} - style={{ - marginRight: 5.5, - fontSize: 16.5, - }} - /> -
{ - if (!textAreaDisabled) { - onSend() - } - }} - style={{ fontSize: 15 }}>
-
+
+ + {apiConfiguration?.apiProvider === "openrouter" && ( +
+ {isEnhancingPrompt && Enhancing prompt...} + !textAreaDisabled && handleEnhancePrompt()} + style={{ fontSize: 16.5 }} + /> +
+ )} + !shouldDisableImages && onSelectImages()} style={{ fontSize: 16.5 }} /> + !textAreaDisabled && onSend()} style={{ fontSize: 15 }} /> +
) diff --git a/webview-ui/src/components/chat/__tests__/ChatTextArea.test.tsx b/webview-ui/src/components/chat/__tests__/ChatTextArea.test.tsx new file mode 100644 index 0000000..b6b2323 --- /dev/null +++ b/webview-ui/src/components/chat/__tests__/ChatTextArea.test.tsx @@ -0,0 +1,185 @@ +/* eslint-disable import/first */ +import React from 'react'; +import { render, fireEvent, screen } from '@testing-library/react'; +import '@testing-library/jest-dom'; +import ChatTextArea from '../ChatTextArea'; +import { useExtensionState } from '../../../context/ExtensionStateContext'; +import { vscode } from '../../../utils/vscode'; + +// Mock modules +jest.mock('../../../utils/vscode', () => ({ + vscode: { + postMessage: jest.fn() + } +})); +jest.mock('../../../components/common/CodeBlock'); +jest.mock('../../../components/common/MarkdownBlock'); + +// Get the mocked postMessage function +const mockPostMessage = vscode.postMessage as jest.Mock; +/* eslint-enable import/first */ + +// Mock ExtensionStateContext +jest.mock('../../../context/ExtensionStateContext'); + +describe('ChatTextArea', () => { + const defaultProps = { + inputValue: '', + setInputValue: jest.fn(), + onSend: jest.fn(), + textAreaDisabled: false, + onSelectImages: jest.fn(), + shouldDisableImages: false, + placeholderText: 'Type a message...', + selectedImages: [], + setSelectedImages: jest.fn(), + onHeightChange: jest.fn(), + }; + + beforeEach(() => { + jest.clearAllMocks(); + // Default mock implementation for useExtensionState + (useExtensionState as jest.Mock).mockReturnValue({ + filePaths: [], + apiConfiguration: { + apiProvider: 'anthropic', + }, + }); + }); + + describe('enhance prompt button', () => { + it('should show enhance prompt button only when apiProvider is openrouter', () => { + // Test with non-openrouter provider + (useExtensionState as jest.Mock).mockReturnValue({ + filePaths: [], + apiConfiguration: { + apiProvider: 'anthropic', + }, + }); + + const { rerender } = render(); + expect(screen.queryByTestId('enhance-prompt-button')).not.toBeInTheDocument(); + + // Test with openrouter provider + (useExtensionState as jest.Mock).mockReturnValue({ + filePaths: [], + apiConfiguration: { + apiProvider: 'openrouter', + }, + }); + + rerender(); + const enhanceButton = screen.getByRole('button', { name: /enhance prompt/i }); + expect(enhanceButton).toBeInTheDocument(); + }); + + it('should be disabled when textAreaDisabled is true', () => { + (useExtensionState as jest.Mock).mockReturnValue({ + filePaths: [], + apiConfiguration: { + apiProvider: 'openrouter', + }, + }); + + render(); + const enhanceButton = screen.getByRole('button', { name: /enhance prompt/i }); + expect(enhanceButton).toHaveClass('disabled'); + }); + }); + + describe('handleEnhancePrompt', () => { + it('should send message with correct configuration when clicked', () => { + const apiConfiguration = { + apiProvider: 'openrouter', + apiKey: 'test-key', + }; + + (useExtensionState as jest.Mock).mockReturnValue({ + filePaths: [], + apiConfiguration, + }); + + render(); + + const enhanceButton = screen.getByRole('button', { name: /enhance prompt/i }); + fireEvent.click(enhanceButton); + + expect(mockPostMessage).toHaveBeenCalledWith({ + type: 'enhancePrompt', + text: 'Test prompt', + }); + }); + + it('should not send message when input is empty', () => { + (useExtensionState as jest.Mock).mockReturnValue({ + filePaths: [], + apiConfiguration: { + apiProvider: 'openrouter', + }, + }); + + render(); + + const enhanceButton = screen.getByRole('button', { name: /enhance prompt/i }); + fireEvent.click(enhanceButton); + + expect(mockPostMessage).not.toHaveBeenCalled(); + }); + + it('should show loading state while enhancing', () => { + (useExtensionState as jest.Mock).mockReturnValue({ + filePaths: [], + apiConfiguration: { + apiProvider: 'openrouter', + }, + }); + + render(); + + const enhanceButton = screen.getByRole('button', { name: /enhance prompt/i }); + fireEvent.click(enhanceButton); + + expect(screen.getByText('Enhancing prompt...')).toBeInTheDocument(); + }); + }); + + describe('effect dependencies', () => { + it('should update when apiConfiguration changes', () => { + const { rerender } = render(); + + // Update apiConfiguration + (useExtensionState as jest.Mock).mockReturnValue({ + filePaths: [], + apiConfiguration: { + apiProvider: 'openrouter', + newSetting: 'test', + }, + }); + + rerender(); + + // Verify the enhance button appears after apiConfiguration changes + expect(screen.getByRole('button', { name: /enhance prompt/i })).toBeInTheDocument(); + }); + }); + + describe('enhanced prompt response', () => { + it('should update input value when receiving enhanced prompt', () => { + const setInputValue = jest.fn(); + + render(); + + // Simulate receiving enhanced prompt message + window.dispatchEvent( + new MessageEvent('message', { + data: { + type: 'enhancedPrompt', + text: 'Enhanced test prompt', + }, + }) + ); + + expect(setInputValue).toHaveBeenCalledWith('Enhanced test prompt'); + }); + }); +}); \ No newline at end of file diff --git a/webview-ui/src/components/common/__mocks__/CodeBlock.tsx b/webview-ui/src/components/common/__mocks__/CodeBlock.tsx new file mode 100644 index 0000000..e261cd0 --- /dev/null +++ b/webview-ui/src/components/common/__mocks__/CodeBlock.tsx @@ -0,0 +1,12 @@ +import * as React from 'react'; + +interface CodeBlockProps { + children?: React.ReactNode; + language?: string; +} + +const CodeBlock: React.FC = () => ( +
Mocked Code Block
+); + +export default CodeBlock; \ No newline at end of file diff --git a/webview-ui/src/components/common/__mocks__/MarkdownBlock.tsx b/webview-ui/src/components/common/__mocks__/MarkdownBlock.tsx new file mode 100644 index 0000000..8aee781 --- /dev/null +++ b/webview-ui/src/components/common/__mocks__/MarkdownBlock.tsx @@ -0,0 +1,12 @@ +import * as React from 'react'; + +interface MarkdownBlockProps { + children?: React.ReactNode; + content?: string; +} + +const MarkdownBlock: React.FC = ({ content }) => ( +
{content}
+); + +export default MarkdownBlock; \ No newline at end of file