Add DeepSeek to the list of providers

This commit is contained in:
Matt Rubens
2024-12-29 08:42:03 -08:00
parent 948c66c1af
commit eb8c4cc50f
8 changed files with 438 additions and 2 deletions

View File

@@ -0,0 +1,5 @@
---
"roo-cline": patch
---
Add the DeepSeek provider along with logic to trim messages when it hits the context window

View File

@@ -13,6 +13,7 @@ A fork of Cline, an autonomous coding agent, with some additional experimental f
- Includes current time in the system prompt - Includes current time in the system prompt
- Uses a file system watcher to more reliably watch for file system changes - Uses a file system watcher to more reliably watch for file system changes
- Language selection for Cline's communication (English, Japanese, Spanish, French, German, and more) - Language selection for Cline's communication (English, Japanese, Spanish, French, German, and more)
- Support for DeepSeek V3 with logic to trim messages when it hits the context window
- Support for Meta 3, 3.1, and 3.2 models via AWS Bedrock - Support for Meta 3, 3.1, and 3.2 models via AWS Bedrock
- Support for listing models from OpenAI-compatible providers - Support for listing models from OpenAI-compatible providers
- Per-tool MCP auto-approval - Per-tool MCP auto-approval

View File

@@ -9,6 +9,7 @@ import { OllamaHandler } from "./providers/ollama"
import { LmStudioHandler } from "./providers/lmstudio" import { LmStudioHandler } from "./providers/lmstudio"
import { GeminiHandler } from "./providers/gemini" import { GeminiHandler } from "./providers/gemini"
import { OpenAiNativeHandler } from "./providers/openai-native" import { OpenAiNativeHandler } from "./providers/openai-native"
import { DeepSeekHandler } from "./providers/deepseek"
import { ApiStream } from "./transform/stream" import { ApiStream } from "./transform/stream"
export interface SingleCompletionHandler { export interface SingleCompletionHandler {
@@ -41,6 +42,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
return new GeminiHandler(options) return new GeminiHandler(options)
case "openai-native": case "openai-native":
return new OpenAiNativeHandler(options) return new OpenAiNativeHandler(options)
case "deepseek":
return new DeepSeekHandler(options)
default: default:
return new AnthropicHandler(options) return new AnthropicHandler(options)
} }

View File

@@ -0,0 +1,251 @@
import { DeepSeekHandler } from '../deepseek'
import { ApiHandlerOptions } from '../../../shared/api'
import OpenAI from 'openai'
import { Anthropic } from '@anthropic-ai/sdk'
// Mock dependencies
jest.mock('openai')
jest.mock('../../../shared/api', () => ({
...jest.requireActual('../../../shared/api'),
deepSeekModels: {
'deepseek-chat': {
maxTokens: 1000,
contextWindow: 2000,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.014,
outputPrice: 0.28,
}
}
}))
describe('DeepSeekHandler', () => {
const mockOptions: ApiHandlerOptions = {
deepSeekApiKey: 'test-key',
deepSeekModelId: 'deepseek-chat',
}
beforeEach(() => {
jest.clearAllMocks()
})
test('constructor initializes with correct options', () => {
const handler = new DeepSeekHandler(mockOptions)
expect(handler).toBeInstanceOf(DeepSeekHandler)
expect(OpenAI).toHaveBeenCalledWith({
baseURL: 'https://api.deepseek.com/v1',
apiKey: mockOptions.deepSeekApiKey,
})
})
test('getModel returns correct model info', () => {
const handler = new DeepSeekHandler(mockOptions)
const result = handler.getModel()
expect(result).toEqual({
id: mockOptions.deepSeekModelId,
info: expect.objectContaining({
maxTokens: 1000,
contextWindow: 2000,
supportsPromptCache: false,
supportsImages: false,
inputPrice: 0.014,
outputPrice: 0.28,
})
})
})
test('getModel returns default model info when no model specified', () => {
const handler = new DeepSeekHandler({ deepSeekApiKey: 'test-key' })
const result = handler.getModel()
expect(result.id).toBe('deepseek-chat')
expect(result.info.maxTokens).toBe(1000)
})
test('createMessage handles string content correctly', async () => {
const handler = new DeepSeekHandler(mockOptions)
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
choices: [{
delta: {
content: 'test response'
}
}]
}
}
}
const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
const systemPrompt = 'test system prompt'
const messages: Anthropic.Messages.MessageParam[] = [
{ role: 'user', content: 'test message' }
]
const generator = handler.createMessage(systemPrompt, messages)
const chunks = []
for await (const chunk of generator) {
chunks.push(chunk)
}
expect(chunks).toHaveLength(1)
expect(chunks[0]).toEqual({
type: 'text',
text: 'test response'
})
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
model: mockOptions.deepSeekModelId,
messages: [
{ role: 'system', content: systemPrompt },
{ role: 'user', content: 'test message' }
],
temperature: 0,
stream: true,
max_tokens: 1000,
stream_options: { include_usage: true }
}))
})
test('createMessage handles complex content correctly', async () => {
const handler = new DeepSeekHandler(mockOptions)
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
choices: [{
delta: {
content: 'test response'
}
}]
}
}
}
const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
const systemPrompt = 'test system prompt'
const messages: Anthropic.Messages.MessageParam[] = [
{
role: 'user',
content: [
{ type: 'text', text: 'part 1' },
{ type: 'text', text: 'part 2' }
]
}
]
const generator = handler.createMessage(systemPrompt, messages)
await generator.next()
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
messages: [
{ role: 'system', content: systemPrompt },
{ role: 'user', content: 'part 1part 2' }
]
}))
})
test('createMessage truncates messages when exceeding context window', async () => {
const handler = new DeepSeekHandler(mockOptions)
const longString = 'a'.repeat(1000) // ~300 tokens
const shortString = 'b'.repeat(100) // ~30 tokens
const systemPrompt = 'test system prompt'
const messages: Anthropic.Messages.MessageParam[] = [
{ role: 'user', content: longString }, // Old message
{ role: 'assistant', content: 'short response' },
{ role: 'user', content: shortString } // Recent message
]
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
choices: [{
delta: {
content: '(Note: Some earlier messages were truncated to fit within the model\'s context window)\n\n'
}
}]
}
yield {
choices: [{
delta: {
content: 'test response'
}
}]
}
}
}
const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
const generator = handler.createMessage(systemPrompt, messages)
const chunks = []
for await (const chunk of generator) {
chunks.push(chunk)
}
// Should get two chunks: truncation notice and response
expect(chunks).toHaveLength(2)
expect(chunks[0]).toEqual({
type: 'text',
text: expect.stringContaining('truncated')
})
expect(chunks[1]).toEqual({
type: 'text',
text: 'test response'
})
// Verify API call includes system prompt and recent messages, but not old message
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
messages: expect.arrayContaining([
{ role: 'system', content: systemPrompt },
{ role: 'assistant', content: 'short response' },
{ role: 'user', content: shortString }
])
}))
// Verify truncation notice was included
expect(chunks[0]).toEqual({
type: 'text',
text: expect.stringContaining('truncated')
})
// Verify the messages array contains the expected messages
const calledMessages = mockCreate.mock.calls[0][0].messages
expect(calledMessages).toHaveLength(4)
expect(calledMessages[0]).toEqual({ role: 'system', content: systemPrompt })
expect(calledMessages[1]).toEqual({ role: 'user', content: longString })
expect(calledMessages[2]).toEqual({ role: 'assistant', content: 'short response' })
expect(calledMessages[3]).toEqual({ role: 'user', content: shortString })
})
test('createMessage handles API errors', async () => {
const handler = new DeepSeekHandler(mockOptions)
const mockStream = {
async *[Symbol.asyncIterator]() {
throw new Error('API Error')
}
}
const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate }
} as any
const generator = handler.createMessage('test', [])
await expect(generator.next()).rejects.toThrow('API Error')
})
})

View File

@@ -0,0 +1,116 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import { ApiHandlerOptions, ModelInfo, deepSeekModels, deepSeekDefaultModelId } from "../../shared/api"
import { ApiHandler } from "../index"
import { ApiStream } from "../transform/stream"
export class DeepSeekHandler implements ApiHandler {
private options: ApiHandlerOptions
private client: OpenAI
constructor(options: ApiHandlerOptions) {
this.options = options
if (!options.deepSeekApiKey) {
throw new Error("DeepSeek API key is required. Please provide it in the settings.")
}
this.client = new OpenAI({
baseURL: this.options.deepSeekBaseUrl ?? "https://api.deepseek.com/v1",
apiKey: this.options.deepSeekApiKey,
})
}
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
// Convert messages to simple format that DeepSeek expects
const formattedMessages = messages.map(msg => {
if (typeof msg.content === "string") {
return { role: msg.role, content: msg.content }
}
// For array content, concatenate text parts
return {
role: msg.role,
content: msg.content.reduce((acc, part) => {
if (part.type === "text") {
return acc + part.text
}
return acc
}, "")
}
})
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
...formattedMessages,
]
const modelInfo = deepSeekModels[this.options.deepSeekModelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId]
const contextWindow = modelInfo.contextWindow || 64_000
const getTokenCount = (content: string) => Math.ceil(content.length * 0.3)
// Always keep system prompt
const systemMsg = openAiMessages[0]
let availableTokens = contextWindow - getTokenCount(typeof systemMsg.content === 'string' ? systemMsg.content : '')
// Start with most recent messages and work backwards
const userMessages = openAiMessages.slice(1).reverse()
const includedMessages = []
let truncated = false
for (const msg of userMessages) {
const content = typeof msg.content === 'string' ? msg.content : ''
const tokens = getTokenCount(content)
if (tokens <= availableTokens) {
includedMessages.unshift(msg)
availableTokens -= tokens
} else {
truncated = true
break
}
}
if (truncated) {
yield {
type: 'text',
text: '(Note: Some earlier messages were truncated to fit within the model\'s context window)\n\n'
}
}
const requestOptions: OpenAI.Chat.ChatCompletionCreateParamsStreaming = {
model: this.options.deepSeekModelId ?? "deepseek-chat",
messages: [systemMsg, ...includedMessages],
temperature: 0,
stream: true,
max_tokens: modelInfo.maxTokens,
}
if (this.options.includeStreamOptions ?? true) {
requestOptions.stream_options = { include_usage: true }
}
const stream = await this.client.chat.completions.create(requestOptions)
for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta
if (delta?.content) {
yield {
type: "text",
text: delta.content,
}
}
if (chunk.usage) {
yield {
type: "usage",
inputTokens: chunk.usage.prompt_tokens || 0,
outputTokens: chunk.usage.completion_tokens || 0,
}
}
}
}
getModel(): { id: string; info: ModelInfo } {
const modelId = this.options.deepSeekModelId ?? deepSeekDefaultModelId
return {
id: modelId,
info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId],
}
}
}

View File

@@ -40,6 +40,7 @@ type SecretKey =
| "openAiApiKey" | "openAiApiKey"
| "geminiApiKey" | "geminiApiKey"
| "openAiNativeApiKey" | "openAiNativeApiKey"
| "deepSeekApiKey"
type GlobalStateKey = type GlobalStateKey =
| "apiProvider" | "apiProvider"
| "apiModelId" | "apiModelId"
@@ -443,6 +444,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
await this.updateGlobalState("anthropicBaseUrl", anthropicBaseUrl) await this.updateGlobalState("anthropicBaseUrl", anthropicBaseUrl)
await this.storeSecret("geminiApiKey", geminiApiKey) await this.storeSecret("geminiApiKey", geminiApiKey)
await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey) await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
await this.storeSecret("deepSeekApiKey", message.apiConfiguration.deepSeekApiKey)
await this.updateGlobalState("azureApiVersion", azureApiVersion) await this.updateGlobalState("azureApiVersion", azureApiVersion)
await this.updateGlobalState("openRouterModelId", openRouterModelId) await this.updateGlobalState("openRouterModelId", openRouterModelId)
await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo) await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo)
@@ -1121,6 +1123,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
anthropicBaseUrl, anthropicBaseUrl,
geminiApiKey, geminiApiKey,
openAiNativeApiKey, openAiNativeApiKey,
deepSeekApiKey,
azureApiVersion, azureApiVersion,
openRouterModelId, openRouterModelId,
openRouterModelInfo, openRouterModelInfo,
@@ -1163,6 +1166,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
this.getGlobalState("anthropicBaseUrl") as Promise<string | undefined>, this.getGlobalState("anthropicBaseUrl") as Promise<string | undefined>,
this.getSecret("geminiApiKey") as Promise<string | undefined>, this.getSecret("geminiApiKey") as Promise<string | undefined>,
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>, this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
this.getSecret("deepSeekApiKey") as Promise<string | undefined>,
this.getGlobalState("azureApiVersion") as Promise<string | undefined>, this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
this.getGlobalState("openRouterModelId") as Promise<string | undefined>, this.getGlobalState("openRouterModelId") as Promise<string | undefined>,
this.getGlobalState("openRouterModelInfo") as Promise<ModelInfo | undefined>, this.getGlobalState("openRouterModelInfo") as Promise<ModelInfo | undefined>,
@@ -1222,6 +1226,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
anthropicBaseUrl, anthropicBaseUrl,
geminiApiKey, geminiApiKey,
openAiNativeApiKey, openAiNativeApiKey,
deepSeekApiKey,
azureApiVersion, azureApiVersion,
openRouterModelId, openRouterModelId,
openRouterModelInfo, openRouterModelInfo,
@@ -1344,6 +1349,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
"openAiApiKey", "openAiApiKey",
"geminiApiKey", "geminiApiKey",
"openAiNativeApiKey", "openAiNativeApiKey",
"deepSeekApiKey",
] ]
for (const key of secretKeys) { for (const key of secretKeys) {
await this.storeSecret(key, undefined) await this.storeSecret(key, undefined)

View File

@@ -8,6 +8,7 @@ export type ApiProvider =
| "lmstudio" | "lmstudio"
| "gemini" | "gemini"
| "openai-native" | "openai-native"
| "deepseek"
export interface ApiHandlerOptions { export interface ApiHandlerOptions {
apiModelId?: string apiModelId?: string
@@ -38,6 +39,9 @@ export interface ApiHandlerOptions {
openRouterUseMiddleOutTransform?: boolean openRouterUseMiddleOutTransform?: boolean
includeStreamOptions?: boolean includeStreamOptions?: boolean
setAzureApiVersion?: boolean setAzureApiVersion?: boolean
deepSeekBaseUrl?: string
deepSeekApiKey?: string
deepSeekModelId?: string
} }
export type ApiConfiguration = ApiHandlerOptions & { export type ApiConfiguration = ApiHandlerOptions & {
@@ -489,6 +493,22 @@ export const openAiNativeModels = {
}, },
} as const satisfies Record<string, ModelInfo> } as const satisfies Record<string, ModelInfo>
// DeepSeek
// https://platform.deepseek.com/docs/api
export type DeepSeekModelId = keyof typeof deepSeekModels
export const deepSeekDefaultModelId: DeepSeekModelId = "deepseek-chat"
export const deepSeekModels = {
"deepseek-chat": {
maxTokens: 8192,
contextWindow: 64_000,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.014, // $0.014 per million tokens
outputPrice: 0.28, // $0.28 per million tokens
description: `DeepSeek-V3 achieves a significant breakthrough in inference speed over previous models. It tops the leaderboard among open-source models and rivals the most advanced closed-source models globally.`,
},
} as const satisfies Record<string, ModelInfo>
// Azure OpenAI // Azure OpenAI
// https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation // https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
// https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#api-specs // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#api-specs

View File

@@ -17,6 +17,8 @@ import {
azureOpenAiDefaultApiVersion, azureOpenAiDefaultApiVersion,
bedrockDefaultModelId, bedrockDefaultModelId,
bedrockModels, bedrockModels,
deepSeekDefaultModelId,
deepSeekModels,
geminiDefaultModelId, geminiDefaultModelId,
geminiModels, geminiModels,
openAiModelInfoSaneDefaults, openAiModelInfoSaneDefaults,
@@ -130,10 +132,11 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
<VSCodeOption value="openrouter">OpenRouter</VSCodeOption> <VSCodeOption value="openrouter">OpenRouter</VSCodeOption>
<VSCodeOption value="anthropic">Anthropic</VSCodeOption> <VSCodeOption value="anthropic">Anthropic</VSCodeOption>
<VSCodeOption value="gemini">Google Gemini</VSCodeOption> <VSCodeOption value="gemini">Google Gemini</VSCodeOption>
<VSCodeOption value="vertex">GCP Vertex AI</VSCodeOption> <VSCodeOption value="deepseek">DeepSeek</VSCodeOption>
<VSCodeOption value="bedrock">AWS Bedrock</VSCodeOption>
<VSCodeOption value="openai-native">OpenAI</VSCodeOption> <VSCodeOption value="openai-native">OpenAI</VSCodeOption>
<VSCodeOption value="openai">OpenAI Compatible</VSCodeOption> <VSCodeOption value="openai">OpenAI Compatible</VSCodeOption>
<VSCodeOption value="vertex">GCP Vertex AI</VSCodeOption>
<VSCodeOption value="bedrock">AWS Bedrock</VSCodeOption>
<VSCodeOption value="lmstudio">LM Studio</VSCodeOption> <VSCodeOption value="lmstudio">LM Studio</VSCodeOption>
<VSCodeOption value="ollama">Ollama</VSCodeOption> <VSCodeOption value="ollama">Ollama</VSCodeOption>
</VSCodeDropdown> </VSCodeDropdown>
@@ -560,6 +563,34 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
</div> </div>
)} )}
{selectedProvider === "deepseek" && (
<div>
<VSCodeTextField
value={apiConfiguration?.deepSeekApiKey || ""}
style={{ width: "100%" }}
type="password"
onInput={handleInputChange("deepSeekApiKey")}
placeholder="Enter API Key...">
<span style={{ fontWeight: 500 }}>DeepSeek API Key</span>
</VSCodeTextField>
<p
style={{
fontSize: "12px",
marginTop: "5px",
color: "var(--vscode-descriptionForeground)",
}}>
This key is stored locally and only used to make API requests from this extension.
{!apiConfiguration?.deepSeekApiKey && (
<VSCodeLink
href="https://platform.deepseek.com/"
style={{ display: "inline", fontSize: "inherit" }}>
You can get a DeepSeek API key by signing up here.
</VSCodeLink>
)}
</p>
</div>
)}
{selectedProvider === "ollama" && ( {selectedProvider === "ollama" && (
<div> <div>
<VSCodeTextField <VSCodeTextField
@@ -652,6 +683,7 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
{selectedProvider === "vertex" && createDropdown(vertexModels)} {selectedProvider === "vertex" && createDropdown(vertexModels)}
{selectedProvider === "gemini" && createDropdown(geminiModels)} {selectedProvider === "gemini" && createDropdown(geminiModels)}
{selectedProvider === "openai-native" && createDropdown(openAiNativeModels)} {selectedProvider === "openai-native" && createDropdown(openAiNativeModels)}
{selectedProvider === "deepseek" && createDropdown(deepSeekModels)}
</div> </div>
<ModelInfoView <ModelInfoView
@@ -836,6 +868,8 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
return getProviderData(vertexModels, vertexDefaultModelId) return getProviderData(vertexModels, vertexDefaultModelId)
case "gemini": case "gemini":
return getProviderData(geminiModels, geminiDefaultModelId) return getProviderData(geminiModels, geminiDefaultModelId)
case "deepseek":
return getProviderData(deepSeekModels, deepSeekDefaultModelId)
case "openai-native": case "openai-native":
return getProviderData(openAiNativeModels, openAiNativeDefaultModelId) return getProviderData(openAiNativeModels, openAiNativeDefaultModelId)
case "openrouter": case "openrouter":