mirror of
https://github.com/pacnpal/Roo-Code.git
synced 2025-12-20 04:11:10 -05:00
Add non-streaming completePrompt to all providers
This commit is contained in:
@@ -1,230 +1,209 @@
|
||||
import { OpenAiNativeHandler } from "../openai-native"
|
||||
import OpenAI from "openai"
|
||||
import { ApiHandlerOptions, openAiNativeDefaultModelId } from "../../../shared/api"
|
||||
import { Anthropic } from "@anthropic-ai/sdk"
|
||||
import { OpenAiNativeHandler } from '../openai-native';
|
||||
import { ApiHandlerOptions } from '../../../shared/api';
|
||||
import OpenAI from 'openai';
|
||||
import { Anthropic } from '@anthropic-ai/sdk';
|
||||
|
||||
// Mock OpenAI
|
||||
jest.mock("openai")
|
||||
|
||||
describe("OpenAiNativeHandler", () => {
|
||||
let handler: OpenAiNativeHandler
|
||||
let mockOptions: ApiHandlerOptions
|
||||
let mockOpenAIClient: jest.Mocked<OpenAI>
|
||||
let mockCreate: jest.Mock
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset mocks
|
||||
jest.clearAllMocks()
|
||||
|
||||
// Setup mock options
|
||||
mockOptions = {
|
||||
openAiNativeApiKey: "test-api-key",
|
||||
apiModelId: "gpt-4o", // Use the correct model ID from shared/api.ts
|
||||
}
|
||||
|
||||
// Setup mock create function
|
||||
mockCreate = jest.fn()
|
||||
|
||||
// Setup mock OpenAI client
|
||||
mockOpenAIClient = {
|
||||
// Mock OpenAI client
|
||||
const mockCreate = jest.fn();
|
||||
jest.mock('openai', () => {
|
||||
return {
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate,
|
||||
},
|
||||
},
|
||||
} as unknown as jest.Mocked<OpenAI>
|
||||
create: mockCreate.mockImplementation(async (options) => {
|
||||
if (!options.stream) {
|
||||
return {
|
||||
id: 'test-completion',
|
||||
choices: [{
|
||||
message: { role: 'assistant', content: 'Test response' },
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: 'Test response' },
|
||||
index: 0
|
||||
}],
|
||||
usage: null
|
||||
};
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
});
|
||||
|
||||
// Mock OpenAI constructor
|
||||
;(OpenAI as jest.MockedClass<typeof OpenAI>).mockImplementation(() => mockOpenAIClient)
|
||||
describe('OpenAiNativeHandler', () => {
|
||||
let handler: OpenAiNativeHandler;
|
||||
let mockOptions: ApiHandlerOptions;
|
||||
|
||||
// Create handler instance
|
||||
handler = new OpenAiNativeHandler(mockOptions)
|
||||
})
|
||||
beforeEach(() => {
|
||||
mockOptions = {
|
||||
apiModelId: 'gpt-4o',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
};
|
||||
handler = new OpenAiNativeHandler(mockOptions);
|
||||
mockCreate.mockClear();
|
||||
});
|
||||
|
||||
describe("constructor", () => {
|
||||
it("should initialize with provided options", () => {
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
apiKey: mockOptions.openAiNativeApiKey,
|
||||
})
|
||||
})
|
||||
})
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided options', () => {
|
||||
expect(handler).toBeInstanceOf(OpenAiNativeHandler);
|
||||
expect(handler.getModel().id).toBe(mockOptions.apiModelId);
|
||||
});
|
||||
|
||||
describe("getModel", () => {
|
||||
it("should return specified model when valid", () => {
|
||||
const result = handler.getModel()
|
||||
expect(result.id).toBe("gpt-4o") // Use the correct model ID
|
||||
})
|
||||
it('should initialize with empty API key', () => {
|
||||
const handlerWithoutKey = new OpenAiNativeHandler({
|
||||
apiModelId: 'gpt-4o',
|
||||
openAiNativeApiKey: ''
|
||||
});
|
||||
expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler);
|
||||
});
|
||||
});
|
||||
|
||||
it("should return default model when model ID is invalid", () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: "invalid-model" as any,
|
||||
})
|
||||
const result = handler.getModel()
|
||||
expect(result.id).toBe(openAiNativeDefaultModelId)
|
||||
})
|
||||
|
||||
it("should return default model when model ID is not provided", () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: undefined,
|
||||
})
|
||||
const result = handler.getModel()
|
||||
expect(result.id).toBe(openAiNativeDefaultModelId)
|
||||
})
|
||||
})
|
||||
|
||||
describe("createMessage", () => {
|
||||
const systemPrompt = "You are a helpful assistant"
|
||||
describe('createMessage', () => {
|
||||
const systemPrompt = 'You are a helpful assistant.';
|
||||
const messages: Anthropic.Messages.MessageParam[] = [
|
||||
{ role: "user", content: "Hello" },
|
||||
]
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello!'
|
||||
}
|
||||
];
|
||||
|
||||
describe("o1 models", () => {
|
||||
beforeEach(() => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: "o1-preview",
|
||||
})
|
||||
})
|
||||
it('should handle streaming responses', async () => {
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
it("should handle non-streaming response for o1 models", async () => {
|
||||
const mockResponse = {
|
||||
choices: [{ message: { content: "Hello there!" } }],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
},
|
||||
}
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const textChunks = chunks.filter(chunk => chunk.type === 'text');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].text).toBe('Test response');
|
||||
});
|
||||
|
||||
mockCreate.mockResolvedValueOnce(mockResponse)
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const results = []
|
||||
for await (const result of generator) {
|
||||
results.push(result)
|
||||
}
|
||||
const stream = handler.createMessage(systemPrompt, messages);
|
||||
|
||||
expect(results).toEqual([
|
||||
{ type: "text", text: "Hello there!" },
|
||||
{ type: "usage", inputTokens: 10, outputTokens: 5 },
|
||||
])
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "o1-preview",
|
||||
messages: [
|
||||
{ role: "user", content: systemPrompt },
|
||||
{ role: "user", content: "Hello" },
|
||||
],
|
||||
})
|
||||
})
|
||||
|
||||
it("should handle missing content in response", async () => {
|
||||
const mockResponse = {
|
||||
choices: [{ message: { content: null } }],
|
||||
usage: null,
|
||||
}
|
||||
|
||||
mockCreate.mockResolvedValueOnce(mockResponse)
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const results = []
|
||||
for await (const result of generator) {
|
||||
results.push(result)
|
||||
}
|
||||
|
||||
expect(results).toEqual([
|
||||
{ type: "text", text: "" },
|
||||
{ type: "usage", inputTokens: 0, outputTokens: 0 },
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe("streaming models", () => {
|
||||
beforeEach(() => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
...mockOptions,
|
||||
apiModelId: "gpt-4o",
|
||||
})
|
||||
})
|
||||
|
||||
it("should handle streaming response", async () => {
|
||||
const mockStream = [
|
||||
{ choices: [{ delta: { content: "Hello" } }], usage: null },
|
||||
{ choices: [{ delta: { content: " there" } }], usage: null },
|
||||
{ choices: [{ delta: { content: "!" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||
]
|
||||
|
||||
mockCreate.mockResolvedValueOnce(
|
||||
(async function* () {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk
|
||||
}
|
||||
})()
|
||||
)
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const results = []
|
||||
for await (const result of generator) {
|
||||
results.push(result)
|
||||
}
|
||||
|
||||
expect(results).toEqual([
|
||||
{ type: "text", text: "Hello" },
|
||||
{ type: "text", text: " there" },
|
||||
{ type: "text", text: "!" },
|
||||
{ type: "usage", inputTokens: 10, outputTokens: 5 },
|
||||
])
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "gpt-4o",
|
||||
temperature: 0,
|
||||
messages: [
|
||||
{ role: "system", content: systemPrompt },
|
||||
{ role: "user", content: "Hello" },
|
||||
],
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
})
|
||||
})
|
||||
|
||||
it("should handle empty delta content", async () => {
|
||||
const mockStream = [
|
||||
{ choices: [{ delta: {} }], usage: null },
|
||||
{ choices: [{ delta: { content: null } }], usage: null },
|
||||
{ choices: [{ delta: { content: "Hello" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
|
||||
]
|
||||
|
||||
mockCreate.mockResolvedValueOnce(
|
||||
(async function* () {
|
||||
for (const chunk of mockStream) {
|
||||
yield chunk
|
||||
}
|
||||
})()
|
||||
)
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
const results = []
|
||||
for await (const result of generator) {
|
||||
results.push(result)
|
||||
}
|
||||
|
||||
expect(results).toEqual([
|
||||
{ type: "text", text: "Hello" },
|
||||
{ type: "usage", inputTokens: 10, outputTokens: 5 },
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
it("should handle API errors", async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
||||
|
||||
const generator = handler.createMessage(systemPrompt, messages)
|
||||
await expect(async () => {
|
||||
for await (const _ of generator) {
|
||||
// consume generator
|
||||
for await (const chunk of stream) {
|
||||
// Should not reach here
|
||||
}
|
||||
}).rejects.toThrow("API Error")
|
||||
})
|
||||
})
|
||||
})
|
||||
}).rejects.toThrow('API Error');
|
||||
});
|
||||
});
|
||||
|
||||
describe('completePrompt', () => {
|
||||
it('should complete prompt successfully with gpt-4o model', async () => {
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'gpt-4o',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
temperature: 0
|
||||
});
|
||||
});
|
||||
|
||||
it('should complete prompt successfully with o1 model', async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: 'o1',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||
});
|
||||
});
|
||||
|
||||
it('should complete prompt successfully with o1-preview model', async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: 'o1-preview',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1-preview',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||
});
|
||||
});
|
||||
|
||||
it('should complete prompt successfully with o1-mini model', async () => {
|
||||
handler = new OpenAiNativeHandler({
|
||||
apiModelId: 'o1-mini',
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('Test response');
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: 'o1-mini',
|
||||
messages: [{ role: 'user', content: 'Test prompt' }]
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle API errors', async () => {
|
||||
mockCreate.mockRejectedValueOnce(new Error('API Error'));
|
||||
await expect(handler.completePrompt('Test prompt'))
|
||||
.rejects.toThrow('OpenAI Native completion error: API Error');
|
||||
});
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{ message: { content: '' } }]
|
||||
});
|
||||
const result = await handler.completePrompt('Test prompt');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return model info', () => {
|
||||
const modelInfo = handler.getModel();
|
||||
expect(modelInfo.id).toBe(mockOptions.apiModelId);
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
expect(modelInfo.info.maxTokens).toBe(4096);
|
||||
expect(modelInfo.info.contextWindow).toBe(128_000);
|
||||
});
|
||||
|
||||
it('should handle undefined model ID', () => {
|
||||
const handlerWithoutModel = new OpenAiNativeHandler({
|
||||
openAiNativeApiKey: 'test-api-key'
|
||||
});
|
||||
const modelInfo = handlerWithoutModel.getModel();
|
||||
expect(modelInfo.id).toBe('gpt-4o'); // Default model
|
||||
expect(modelInfo.info).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user