feat: add openai server api

This commit is contained in:
arvinxx
2023-05-25 01:29:35 +08:00
parent f2e07f3b0f
commit 59d381e77a
6 changed files with 273 additions and 0 deletions

View File

@@ -19,6 +19,7 @@
"dependencies": {
"@lobehub/ui": "^1",
"@vercel/analytics": "^1",
"langchain": "latest",
"next": "^13",
"react": "^18",
"react-dom": "^18"

View File

@@ -0,0 +1,90 @@
import { LangChainParams } from '@/types/langchain';
import { LLMChain } from 'langchain/chains';
import { ChatOpenAI } from 'langchain/chat_models/openai';
import {
AIMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
} from 'langchain/prompts';
const isDev = process.env.NODE_ENV === 'development';
const OPENAI_PROXY_URL = process.env.OPENAI_PROXY_URL;
export function LangChainStream(payload: LangChainParams) {
const { prompts, vars, llm } = payload;
// 将 payload 中的消息转换为 ChatOpenAI 所需的 HumanChatMessage、SystemChatMessage 和 AIChatMessage 类型
const chatPrompt = ChatPromptTemplate.fromPromptMessages(
prompts.map((m) => {
switch (m.role) {
default:
case 'user':
return HumanMessagePromptTemplate.fromTemplate(m.content);
case 'system':
return SystemMessagePromptTemplate.fromTemplate(m.content);
case 'assistant':
return AIMessagePromptTemplate.fromTemplate(m.content);
}
}),
);
// 使用 TextEncoder 将字符串转换为字节数组,以便在 ReadableStream 中发送
const encoder = new TextEncoder();
// 初始化换行符计数器
return new ReadableStream({
async start(controller) {
let newlineCounter = 0;
const chat = new ChatOpenAI(
{
streaming: true,
...llm,
// 暂时设定不重试 ,后续看是否需要支持重试
maxRetries: 0,
callbacks: [
{
handleLLMNewToken(token) {
// 如果 message 是换行符,且 newlineCounter 小于 2那么跳过该换行符
if (newlineCounter < 2 && token === '\n') {
return;
}
// 将 message 编码为字节并添加到流中
const queue = encoder.encode(token);
controller.enqueue(queue);
newlineCounter++;
},
},
],
},
isDev && OPENAI_PROXY_URL ? { basePath: OPENAI_PROXY_URL } : undefined,
);
const chain = new LLMChain({
prompt: chatPrompt,
llm: chat,
verbose: true,
callbacks: [
{
handleChainError(err: Error): Promise<void> | void {
console.log(err.message);
},
},
],
});
try {
// 使用转换后的聊天消息作为输入开始聊天
await chain.call(vars);
// 完成后,关闭流
controller.close();
} catch (e) {
// 如果在执行过程中发生错误,向流发送错误
controller.error(e);
}
},
});
}

View File

@@ -0,0 +1,117 @@
import { ChatMessage } from '@lobehub/ui';
import { ChatOpenAI } from 'langchain/chat_models/openai';
import { AIChatMessage, HumanChatMessage, SystemChatMessage } from 'langchain/schema';
const isDev = process.env.NODE_ENV === 'development';
const OPENAI_PROXY_URL = process.env.OPENAI_PROXY_URL;
/**
* @title OpenAI Stream Payload
*/
export interface OpenAIStreamPayload {
/**
* @title 模型名称
*/
model: string;
/**
* @title 聊天信息列表
*/
messages: ChatMessage[];
/**
* @title 生成文本的随机度量,用于控制文本的创造性和多样性
* @default 0.5
*/
temperature: number;
/**
* @title 控制生成文本中最高概率的单个令牌
* @default 1
*/
top_p?: number;
/**
* @title 控制生成文本中的惩罚系数,用于减少重复性
* @default 0
*/
frequency_penalty?: number;
/**
* @title 控制生成文本中的惩罚系数,用于减少主题的变化
* @default 0
*/
presence_penalty?: number;
/**
* @title 生成文本的最大长度
*/
max_tokens?: number;
/**
* @title 是否开启流式请求
* @default true
*/
stream?: boolean;
/**
* @title 返回的文本数量
*/
n?: number;
}
export function OpenAIStream(payload: OpenAIStreamPayload) {
const { messages, ...params } = payload;
// 将 payload 中的消息转换为 ChatOpenAI 所需的 HumanChatMessage、SystemChatMessage 和 AIChatMessage 类型
const chatMessages = messages.map((m) => {
switch (m.role) {
default:
case 'user':
return new HumanChatMessage(m.content);
case 'system':
return new SystemChatMessage(m.content);
case 'assistant':
return new AIChatMessage(m.content);
}
});
// 使用 TextEncoder 将字符串转换为字节数组,以便在 ReadableStream 中发送
const encoder = new TextEncoder();
// 初始化换行符计数器
return new ReadableStream({
async start(controller) {
let newlineCounter = 0;
const chat = new ChatOpenAI(
{
streaming: true,
...params,
// 暂时设定不重试 ,后续看是否需要支持重试
maxRetries: 0,
callbacks: [
{
handleLLMNewToken(token) {
// 如果 message 是换行符,且 newlineCounter 小于 2那么跳过该换行符
if (newlineCounter < 2 && token === '\n') {
return;
}
// 将 message 编码为字节并添加到流中
const queue = encoder.encode(token);
controller.enqueue(queue);
newlineCounter++;
},
},
],
},
isDev && OPENAI_PROXY_URL ? { basePath: OPENAI_PROXY_URL } : undefined,
);
try {
// 使用转换后的聊天消息作为输入开始聊天
await chat.call(chatMessages);
// 完成后,关闭流
controller.close();
} catch (e) {
// 如果在执行过程中发生错误,向流发送错误
controller.error(e);
}
},
});
}

View File

@@ -0,0 +1,16 @@
import { LangChainParams } from '@/types/langchain';
import { LangChainStream } from './LangChainStream';
if (!process.env.OPENAI_API_KEY) {
throw new Error('Missing env var from OpenAI');
}
export const config = {
runtime: 'edge',
};
export default async function handler(request: Request) {
const payload = (await request.json()) as LangChainParams;
return new Response(LangChainStream(payload));
}

View File

@@ -0,0 +1,15 @@
import { OpenAIStream, OpenAIStreamPayload } from './OpenAIStream';
if (!process.env.OPENAI_API_KEY) {
throw new Error('Missing env var from OpenAI');
}
export const config = {
runtime: 'edge',
};
export default async function handler(request: Request) {
const payload = (await request.json()) as OpenAIStreamPayload;
return new Response(OpenAIStream(payload));
}

34
src/types/langchain.ts Normal file
View File

@@ -0,0 +1,34 @@
import { ChatMessage } from '@lobehub/ui';
export interface LangChainParams {
llm: {
model: string;
/**
* 生成文本的随机度量,用于控制文本的创造性和多样性
* @default 0.6
*/
temperature: number;
/**
* 控制生成文本中最高概率的单个令牌
*/
top_p?: number;
/**
* 控制生成文本中的惩罚系数,用于减少重复性
*/
frequency_penalty?: number;
/**
* 控制生成文本中的惩罚系数,用于减少主题的变化
*/
presence_penalty?: number;
/**
* 生成文本的最大长度
*/
max_tokens?: number;
};
/**
* 聊天信息列表
*/
prompts: ChatMessage[];
vars: Record<string, string>;
}