This commit is contained in:
Matt Rubens
2024-12-29 12:18:23 -08:00
parent 25987dd40b
commit 6290f90fa5
4 changed files with 36 additions and 95 deletions

View File

@@ -137,7 +137,13 @@ describe('DeepSeekHandler', () => {
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
messages: [ messages: [
{ role: 'system', content: systemPrompt }, { role: 'system', content: systemPrompt },
{ role: 'user', content: 'part 1part 2' } {
role: 'user',
content: [
{ type: 'text', text: 'part 1' },
{ type: 'text', text: 'part 2' }
]
}
] ]
})) }))
}) })

View File

@@ -1,96 +1,26 @@
import { Anthropic } from "@anthropic-ai/sdk" import { OpenAiHandler } from "./openai"
import OpenAI from "openai" import { ApiHandlerOptions, ModelInfo } from "../../shared/api"
import { ApiHandlerOptions, ModelInfo, deepSeekModels, deepSeekDefaultModelId } from "../../shared/api" import { deepSeekModels, deepSeekDefaultModelId } from "../../shared/api"
import { ApiHandler } from "../index"
import { ApiStream } from "../transform/stream"
export class DeepSeekHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: OpenAI
export class DeepSeekHandler extends OpenAiHandler {
constructor(options: ApiHandlerOptions) { constructor(options: ApiHandlerOptions) {
this.options = options
if (!options.deepSeekApiKey) { if (!options.deepSeekApiKey) {
throw new Error("DeepSeek API key is required. Please provide it in the settings.") throw new Error("DeepSeek API key is required. Please provide it in the settings.")
} }
this.client = new OpenAI({ super({
baseURL: this.options.deepSeekBaseUrl ?? "https://api.deepseek.com/v1", ...options,
apiKey: this.options.deepSeekApiKey, openAiApiKey: options.deepSeekApiKey,
openAiModelId: options.deepSeekModelId ?? deepSeekDefaultModelId,
openAiBaseUrl: options.deepSeekBaseUrl ?? "https://api.deepseek.com/v1",
includeMaxTokens: true
}) })
} }
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { override getModel(): { id: string; info: ModelInfo } {
const modelInfo = deepSeekModels[this.options.deepSeekModelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId]
// Format all messages
const messagesToInclude: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: 'system' as const, content: systemPrompt }
]
// Add the rest of the messages
for (const msg of messages) {
let messageContent = ""
if (typeof msg.content === "string") {
messageContent = msg.content
} else if (Array.isArray(msg.content)) {
messageContent = msg.content.reduce((acc, part) => {
if (part.type === "text") {
return acc + part.text
}
return acc
}, "")
}
messagesToInclude.push({
role: msg.role === 'user' ? 'user' as const : 'assistant' as const,
content: messageContent
})
}
const requestOptions: OpenAI.Chat.ChatCompletionCreateParamsStreaming = {
model: this.options.deepSeekModelId ?? "deepseek-chat",
messages: messagesToInclude,
temperature: 0,
stream: true,
max_tokens: modelInfo.maxTokens,
}
if (this.options.includeStreamOptions ?? true) {
requestOptions.stream_options = { include_usage: true }
}
let totalInputTokens = 0;
let totalOutputTokens = 0;
try {
const stream = await this.client.chat.completions.create(requestOptions)
for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta
if (delta?.content) {
yield {
type: "text",
text: delta.content,
}
}
if (chunk.usage) {
yield {
type: "usage",
inputTokens: chunk.usage.prompt_tokens || 0,
outputTokens: chunk.usage.completion_tokens || 0,
}
}
}
} catch (error) {
console.error("DeepSeek API Error:", error)
throw error
}
}
getModel(): { id: string; info: ModelInfo } {
const modelId = this.options.deepSeekModelId ?? deepSeekDefaultModelId const modelId = this.options.deepSeekModelId ?? deepSeekDefaultModelId
return { return {
id: modelId, id: modelId,
info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId], info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId]
} }
} }
} }

View File

@@ -11,7 +11,7 @@ import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream" import { ApiStream } from "../transform/stream"
export class OpenAiHandler implements ApiHandler { export class OpenAiHandler implements ApiHandler {
private options: ApiHandlerOptions protected options: ApiHandlerOptions
private client: OpenAI private client: OpenAI
constructor(options: ApiHandlerOptions) { constructor(options: ApiHandlerOptions) {
@@ -38,12 +38,16 @@ export class OpenAiHandler implements ApiHandler {
{ role: "system", content: systemPrompt }, { role: "system", content: systemPrompt },
...convertToOpenAiMessages(messages), ...convertToOpenAiMessages(messages),
] ]
const modelInfo = this.getModel().info
const requestOptions: OpenAI.Chat.ChatCompletionCreateParams = { const requestOptions: OpenAI.Chat.ChatCompletionCreateParams = {
model: this.options.openAiModelId ?? "", model: this.options.openAiModelId ?? "",
messages: openAiMessages, messages: openAiMessages,
temperature: 0, temperature: 0,
stream: true, stream: true,
} }
if (this.options.includeMaxTokens) {
requestOptions.max_tokens = modelInfo.maxTokens
}
if (this.options.includeStreamOptions ?? true) { if (this.options.includeStreamOptions ?? true) {
requestOptions.stream_options = { include_usage: true } requestOptions.stream_options = { include_usage: true }

View File

@@ -42,6 +42,7 @@ export interface ApiHandlerOptions {
deepSeekBaseUrl?: string deepSeekBaseUrl?: string
deepSeekApiKey?: string deepSeekApiKey?: string
deepSeekModelId?: string deepSeekModelId?: string
includeMaxTokens?: boolean
} }
export type ApiConfiguration = ApiHandlerOptions & { export type ApiConfiguration = ApiHandlerOptions & {