mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-20 04:11:10 -05:00
Continuing work on support for OpenRouter compression (#43)
This commit is contained in:
121
src/api/providers/__tests__/openrouter.test.ts
Normal file
121
src/api/providers/__tests__/openrouter.test.ts
Normal file
@@ -0,0 +1,121 @@
|
||||
import { OpenRouterHandler } from '../openrouter'
|
||||
import { ApiHandlerOptions, ModelInfo } from '../../../shared/api'
|
||||
import OpenAI from 'openai'
|
||||
import axios from 'axios'
|
||||
import { Anthropic } from '@anthropic-ai/sdk'
|
||||
|
||||
// Mock dependencies
|
||||
jest.mock('openai')
|
||||
jest.mock('axios')
|
||||
jest.mock('delay', () => jest.fn(() => Promise.resolve()))
|
||||
|
||||
describe('OpenRouterHandler', () => {
|
||||
const mockOptions: ApiHandlerOptions = {
|
||||
openRouterApiKey: 'test-key',
|
||||
openRouterModelId: 'test-model',
|
||||
openRouterModelInfo: {
|
||||
name: 'Test Model',
|
||||
description: 'Test Description',
|
||||
maxTokens: 1000,
|
||||
contextWindow: 2000,
|
||||
supportsPromptCache: true,
|
||||
inputPrice: 0.01,
|
||||
outputPrice: 0.02
|
||||
} as ModelInfo
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
test('constructor initializes with correct options', () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
expect(handler).toBeInstanceOf(OpenRouterHandler)
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
baseURL: 'https://openrouter.ai/api/v1',
|
||||
apiKey: mockOptions.openRouterApiKey,
|
||||
defaultHeaders: {
|
||||
'HTTP-Referer': 'https://cline.bot',
|
||||
'X-Title': 'Cline',
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
test('getModel returns correct model info when options are provided', () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const result = handler.getModel()
|
||||
|
||||
expect(result).toEqual({
|
||||
id: mockOptions.openRouterModelId,
|
||||
info: mockOptions.openRouterModelInfo
|
||||
})
|
||||
})
|
||||
|
||||
test('createMessage generates correct stream chunks', async () => {
|
||||
const handler = new OpenRouterHandler(mockOptions)
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: 'test-id',
|
||||
choices: [{
|
||||
delta: {
|
||||
content: 'test response'
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mock OpenAI chat.completions.create
|
||||
const mockCreate = jest.fn().mockResolvedValue(mockStream)
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
|
||||
completions: { create: mockCreate }
|
||||
} as any
|
||||
|
||||
// Mock axios.get for generation details
|
||||
;(axios.get as jest.Mock).mockResolvedValue({
|
||||
data: {
|
||||
data: {
|
||||
native_tokens_prompt: 10,
|
||||
native_tokens_completion: 20,
|
||||
total_cost: 0.001
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const systemPrompt = 'test system prompt'
|
||||
const messages: Anthropic.Messages.MessageParam[] = [{ role: 'user' as const, content: 'test message' }]
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const chunks = []
|
||||
|
||||
for await (const chunk of generator) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
// Verify stream chunks
|
||||
expect(chunks).toHaveLength(2) // One text chunk and one usage chunk
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'text',
|
||||
text: 'test response'
|
||||
})
|
||||
expect(chunks[1]).toEqual({
|
||||
type: 'usage',
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalCost: 0.001,
|
||||
fullResponseText: 'test response'
|
||||
})
|
||||
|
||||
// Verify OpenAI client was called with correct parameters
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
model: mockOptions.openRouterModelId,
|
||||
temperature: 0,
|
||||
messages: expect.arrayContaining([
|
||||
{ role: 'system', content: systemPrompt },
|
||||
{ role: 'user', content: 'test message' }
|
||||
]),
|
||||
stream: true
|
||||
}))
|
||||
})
|
||||
})
|
||||
@@ -4,9 +4,19 @@ import OpenAI from "openai"
|
||||
import { ApiHandler } from "../"
|
||||
import { ApiHandlerOptions, ModelInfo, openRouterDefaultModelId, openRouterDefaultModelInfo } from "../../shared/api"
|
||||
import { convertToOpenAiMessages } from "../transform/openai-format"
|
||||
import { ApiStream } from "../transform/stream"
|
||||
import { ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream"
|
||||
import delay from "delay"
|
||||
|
||||
// Add custom interface for OpenRouter params
|
||||
interface OpenRouterChatCompletionParams extends OpenAI.Chat.ChatCompletionCreateParamsStreaming {
|
||||
transforms?: string[];
|
||||
}
|
||||
|
||||
// Add custom interface for OpenRouter usage chunk
|
||||
interface OpenRouterApiStreamUsageChunk extends ApiStreamUsageChunk {
|
||||
fullResponseText: string;
|
||||
}
|
||||
|
||||
export class OpenRouterHandler implements ApiHandler {
|
||||
private options: ApiHandlerOptions
|
||||
private client: OpenAI
|
||||
@@ -23,7 +33,7 @@ export class OpenRouterHandler implements ApiHandler {
|
||||
})
|
||||
}
|
||||
|
||||
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
||||
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): AsyncGenerator<ApiStreamChunk> {
|
||||
// Convert Anthropic messages to OpenAI format
|
||||
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
|
||||
{ role: "system", content: systemPrompt },
|
||||
@@ -95,17 +105,21 @@ export class OpenRouterHandler implements ApiHandler {
|
||||
maxTokens = 8_192
|
||||
break
|
||||
}
|
||||
// https://openrouter.ai/docs/transforms
|
||||
let fullResponseText = "";
|
||||
const stream = await this.client.chat.completions.create({
|
||||
model: this.getModel().id,
|
||||
max_tokens: maxTokens,
|
||||
temperature: 0,
|
||||
messages: openAiMessages,
|
||||
stream: true,
|
||||
})
|
||||
// This way, the transforms field will only be included in the parameters when openRouterUseMiddleOutTransform is true.
|
||||
...(this.options.openRouterUseMiddleOutTransform && { transforms: ["middle-out"] })
|
||||
} as OpenRouterChatCompletionParams);
|
||||
|
||||
let genId: string | undefined
|
||||
|
||||
for await (const chunk of stream) {
|
||||
for await (const chunk of stream as unknown as AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>) {
|
||||
// openrouter returns an error object instead of the openai sdk throwing an error
|
||||
if ("error" in chunk) {
|
||||
const error = chunk.error as { message?: string; code?: number }
|
||||
@@ -119,10 +133,11 @@ export class OpenRouterHandler implements ApiHandler {
|
||||
|
||||
const delta = chunk.choices[0]?.delta
|
||||
if (delta?.content) {
|
||||
fullResponseText += delta.content;
|
||||
yield {
|
||||
type: "text",
|
||||
text: delta.content,
|
||||
}
|
||||
} as ApiStreamChunk;
|
||||
}
|
||||
// if (chunk.usage) {
|
||||
// yield {
|
||||
@@ -153,13 +168,14 @@ export class OpenRouterHandler implements ApiHandler {
|
||||
inputTokens: generation?.native_tokens_prompt || 0,
|
||||
outputTokens: generation?.native_tokens_completion || 0,
|
||||
totalCost: generation?.total_cost || 0,
|
||||
}
|
||||
fullResponseText
|
||||
} as OpenRouterApiStreamUsageChunk;
|
||||
} catch (error) {
|
||||
// ignore if fails
|
||||
console.error("Error fetching OpenRouter generation details:", error)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
getModel(): { id: string; info: ModelInfo } {
|
||||
const modelId = this.options.openRouterModelId
|
||||
const modelInfo = this.options.openRouterModelInfo
|
||||
|
||||
Reference in New Issue
Block a user