Enhance prompt button for openrouter

This commit is contained in:
Matt Rubens
2024-12-23 23:24:49 -08:00
parent 1581ed135b
commit 111abdbb2c
13 changed files with 703 additions and 102 deletions

View File

@@ -11,6 +11,10 @@ import { GeminiHandler } from "./providers/gemini"
import { OpenAiNativeHandler } from "./providers/openai-native"
import { ApiStream } from "./transform/stream"
export interface SingleCompletionHandler {
completePrompt(prompt: string): Promise<string>
}
export interface ApiHandler {
createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream
getModel(): { id: string; info: ModelInfo }

View File

@@ -51,6 +51,14 @@ describe('OpenRouterHandler', () => {
})
})
test('getModel returns default model info when options are not provided', () => {
const handler = new OpenRouterHandler({})
const result = handler.getModel()
expect(result.id).toBe('anthropic/claude-3.5-sonnet:beta')
expect(result.info.supportsPromptCache).toBe(true)
})
test('createMessage generates correct stream chunks', async () => {
const handler = new OpenRouterHandler(mockOptions)
const mockStream = {
@@ -118,4 +126,158 @@ describe('OpenRouterHandler', () => {
stream: true
}))
})
test('createMessage with middle-out transform enabled', async () => {
const handler = new OpenRouterHandler({
...mockOptions,
openRouterUseMiddleOutTransform: true
})
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
id: 'test-id',
choices: [{
delta: {
content: 'test response'
}
}]
}
}
}
const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
await handler.createMessage('test', []).next()
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
transforms: ['middle-out']
}))
})
test('createMessage with Claude model adds cache control', async () => {
const handler = new OpenRouterHandler({
...mockOptions,
openRouterModelId: 'anthropic/claude-3.5-sonnet'
})
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
id: 'test-id',
choices: [{
delta: {
content: 'test response'
}
}]
}
}
}
const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
const messages: Anthropic.Messages.MessageParam[] = [
{ role: 'user', content: 'message 1' },
{ role: 'assistant', content: 'response 1' },
{ role: 'user', content: 'message 2' }
]
await handler.createMessage('test system', messages).next()
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
messages: expect.arrayContaining([
expect.objectContaining({
role: 'system',
content: expect.arrayContaining([
expect.objectContaining({
cache_control: { type: 'ephemeral' }
})
])
})
])
}))
})
test('createMessage handles API errors', async () => {
const handler = new OpenRouterHandler(mockOptions)
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
error: {
message: 'API Error',
code: 500
}
}
}
}
const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
const generator = handler.createMessage('test', [])
await expect(generator.next()).rejects.toThrow('OpenRouter API Error 500: API Error')
})
test('completePrompt returns correct response', async () => {
const handler = new OpenRouterHandler(mockOptions)
const mockResponse = {
choices: [{
message: {
content: 'test completion'
}
}]
}
const mockCreate = jest.fn().mockResolvedValue(mockResponse)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
const result = await handler.completePrompt('test prompt')
expect(result).toBe('test completion')
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.openRouterModelId,
messages: [{ role: 'user', content: 'test prompt' }],
temperature: 0,
stream: false
})
})
test('completePrompt handles API errors', async () => {
const handler = new OpenRouterHandler(mockOptions)
const mockError = {
error: {
message: 'API Error',
code: 500
}
}
const mockCreate = jest.fn().mockResolvedValue(mockError)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
await expect(handler.completePrompt('test prompt'))
.rejects.toThrow('OpenRouter API Error 500: API Error')
})
test('completePrompt handles unexpected errors', async () => {
const handler = new OpenRouterHandler(mockOptions)
const mockCreate = jest.fn().mockRejectedValue(new Error('Unexpected error'))
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
await expect(handler.completePrompt('test prompt'))
.rejects.toThrow('OpenRouter completion error: Unexpected error')
})
})

View File

@@ -4,11 +4,11 @@ import OpenAI from "openai"
import { ApiHandler } from "../"
import { ApiHandlerOptions, ModelInfo, openRouterDefaultModelId, openRouterDefaultModelInfo } from "../../shared/api"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream"
import { ApiStream, ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream"
import delay from "delay"
// Add custom interface for OpenRouter params
interface OpenRouterChatCompletionParams extends OpenAI.Chat.ChatCompletionCreateParamsStreaming {
type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & {
transforms?: string[];
}
@@ -17,7 +17,12 @@ interface OpenRouterApiStreamUsageChunk extends ApiStreamUsageChunk {
fullResponseText: string;
}
export class OpenRouterHandler implements ApiHandler {
// Interface for providers that support single completions
export interface SingleCompletionHandler {
completePrompt(prompt: string): Promise<string>
}
export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
private options: ApiHandlerOptions
private client: OpenAI
@@ -184,4 +189,28 @@ export class OpenRouterHandler implements ApiHandler {
}
return { id: openRouterDefaultModelId, info: openRouterDefaultModelInfo }
}
async completePrompt(prompt: string): Promise<string> {
try {
const response = await this.client.chat.completions.create({
model: this.getModel().id,
messages: [{ role: "user", content: prompt }],
temperature: 0,
stream: false
})
if ("error" in response) {
const error = response.error as { message?: string; code?: number }
throw new Error(`OpenRouter API Error ${error?.code}: ${error?.message}`)
}
const completion = response as OpenAI.Chat.ChatCompletion
return completion.choices[0]?.message?.content || ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`OpenRouter completion error: ${error.message}`)
}
throw error
}
}
}