diff --git a/package.json b/package.json index 30de430f67..fff38149f4 100644 --- a/package.json +++ b/package.json @@ -19,6 +19,7 @@ "dependencies": { "@lobehub/ui": "^1", "@vercel/analytics": "^1", + "langchain": "latest", "next": "^13", "react": "^18", "react-dom": "^18" diff --git a/src/pages/api/LangChainStream.ts b/src/pages/api/LangChainStream.ts new file mode 100644 index 0000000000..2853c477bb --- /dev/null +++ b/src/pages/api/LangChainStream.ts @@ -0,0 +1,90 @@ +import { LangChainParams } from '@/types/langchain'; +import { LLMChain } from 'langchain/chains'; +import { ChatOpenAI } from 'langchain/chat_models/openai'; +import { + AIMessagePromptTemplate, + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +} from 'langchain/prompts'; + +const isDev = process.env.NODE_ENV === 'development'; +const OPENAI_PROXY_URL = process.env.OPENAI_PROXY_URL; + +export function LangChainStream(payload: LangChainParams) { + const { prompts, vars, llm } = payload; + + // 将 payload 中的消息转换为 ChatOpenAI 所需的 HumanChatMessage、SystemChatMessage 和 AIChatMessage 类型 + const chatPrompt = ChatPromptTemplate.fromPromptMessages( + prompts.map((m) => { + switch (m.role) { + default: + case 'user': + return HumanMessagePromptTemplate.fromTemplate(m.content); + case 'system': + return SystemMessagePromptTemplate.fromTemplate(m.content); + + case 'assistant': + return AIMessagePromptTemplate.fromTemplate(m.content); + } + }), + ); + + // 使用 TextEncoder 将字符串转换为字节数组,以便在 ReadableStream 中发送 + const encoder = new TextEncoder(); + + // 初始化换行符计数器 + + return new ReadableStream({ + async start(controller) { + let newlineCounter = 0; + + const chat = new ChatOpenAI( + { + streaming: true, + ...llm, + // 暂时设定不重试 ,后续看是否需要支持重试 + maxRetries: 0, + callbacks: [ + { + handleLLMNewToken(token) { + // 如果 message 是换行符,且 newlineCounter 小于 2,那么跳过该换行符 + if (newlineCounter < 2 && token === '\n') { + return; + } + + // 将 message 编码为字节并添加到流中 + const queue = encoder.encode(token); + controller.enqueue(queue); + newlineCounter++; + }, + }, + ], + }, + isDev && OPENAI_PROXY_URL ? { basePath: OPENAI_PROXY_URL } : undefined, + ); + + const chain = new LLMChain({ + prompt: chatPrompt, + llm: chat, + verbose: true, + callbacks: [ + { + handleChainError(err: Error): Promise | void { + console.log(err.message); + }, + }, + ], + }); + try { + // 使用转换后的聊天消息作为输入开始聊天 + await chain.call(vars); + // 完成后,关闭流 + controller.close(); + } catch (e) { + // 如果在执行过程中发生错误,向流发送错误 + controller.error(e); + } + }, + }); +} diff --git a/src/pages/api/OpenAIStream.ts b/src/pages/api/OpenAIStream.ts new file mode 100644 index 0000000000..3d92ecfd2a --- /dev/null +++ b/src/pages/api/OpenAIStream.ts @@ -0,0 +1,117 @@ +import { ChatMessage } from '@lobehub/ui'; +import { ChatOpenAI } from 'langchain/chat_models/openai'; +import { AIChatMessage, HumanChatMessage, SystemChatMessage } from 'langchain/schema'; + +const isDev = process.env.NODE_ENV === 'development'; +const OPENAI_PROXY_URL = process.env.OPENAI_PROXY_URL; + +/** + * @title OpenAI Stream Payload + */ +export interface OpenAIStreamPayload { + /** + * @title 模型名称 + */ + model: string; + /** + * @title 聊天信息列表 + */ + messages: ChatMessage[]; + /** + * @title 生成文本的随机度量,用于控制文本的创造性和多样性 + * @default 0.5 + */ + temperature: number; + /** + * @title 控制生成文本中最高概率的单个令牌 + * @default 1 + */ + top_p?: number; + /** + * @title 控制生成文本中的惩罚系数,用于减少重复性 + * @default 0 + */ + frequency_penalty?: number; + /** + * @title 控制生成文本中的惩罚系数,用于减少主题的变化 + * @default 0 + */ + presence_penalty?: number; + /** + * @title 生成文本的最大长度 + */ + max_tokens?: number; + /** + * @title 是否开启流式请求 + * @default true + */ + stream?: boolean; + /** + * @title 返回的文本数量 + */ + n?: number; +} + +export function OpenAIStream(payload: OpenAIStreamPayload) { + const { messages, ...params } = payload; + + // 将 payload 中的消息转换为 ChatOpenAI 所需的 HumanChatMessage、SystemChatMessage 和 AIChatMessage 类型 + const chatMessages = messages.map((m) => { + switch (m.role) { + default: + case 'user': + return new HumanChatMessage(m.content); + case 'system': + return new SystemChatMessage(m.content); + + case 'assistant': + return new AIChatMessage(m.content); + } + }); + + // 使用 TextEncoder 将字符串转换为字节数组,以便在 ReadableStream 中发送 + const encoder = new TextEncoder(); + + // 初始化换行符计数器 + + return new ReadableStream({ + async start(controller) { + let newlineCounter = 0; + + const chat = new ChatOpenAI( + { + streaming: true, + ...params, + // 暂时设定不重试 ,后续看是否需要支持重试 + maxRetries: 0, + callbacks: [ + { + handleLLMNewToken(token) { + // 如果 message 是换行符,且 newlineCounter 小于 2,那么跳过该换行符 + if (newlineCounter < 2 && token === '\n') { + return; + } + + // 将 message 编码为字节并添加到流中 + const queue = encoder.encode(token); + controller.enqueue(queue); + newlineCounter++; + }, + }, + ], + }, + isDev && OPENAI_PROXY_URL ? { basePath: OPENAI_PROXY_URL } : undefined, + ); + + try { + // 使用转换后的聊天消息作为输入开始聊天 + await chat.call(chatMessages); + // 完成后,关闭流 + controller.close(); + } catch (e) { + // 如果在执行过程中发生错误,向流发送错误 + controller.error(e); + } + }, + }); +} diff --git a/src/pages/api/chain.api.ts b/src/pages/api/chain.api.ts new file mode 100644 index 0000000000..3e17fe07dd --- /dev/null +++ b/src/pages/api/chain.api.ts @@ -0,0 +1,16 @@ +import { LangChainParams } from '@/types/langchain'; +import { LangChainStream } from './LangChainStream'; + +if (!process.env.OPENAI_API_KEY) { + throw new Error('Missing env var from OpenAI'); +} + +export const config = { + runtime: 'edge', +}; + +export default async function handler(request: Request) { + const payload = (await request.json()) as LangChainParams; + + return new Response(LangChainStream(payload)); +} diff --git a/src/pages/api/openai.api.ts b/src/pages/api/openai.api.ts new file mode 100644 index 0000000000..95c909d986 --- /dev/null +++ b/src/pages/api/openai.api.ts @@ -0,0 +1,15 @@ +import { OpenAIStream, OpenAIStreamPayload } from './OpenAIStream'; + +if (!process.env.OPENAI_API_KEY) { + throw new Error('Missing env var from OpenAI'); +} + +export const config = { + runtime: 'edge', +}; + +export default async function handler(request: Request) { + const payload = (await request.json()) as OpenAIStreamPayload; + + return new Response(OpenAIStream(payload)); +} diff --git a/src/types/langchain.ts b/src/types/langchain.ts new file mode 100644 index 0000000000..95aa5c5a50 --- /dev/null +++ b/src/types/langchain.ts @@ -0,0 +1,34 @@ +import { ChatMessage } from '@lobehub/ui'; + +export interface LangChainParams { + llm: { + model: string; + /** + * 生成文本的随机度量,用于控制文本的创造性和多样性 + * @default 0.6 + */ + temperature: number; + /** + * 控制生成文本中最高概率的单个令牌 + */ + top_p?: number; + /** + * 控制生成文本中的惩罚系数,用于减少重复性 + */ + frequency_penalty?: number; + /** + * 控制生成文本中的惩罚系数,用于减少主题的变化 + */ + presence_penalty?: number; + /** + * 生成文本的最大长度 + */ + max_tokens?: number; + }; + + /** + * 聊天信息列表 + */ + prompts: ChatMessage[]; + vars: Record; +}