feat: 优化插件模式下的用户体验 (#13)

This commit is contained in:
Arvin Xu
2023-07-23 22:26:52 +08:00
committed by GitHub
parent f3a9902829
commit 4596f12b33
17 changed files with 323 additions and 126 deletions

View File

@@ -13,7 +13,7 @@ const withPWA = nextPWA({
const nextConfig = {
reactStrictMode: true,
pageExtensions: ['page.tsx', 'api.ts'],
transpilePackages: ['@lobehub/ui', 'antd-style'],
transpilePackages: ['@lobehub/ui'],
webpack(config) {
config.experiments = {
@@ -31,8 +31,8 @@ const nextConfig = {
destination: `${API_END_PORT_URL}/api/openai`,
},
{
source: '/api/chain-dev',
destination: `${API_END_PORT_URL}/api/chain`,
source: '/api/plugins-dev',
destination: `${API_END_PORT_URL}/api/plugins`,
},
];
},

2
src/const/message.ts Normal file
View File

@@ -0,0 +1,2 @@
// 只要 start with 这个,就可以判断为 function message
export const FUNCTION_MESSAGE_FLAG = '{"function';

View File

@@ -38,6 +38,7 @@ export default {
'plugin-realtimeWeather': '实时天气预报',
'plugin-searchEngine': '搜索引擎',
'pluginList': '插件列表',
'pluginLoading': '插件运行中...',
'profile': '助手身份',
'reset': '重置',
'searchAgentPlaceholder': '搜索助手和对话...',

View File

@@ -1,74 +1,15 @@
import { OpenAIStream, StreamingTextResponse } from 'ai';
import { Configuration, OpenAIApi } from 'openai-edge';
import { ChatCompletionFunctions, ChatCompletionRequestMessage } from 'openai-edge/types/api';
import { StreamingTextResponse } from 'ai';
import { OpenAIStreamPayload } from '@/types/openai';
import pluginList from '../../plugins';
import { createChatCompletion } from './openai';
export const runtime = 'edge';
const isDev = process.env.NODE_ENV === 'development';
const OPENAI_PROXY_URL = process.env.OPENAI_PROXY_URL;
// Create an OpenAI API client (that's edge friendly!)
const config = new Configuration({
apiKey: process.env.OPENAI_API_KEY,
});
const openai = new OpenAIApi(config, isDev && OPENAI_PROXY_URL ? OPENAI_PROXY_URL : undefined);
export default async function handler(req: Request) {
const {
messages,
plugins: enabledPlugins,
...params
} = (await req.json()) as OpenAIStreamPayload;
const payload = (await req.json()) as OpenAIStreamPayload;
// ============ 1. 前置处理 functions ============ //
const filterFunctions: ChatCompletionFunctions[] = pluginList
.filter((p) => {
// 如果不存在 enabledPlugins那么全部不启用
if (!enabledPlugins) return false;
// 如果存在 enabledPlugins那么只启用 enabledPlugins 中的插件
return enabledPlugins.includes(p.name);
})
.map((f) => f.schema);
const functions = filterFunctions.length === 0 ? undefined : filterFunctions;
// ============ 2. 前置处理 messages ============ //
const formatMessages = messages.map((m) => ({ content: m.content, role: m.role }));
const response = await openai.createChatCompletion({
functions,
messages: formatMessages,
stream: true,
...params,
});
const stream = OpenAIStream(response, {
experimental_onFunctionCall: async ({ name, arguments: args }, createFunctionCallMessages) => {
console.log(`执行 functionCall [${name}]`, 'args:', args);
const func = pluginList.find((f) => f.name === name);
if (func) {
const result = await func.runner(args as any);
const newMessages = createFunctionCallMessages(result) as ChatCompletionRequestMessage[];
return openai.createChatCompletion({
functions,
messages: [...formatMessages, ...newMessages],
stream: true,
...params,
});
}
},
});
const stream = await createChatCompletion(payload);
return new StreamingTextResponse(stream);
}

50
src/pages/api/openai.ts Normal file
View File

@@ -0,0 +1,50 @@
import { OpenAIStream, OpenAIStreamCallbacks } from 'ai';
import { Configuration, OpenAIApi } from 'openai-edge';
import { ChatCompletionFunctions } from 'openai-edge/types/api';
import { OpenAIStreamPayload } from '@/types/openai';
import pluginList from '../../plugins';
const isDev = process.env.NODE_ENV === 'development';
const OPENAI_PROXY_URL = process.env.OPENAI_PROXY_URL;
// Create an OpenAI API client (that's edge friendly!)
const config = new Configuration({
apiKey: process.env.OPENAI_API_KEY,
});
export const openai = new OpenAIApi(
config,
isDev && OPENAI_PROXY_URL ? OPENAI_PROXY_URL : undefined,
);
export const createChatCompletion = async (
payload: OpenAIStreamPayload,
callbacks?: (payload: OpenAIStreamPayload) => OpenAIStreamCallbacks,
) => {
const { messages, plugins: enabledPlugins, ...params } = payload;
// ============ 1. 前置处理 functions ============ //
const filterFunctions: ChatCompletionFunctions[] = pluginList
.filter((p) => {
// 如果不存在 enabledPlugins那么全部不启用
if (!enabledPlugins) return false;
// 如果存在 enabledPlugins那么只启用 enabledPlugins 中的插件
return enabledPlugins.includes(p.name);
})
.map((f) => f.schema);
const functions = filterFunctions.length === 0 ? undefined : filterFunctions;
// ============ 2. 前置处理 messages ============ //
const formatMessages = messages.map((m) => ({ content: m.content, role: m.role }));
const requestParams = { functions, messages: formatMessages, stream: true, ...params };
const response = await openai.createChatCompletion(requestParams);
return OpenAIStream(response, callbacks?.(requestParams));
};

View File

@@ -0,0 +1,34 @@
import { StreamingTextResponse } from 'ai';
import { ChatCompletionRequestMessage } from 'openai-edge';
import { OpenAIStreamPayload } from '@/types/openai';
import pluginList from '../../plugins';
import { createChatCompletion, openai } from './openai';
export const runtime = 'edge';
export default async function handler(req: Request) {
const payload = (await req.json()) as OpenAIStreamPayload;
const stream = await createChatCompletion(payload, (payload) => ({
experimental_onFunctionCall: async ({ name, arguments: args }, createFunctionCallMessages) => {
console.log(`执行 functionCall [${name}]`, 'args:', args);
const func = pluginList.find((f) => f.name === name);
if (func) {
const result = await func.runner(args as any);
const newMessages = createFunctionCallMessages(result) as ChatCompletionRequestMessage[];
return openai.createChatCompletion({
...payload,
messages: [...payload.messages, ...newMessages],
});
}
},
}));
return new StreamingTextResponse(stream);
}

View File

@@ -1,10 +1,12 @@
import { ChatList } from '@lobehub/ui';
import { ChatList, ChatMessage } from '@lobehub/ui';
import isEqual from 'fast-deep-equal';
import { memo } from 'react';
import { ReactNode, memo } from 'react';
import { shallow } from 'zustand/shallow';
import { chatSelectors, useSessionStore } from '@/store/session';
import { isFunctionMessage } from '@/utils/message';
import FunctionMessage from './FunctionMessage';
import MessageExtra from './MessageExtra';
const List = () => {
@@ -14,6 +16,13 @@ const List = () => {
shallow,
);
const renderMessage = (content: ReactNode, message: ChatMessage) => {
if (message.role === 'function')
return isFunctionMessage(message.content) ? <FunctionMessage /> : content;
return content;
};
return (
<ChatList
data={data}
@@ -33,6 +42,7 @@ const List = () => {
onMessageChange={(id, content) => {
dispatchMessage({ id, key: 'content', type: 'updateMessage', value: content });
}}
renderMessage={renderMessage}
renderMessageExtra={MessageExtra}
style={{ marginTop: 24 }}
/>

View File

@@ -0,0 +1,31 @@
import { LoadingOutlined } from '@ant-design/icons';
import { createStyles } from 'antd-style';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { Flexbox } from 'react-layout-kit';
const useStyles = createStyles(({ css, token }) => ({
container: css`
padding: 4px 8px;
color: ${token.colorTextSecondary};
background: ${token.colorFillTertiary};
border: 1px solid ${token.colorBorder};
border-radius: 6px;
`,
}));
const FunctionMessage = memo(() => {
const { t } = useTranslation();
const { styles } = useStyles();
return (
<Flexbox className={styles.container} gap={8} horizontal>
<div>
<LoadingOutlined />
</div>
{t('pluginLoading')}
</Flexbox>
);
});
export default FunctionMessage;

View File

@@ -1,9 +1,13 @@
import { Avatar, Icon, Tooltip } from '@lobehub/ui';
import { Tag } from 'antd';
import { createStyles } from 'antd-style';
import { LucideToyBrick } from 'lucide-react';
import { ReactNode } from 'react';
import { useTranslation } from 'react-i18next';
import { Flexbox } from 'react-layout-kit';
import { shallow } from 'zustand/shallow';
import pluginList from '@/plugins';
import { agentSelectors, useSessionStore } from '@/store/session';
import { ChatMessage } from '@/types/chatMessage';
@@ -11,27 +15,70 @@ const useStyles = createStyles(({ css }) => ({
container: css`
margin-top: 8px;
`,
plugin: css`
display: flex;
gap: 4px;
align-items: center;
width: fit-content;
`,
}));
const MessageExtra = ({ role, extra }: ChatMessage): ReactNode => {
const MessageExtra = ({ role, extra, function_call }: ChatMessage): ReactNode => {
const { styles } = useStyles();
const { t } = useTranslation();
const [model] = useSessionStore((s) => [agentSelectors.currentAgentModel(s)], shallow);
// 1. 只有 ai 的 message
// 2. 且存在 fromModel
// 3. 且当前的 model 和 fromModel 不一致时
if (role === 'assistant' && extra?.fromModel && model !== extra?.fromModel)
// 才需要展示 model tag
return (
<Flexbox className={styles.container}>
<div>
<Tag bordered={false} style={{ borderRadius: 6 }}>
{extra?.fromModel}
</Tag>
</div>
</Flexbox>
);
const plugin = pluginList.find((p) => p.name === function_call?.name);
const funcTag = (
<Tooltip title={function_call?.arguments}>
<Tag bordered={false} className={styles.plugin} style={{ borderRadius: 6 }}>
{plugin?.avatar ? (
<Avatar avatar={plugin?.avatar} size={18} />
) : (
<Icon icon={LucideToyBrick} />
)}
{t(`plugin-${function_call?.name}` as any)}
</Tag>
</Tooltip>
);
const modelTag = (
<div>
<Tag bordered={false} style={{ borderRadius: 6 }}>
{extra?.fromModel}
</Tag>
</div>
);
// 1. 存在 fromModel
// 2. 且当前的 model 和 fromModel 不一致时
const hasModelTag = extra?.fromModel && model !== extra?.fromModel;
const hasFuncTag = !!function_call;
switch (role) {
case 'user':
case 'system': {
return;
}
case 'assistant': {
// 1. 只有 ai 的 message
// 2. 且存在 fromModel
// 3. 且当前的 model 和 fromModel 不一致时
if (!(hasModelTag || hasFuncTag)) return;
return (
<Flexbox className={styles.container}>
{hasFuncTag && funcTag}
{hasModelTag && modelTag}
</Flexbox>
);
}
case 'function': {
return <Flexbox className={styles.container}>{funcTag}</Flexbox>;
}
}
};
export default MessageExtra;

View File

@@ -9,7 +9,7 @@ import { URLS } from './url';
*/
export const fetchChatModel = (
params: Partial<OpenAIStreamPayload>,
signal?: AbortSignal | undefined,
options?: { signal?: AbortSignal | undefined; withPlugin?: boolean },
) => {
const payload = merge(
{
@@ -23,12 +23,12 @@ export const fetchChatModel = (
params,
);
return fetch(URLS.openai, {
return fetch(options?.withPlugin ? URLS.plugins : URLS.openai, {
body: JSON.stringify(payload),
headers: {
'Content-Type': 'application/json',
},
method: 'POST',
signal,
signal: options?.signal,
});
};

View File

@@ -4,4 +4,5 @@ const prefix = isDev ? '-dev' : '';
export const URLS = {
openai: '/api/openai' + prefix,
plugins: '/api/plugins' + prefix,
};

View File

@@ -3,7 +3,8 @@ import { StateCreator } from 'zustand/vanilla';
import { fetchChatModel } from '@/services/chatModel';
import { SessionStore, agentSelectors, chatSelectors, sessionSelectors } from '@/store/session';
import { ChatMessage } from '@/types/chatMessage';
import { FetchSSEOptions, fetchSSE } from '@/utils/fetch';
import { fetchSSE } from '@/utils/fetch';
import { isFunctionMessage } from '@/utils/message';
import { nanoid } from '@/utils/uuid';
import { MessageDispatch, messagesReducer } from './messageReducer';
@@ -38,7 +39,11 @@ export interface ChatAction {
* @param messages - 聊天消息数组
* @param options - 获取 SSE 选项
*/
generateMessage: (messages: ChatMessage[], options: FetchSSEOptions) => Promise<void>;
generateMessage: (
messages: ChatMessage[],
assistantMessageId: string,
withPlugin?: boolean,
) => Promise<{ isFunctionCall: boolean; output: string }>;
/**
* 处理消息编辑
* @param messageId - 消息 ID可选
@@ -100,16 +105,57 @@ export const createChatSlice: StateCreator<
get().dispatchSession({ chats, id: activeId, type: 'updateSessionChat' });
},
generateMessage: async (messages, options) => {
generateMessage: async (messages, assistantId, withPlugin) => {
const { dispatchMessage } = get();
set({ chatLoading: true });
const config = agentSelectors.currentAgentConfigSafe(get());
const fetcher = () =>
fetchChatModel({ messages, model: config.model, ...config.params, plugins: config.plugins });
fetchChatModel(
{ messages, model: config.model, ...config.params, plugins: config.plugins },
{ withPlugin },
);
await fetchSSE(fetcher, options);
let output = '';
let isFunctionCall = false;
await fetchSSE(fetcher, {
onErrorHandle: (error) => {
dispatchMessage({ id: assistantId, key: 'error', type: 'updateMessage', value: error });
},
onMessageHandle: (text) => {
output += text;
dispatchMessage({
id: assistantId,
key: 'content',
type: 'updateMessage',
value: output,
});
// 如果是 function call
if (isFunctionMessage(output)) {
isFunctionCall = true;
// 设为 function
dispatchMessage({
id: assistantId,
key: 'role',
type: 'updateMessage',
value: 'function',
});
}
// 滚动到最后一条消息
const item = document.querySelector('#for-loading');
if (!item) return;
item.scrollIntoView({ behavior: 'smooth' });
},
});
set({ chatLoading: false });
return { isFunctionCall, output };
},
handleMessageEditing: (messageId) => {
@@ -145,29 +191,36 @@ export const createChatSlice: StateCreator<
value: model,
});
let output = '';
// 生成 ai message
await generateMessage(messages, {
onErrorHandle: (error) => {
dispatchMessage({ id: assistantId, key: 'error', type: 'updateMessage', value: error });
},
onMessageHandle: (text) => {
output += text;
const { output, isFunctionCall } = await generateMessage(messages, assistantId);
dispatchMessage({
id: assistantId,
key: 'content',
type: 'updateMessage',
value: output,
});
// 如果是 function则发送函数调用方法
if (isFunctionCall) {
const { function_call } = JSON.parse(output);
// 滚动到最后一条消息
const item = document.querySelector('#for-loading');
if (!item) return;
dispatchMessage({
id: assistantId,
key: 'function_call',
type: 'updateMessage',
value: function_call,
});
item.scrollIntoView({ behavior: 'smooth' });
},
});
await generateMessage(
[
...messages,
{ content: '', function_call, id: assistantId, role: 'assistant' } as ChatMessage,
],
assistantId,
true,
);
dispatchMessage({
id: assistantId,
key: 'role',
type: 'updateMessage',
value: 'assistant',
});
}
},
resendMessage: async (messageId) => {
@@ -182,7 +235,7 @@ export const createChatSlice: StateCreator<
const histories = chats
.slice(0, currentIndex + 1)
// 如果点击重新发送的 message 其 role 是 assistant那么需要移除
// 如果点击重新发送的 message 其 role 是 assistant 或者 function,那么需要移除
// 如果点击重新发送的 message 其 role 是 user则不需要移除
.filter((c) => !(c.role === 'assistant' && c.id === messageId));

View File

@@ -1,3 +1,4 @@
import pluginList from '@/plugins';
import { ChatMessage } from '@/types/chatMessage';
import { LobeAgentSession } from '@/types/session';
@@ -5,6 +6,35 @@ export const organizeChats = (
session: LobeAgentSession,
avatar: { assistant: string; user: string },
) => {
const getMeta = (message: ChatMessage) => {
switch (message.role) {
case 'user': {
return {
avatar: avatar.user,
};
}
case 'system': {
return message.meta;
}
case 'assistant': {
return {
avatar: avatar.assistant,
title: session.meta.title,
};
}
case 'function': {
const plugin = pluginList.find((p) => p.name === message.function_call?.name);
return {
avatar: plugin?.avatar || '🧩',
title: plugin?.name || 'plugin-unknown',
};
}
}
};
const basic = Object.values<ChatMessage>(session.chats)
// 首先按照时间顺序排序,越早的在越前面
.sort((pre, next) => pre.createAt - next.createAt)
@@ -14,17 +44,7 @@ export const organizeChats = (
.map((m) => {
return {
...m,
meta:
m.role === 'assistant'
? {
avatar: avatar.assistant,
title: session.meta.title,
}
: m.role === 'user'
? {
avatar: avatar.user,
}
: m.meta,
meta: getMeta(m),
};
});

View File

@@ -22,7 +22,6 @@ export interface ChatMessage extends BaseDataModel {
*/
content: string;
error?: any;
// 扩展字段
extra?: {
fromModel?: string;
@@ -33,6 +32,9 @@ export interface ChatMessage extends BaseDataModel {
};
} & Record<string, any>;
function_call?: { arguments?: string; name: string };
name?: string;
parentId?: string;
// 引用
quotaId?: string;

View File

@@ -39,7 +39,7 @@ export interface LLMParams {
top_p?: number;
}
export type LLMRoleType = 'user' | 'system' | 'assistant';
export type LLMRoleType = 'user' | 'system' | 'assistant' | 'function';
export interface LLMMessage {
content: string;

View File

@@ -91,7 +91,7 @@ interface FetchAITaskResultParams<T> {
}
export const fetchAIFactory =
<T>(fetcher: (params: T, signal?: AbortSignal) => Promise<Response>) =>
<T>(fetcher: (params: T, options: { signal?: AbortSignal }) => Promise<Response>) =>
async ({
params,
onMessageHandle,
@@ -117,7 +117,7 @@ export const fetchAIFactory =
onLoadingChange?.(true);
const data = await fetchSSE(() => fetcher(params, abortController?.signal), {
const data = await fetchSSE(() => fetcher(params, { signal: abortController?.signal }), {
onErrorHandle: (error) => {
errorHandle(new Error(error.message));
},

5
src/utils/message.ts Normal file
View File

@@ -0,0 +1,5 @@
import { FUNCTION_MESSAGE_FLAG } from '@/const/message';
export const isFunctionMessage = (content: string) => {
return content.startsWith(FUNCTION_MESSAGE_FLAG);
};