♻️ refactor: refactor for session server mode (#2163)

* ♻️ refactor: refactor for session server mode

*  test: fix test

*  test: add tests

* 🚨 chore: fix lint
This commit is contained in:
Arvin Xu
2024-04-24 01:01:51 +08:00
committed by GitHub
parent 2d52303888
commit e012597226
36 changed files with 544 additions and 209 deletions

2
.npmrc
View File

@@ -12,3 +12,5 @@ public-hoist-pattern[]=*prettier*
public-hoist-pattern[]=*remark*
public-hoist-pattern[]=*semantic-release*
public-hoist-pattern[]=*stylelint*
public-hoist-pattern[]=@auth/core

View File

@@ -173,7 +173,7 @@
"@next/eslint-plugin-next": "^14.2.2",
"@peculiar/webcrypto": "^1.4.6",
"@testing-library/jest-dom": "^6.4.2",
"@testing-library/react": "^15.0.2",
"@testing-library/react": "^15.0.4",
"@types/chroma-js": "^2.4.4",
"@types/debug": "^4.1.12",
"@types/diff": "^5.2.0",

View File

@@ -27,7 +27,7 @@ const Actions = memo<ActionsProps>(
({ id, openRenameModal, openConfigModal, onOpenChange, isCustomGroup, isPinned }) => {
const { t } = useTranslation('chat');
const { styles } = useStyles();
const { modal } = App.useApp();
const { modal, message } = App.useApp();
const [createSession, removeSessionGroup] = useSessionStore((s) => [
s.createSession,
@@ -48,9 +48,15 @@ const Actions = memo<ActionsProps>(
icon: <Icon icon={Plus} />,
key: 'newAgent',
label: t('newAgent'),
onClick: ({ domEvent }) => {
onClick: async ({ domEvent }) => {
domEvent.stopPropagation();
createSession({ group: id, pinned: isPinned });
const key = 'createNewAgentInGroup';
message.loading({ content: t('sessionGroup.creatingAgent'), duration: 0, key });
await createSession({ group: id, pinned: isPinned });
message.destroy(key);
message.success({ content: t('sessionGroup.createAgentSuccess') });
},
};
@@ -83,9 +89,9 @@ const Actions = memo<ActionsProps>(
modal.confirm({
centered: true,
okButtonProps: { danger: true },
onOk: () => {
onOk: async () => {
if (!id) return;
removeSessionGroup(id);
await removeSessionGroup(id);
},
rootClassName: styles.modalRoot,
title: t('sessionGroup.confirmRemoveGroupAlert'),

View File

@@ -1,10 +1,12 @@
import { CollapseProps } from 'antd';
import isEqual from 'fast-deep-equal';
import { memo, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGlobalStore } from '@/store/global';
import { preferenceSelectors } from '@/store/global/selectors';
import { useSessionStore } from '@/store/session';
import { sessionSelectors } from '@/store/session/selectors';
import { SessionDefaultGroup } from '@/types/session';
import Actions from '../SessionListContent/CollapseGroup/Actions';
@@ -22,11 +24,11 @@ const SessionListContent = memo(() => {
const [configGroupModalOpen, setConfigGroupModalOpen] = useState(false);
const [useFetchSessions] = useSessionStore((s) => [s.useFetchSessions]);
const { data } = useFetchSessions();
useFetchSessions();
const pinnedSessions = data?.pinned;
const defaultSessions = data?.default;
const customSessionGroups = data?.customGroup;
const defaultSessions = useSessionStore(sessionSelectors.defaultSessions, isEqual);
const customSessionGroups = useSessionStore(sessionSelectors.customSessionGroups, isEqual);
const pinnedSessions = useSessionStore(sessionSelectors.pinnedSessions, isEqual);
const [sessionGroupKeys, updatePreference] = useGlobalStore((s) => [
preferenceSelectors.sessionGroupKeys(s),

View File

@@ -12,18 +12,13 @@ const AddButton = memo<{ groupId?: string }>(({ groupId }) => {
const { t } = useTranslation('chat');
const createSession = useSessionStore((s) => s.createSession);
const { mutate, isValidating } = useActionSWR('session.createSession', (groupId) =>
createSession({ group: groupId }),
);
const { mutate, isValidating } = useActionSWR(['session.createSession', groupId], () => {
return createSession({ group: groupId });
});
return (
<Flexbox style={{ margin: '12px 16px' }}>
<Button
block
icon={<Icon icon={Plus} />}
loading={isValidating}
onClick={() => mutate(groupId)}
>
<Button block icon={<Icon icon={Plus} />} loading={isValidating} onClick={() => mutate()}>
{t('newAgent')}
</Button>
</Flexbox>

View File

@@ -1,5 +1,5 @@
import { ActionIcon, EditableText, SortableList } from '@lobehub/ui';
import { App, Popconfirm } from 'antd';
import { App } from 'antd';
import { createStyles } from 'antd-style';
import { PencilLine, Trash } from 'lucide-react';
import { memo, useState } from 'react';
@@ -25,7 +25,7 @@ const useStyles = createStyles(({ css }) => ({
const GroupItem = memo<SessionGroupItem>(({ id, name }) => {
const { t } = useTranslation('chat');
const { styles } = useStyles();
const { message } = App.useApp();
const { message, modal } = App.useApp();
const [editing, setEditing] = useState(false);
const [updateSessionGroupName, removeSessionGroup] = useSessionStore((s) => [
@@ -40,29 +40,34 @@ const GroupItem = memo<SessionGroupItem>(({ id, name }) => {
<>
<span className={styles.title}>{name}</span>
<ActionIcon icon={PencilLine} onClick={() => setEditing(true)} size={'small'} />
<Popconfirm
arrow={false}
okButtonProps={{
danger: true,
type: 'primary',
<ActionIcon
icon={Trash}
onClick={() => {
modal.confirm({
centered: true,
okButtonProps: {
danger: true,
type: 'primary',
},
onOk: async () => {
await removeSessionGroup(id);
},
title: t('sessionGroup.confirmRemoveGroupAlert'),
});
}}
onConfirm={() => {
removeSessionGroup(id);
}}
title={t('sessionGroup.confirmRemoveGroupAlert')}
>
<ActionIcon icon={Trash} size={'small'} />
</Popconfirm>
size={'small'}
/>
</>
) : (
<EditableText
editing={editing}
onChangeEnd={(input) => {
onChangeEnd={async (input) => {
if (name !== input) {
if (!input) return;
if (input.length === 0 || input.length > 20)
return message.warning(t('sessionGroup.tooLong'));
updateSessionGroupName(id, input);
await updateSessionGroupName(id, input);
message.success(t('sessionGroup.renameSuccess'));
}
setEditing(false);

View File

@@ -3,7 +3,7 @@ import { Button } from 'antd';
import { createStyles } from 'antd-style';
import isEqual from 'fast-deep-equal';
import { Plus } from 'lucide-react';
import { memo } from 'react';
import { memo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { Flexbox } from 'react-layout-kit';
@@ -35,6 +35,7 @@ const ConfigGroupModal = memo<ModalProps>(({ open, onCancel }) => {
s.addSessionGroup,
s.updateSessionGroupSort,
]);
const [loading, setLoading] = useState(false);
return (
<Modal
@@ -67,7 +68,12 @@ const ConfigGroupModal = memo<ModalProps>(({ open, onCancel }) => {
<Button
block
icon={<Icon icon={Plus} />}
onClick={() => addSessionGroup(t('sessionGroup.newGroup'))}
loading={loading}
onClick={async () => {
setLoading(true);
await addSessionGroup(t('sessionGroup.newGroup'));
setLoading(false);
}}
>
{t('sessionGroup.createGroup')}
</Button>

View File

@@ -22,21 +22,29 @@ const CreateGroupModal = memo<CreateGroupModalProps>(
s.addSessionGroup,
]);
const [input, setInput] = useState('');
const [loading, setLoading] = useState(false);
return (
<div onClick={(e) => e.stopPropagation()}>
<Modal
allowFullscreen
onCancel={onCancel}
destroyOnClose
okButtonProps={{ loading }}
onCancel={(e) => {
setInput('');
onCancel?.(e);
}}
onOk={async (e: MouseEvent<HTMLButtonElement>) => {
if (!input) return;
if (input.length === 0 || input.length > 20)
return message.warning(t('sessionGroup.tooLong'));
setLoading(true);
const groupId = await addCustomGroup(input);
await updateSessionGroup(id, groupId);
toggleExpandSessionGroup(groupId, true);
setLoading(false);
message.success(t('sessionGroup.createSuccess'));
onCancel?.(e);

View File

@@ -18,18 +18,27 @@ const RenameGroupModal = memo<RenameGroupModalProps>(({ id, open, onCancel }) =>
const group = useSessionStore((s) => sessionGroupSelectors.getGroupById(id)(s), isEqual);
const [input, setInput] = useState<string>();
const [loading, setLoading] = useState(false);
const { message } = App.useApp();
return (
<Modal
allowFullscreen
onCancel={onCancel}
onOk={(e) => {
destroyOnClose
okButtonProps={{ loading }}
onCancel={(e) => {
setInput(group?.name);
onCancel?.(e);
}}
onOk={async (e) => {
if (!input) return;
if (input.length === 0 || input.length > 20)
return message.warning(t('sessionGroup.tooLong'));
updateSessionGroupName(id, input);
setLoading(true);
await updateSessionGroupName(id, input);
message.success(t('sessionGroup.renameSuccess'));
setLoading(false);
onCancel?.(e);
}}
open={open}

View File

@@ -8,8 +8,8 @@ import { sessionService } from '@/services/session';
const checkHasConversation = async () => {
const hasMessages = await messageService.hasMessages();
const hasAgents = await sessionService.countSessions();
return hasMessages || hasAgents === 0;
const hasAgents = await sessionService.hasSessions();
return hasMessages || hasAgents;
};
const Redirect = memo(() => {

View File

@@ -16,7 +16,7 @@ const Footer = memo(() => {
return (
<Flexbox align={'center'} horizontal justify={'space-between'} style={{ padding: 16 }}>
<span style={{ color: theme.colorTextDescription }}>
©{new Date().getFullYear()} LobeHub
© 2023 - {new Date().getFullYear()} LobeHub, LLC
</span>
<Flexbox horizontal>
<ActionIcon

View File

@@ -3,12 +3,12 @@
import { Icon } from '@lobehub/ui';
import { Button } from 'antd';
import { SendHorizonal } from 'lucide-react';
import Link from 'next/link';
import { useRouter } from 'next/navigation';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { Flexbox } from 'react-layout-kit';
import DataImporter from '@/features/DataImporter';
import { useGlobalStore } from '@/store/global';
import Hero from './Hero';
@@ -32,15 +32,11 @@ const Banner = memo<{ mobile?: boolean }>(({ mobile }) => {
justify={'center'}
width={'100%'}
>
<DataImporter
onFinishImport={() => {
router.push('/chat');
}}
>
<Button block={mobile} size={'large'}>
{t('button.import')}
<Link href={'/market'}>
<Button block={mobile} size={'large'} type={'default'}>
{t('button.market')}
</Button>
</DataImporter>
</Link>
<Button
block={mobile}
onClick={() => (isMobile ? router.push('/chat') : switchBackToChat())}

View File

@@ -179,8 +179,8 @@ describe('SessionModel', () => {
await SessionModel.create('agent', sessionData);
const sessionsWithGroups = await SessionModel.queryWithGroups();
expect(sessionsWithGroups.all).toHaveLength(1);
expect(sessionsWithGroups.all[0]).toEqual(expect.objectContaining(sessionData));
expect(sessionsWithGroups.sessions).toHaveLength(1);
expect(sessionsWithGroups.sessions[0]).toEqual(expect.objectContaining(sessionData));
});
});

View File

@@ -43,21 +43,11 @@ class _SessionModel extends BaseModel {
}
async queryWithGroups(): Promise<ChatSessionList> {
const groups = await SessionGroupModel.query();
const customGroups = await this.queryByGroupIds(groups.map((item) => item.id));
const defaultItems = await this.querySessionsByGroupId(SessionDefaultGroup.Default);
const pinnedItems = await this.getPinnedSessions();
const sessionGroups = await SessionGroupModel.query();
const all = await this.query();
return {
all,
customGroup: groups.map((group) => ({
...group,
children: customGroups[group.id],
})),
default: defaultItems,
pinned: pinnedItems,
};
const sessions = await this.query();
return { sessionGroups, sessions };
}
/**

View File

@@ -50,13 +50,17 @@ export default {
sessionGroup: {
config: '分组管理',
confirmRemoveGroupAlert: '即将删除该分组,删除后该分组的助手将移动到默认列表,请确认你的操作',
createAgentSuccess: '助手创建成功',
createGroup: '添加新分组',
createSuccess: '创建成功',
createSuccess: '分组创建成功',
creatingAgent: '助手创建中...',
inputPlaceholder: '请输入分组名称...',
moveGroup: '移动到分组',
newGroup: '新分组',
rename: '重命名分组',
renameSuccess: '重命名成功',
sortSuccess: '重新排序成功',
sorting: '分组排序更新中...',
tooLong: '分组名称长度需在 1-20 之内',
},
shareModal: {
@@ -126,6 +130,6 @@ export default {
dragDesc: '拖拽文件到这里,支持上传多个图片。按住 Shift 直接发送图片',
dragFileDesc: '拖拽图片和文件到这里,支持上传多个图片和文件。按住 Shift 直接发送图片或文件',
dragFileTitle: '上传文件',
dragTitle: '上传图片'
dragTitle: '上传图片',
},
};

View File

@@ -1,6 +1,7 @@
export default {
button: {
import: '导入配置',
market: '逛逛市场',
start: '立即开始',
},
header: '欢迎使用',

View File

@@ -65,31 +65,15 @@ class ConfigService {
case 'all': {
await this.importSettings(config.state.settings);
const sessionGroups = await this.importSessionGroups(config.state.sessionGroups);
const [sessions, messages, topics] = await Promise.all([
this.importSessions(config.state.sessions),
this.importMessages(config.state.messages),
this.importTopics(config.state.topics),
]);
return {
messages: this.mapImportResult(messages),
sessionGroups: this.mapImportResult(sessionGroups),
sessions: this.mapImportResult(sessions),
topics: this.mapImportResult(topics),
};
}
// all and sessions have the same data process, so we can fall through
// eslint-disable-next-line no-fallthrough
case 'sessions': {
const sessionGroups = await this.importSessionGroups(config.state.sessionGroups);
const [sessions, messages, topics] = await Promise.all([
this.importSessions(config.state.sessions),
this.importMessages(config.state.messages),
this.importTopics(config.state.topics),
]);
const sessions = await this.importSessions(config.state.sessions);
const topics = await this.importTopics(config.state.topics);
const messages = await this.importMessages(config.state.messages);
return {
messages: this.mapImportResult(messages),

View File

@@ -62,4 +62,14 @@ export class ClientService implements IMessageService {
async removeAllMessages() {
return MessageModel.clearTable();
}
async hasMessages() {
const number = await this.countMessages();
return number > 0;
}
async messageCountToCheckTrace() {
const number = await this.countMessages();
return number >= 4;
}
}

View File

@@ -9,16 +9,4 @@ import { ClientService } from './client';
export type { CreateMessageParams } from './type';
class MessageService extends ClientService {
async hasMessages() {
const number = await this.countMessages();
return number > 0;
}
async messageCountToCheckTrace() {
const number = await this.countMessages();
return number >= 4;
}
}
export const messageService = new MessageService();
export const messageService = new ClientService();

View File

@@ -30,4 +30,7 @@ export interface IMessageService {
removeMessage(id: string): Promise<any>;
removeMessages(assistantId: string, topicId?: string): Promise<any>;
removeAllMessages(): Promise<any>;
hasMessages(): Promise<boolean>;
messageCountToCheckTrace(): Promise<boolean>;
}

View File

@@ -64,6 +64,9 @@ export class ClientService implements ISessionService {
async countSessions() {
return SessionModel.count();
}
async hasSessions() {
return (await this.countSessions()) === 0;
}
async searchSessions(keyword: string) {
return SessionModel.queryByKeyword(keyword);

View File

@@ -2,11 +2,13 @@
import { DeepPartial } from 'utility-types';
import { LobeAgentConfig } from '@/types/agent';
import { BatchTaskResult } from '@/types/service';
import {
ChatSessionList,
LobeAgentSession,
LobeSessionType,
LobeSessions,
SessionGroupId,
SessionGroupItem,
SessionGroups,
} from '@/types/session';
@@ -19,9 +21,13 @@ export interface ISessionService {
getGroupedSessions(): Promise<ChatSessionList>;
getSessionsByType(type: 'agent' | 'group' | 'all'): Promise<LobeSessions>;
countSessions(): Promise<number>;
hasSessions(): Promise<boolean>;
searchSessions(keyword: string): Promise<LobeSessions>;
updateSession(id: string, data: Partial<Pick<LobeAgentSession, 'group' | 'meta'>>): Promise<any>;
updateSession(
id: string,
data: Partial<{ group?: SessionGroupId; pinned?: boolean }>,
): Promise<any>;
updateSessionConfig(id: string, config: DeepPartial<LobeAgentConfig>): Promise<any>;
removeSession(id: string): Promise<any>;
@@ -32,7 +38,7 @@ export interface ISessionService {
// ************************************** //
createSessionGroup(name: string, sort?: number): Promise<string>;
batchCreateSessionGroups(groups: SessionGroups): Promise<any>;
batchCreateSessionGroups(groups: SessionGroups): Promise<BatchTaskResult>;
getSessionGroups(): Promise<SessionGroupItem[]>;

View File

@@ -5,7 +5,7 @@ import { CreateTopicParams, ITopicService, QueryTopicParams } from './type';
export class ClientService implements ITopicService {
async createTopic(params: CreateTopicParams): Promise<string> {
const item = await TopicModel.create(params);
const item = await TopicModel.create(params as any);
if (!item) {
throw new Error('topic create Error');

View File

@@ -1,10 +1,11 @@
/* eslint-disable typescript-sort-keys/interface */
import { BatchTaskResult } from '@/types/service';
import { ChatTopic } from '@/types/topic';
export interface CreateTopicParams {
favorite?: boolean;
messages?: string[];
sessionId: string;
sessionId?: string | null;
title: string;
}
@@ -16,7 +17,7 @@ export interface QueryTopicParams {
export interface ITopicService {
createTopic(params: CreateTopicParams): Promise<string>;
batchCreateTopics(importTopics: ChatTopic[]): Promise<any>;
batchCreateTopics(importTopics: ChatTopic[]): Promise<BatchTaskResult>;
cloneTopic(id: string, newTitle?: string): Promise<string>;
getTopics(params: QueryTopicParams): Promise<ChatTopic[]>;

View File

@@ -1,5 +1,4 @@
import { t } from 'i18next';
import { produce } from 'immer';
import useSWR, { SWRResponse, mutate } from 'swr';
import { DeepPartial } from 'utility-types';
import { StateCreator } from 'zustand/vanilla';
@@ -11,12 +10,20 @@ import { sessionService } from '@/services/session';
import { useGlobalStore } from '@/store/global';
import { settingsSelectors } from '@/store/global/selectors';
import { SessionStore } from '@/store/session';
import { ChatSessionList, LobeAgentSession, LobeSessionType, LobeSessions } from '@/types/session';
import {
ChatSessionList,
LobeAgentSession,
LobeSessionGroups,
LobeSessionType,
LobeSessions,
SessionGroupId,
} from '@/types/session';
import { merge } from '@/utils/merge';
import { setNamespace } from '@/utils/storeDebug';
import { agentSelectors } from '../agent/selectors';
import { initLobeSession } from './initialState';
import { SessionDispatch, sessionsReducer } from './reducers';
import { sessionSelectors } from './selectors';
const n = setNamespace('session');
@@ -24,6 +31,7 @@ const n = setNamespace('session');
const FETCH_SESSIONS_KEY = 'fetchSessions';
const SEARCH_SESSIONS_KEY = 'searchSessions';
/* eslint-disable typescript-sort-keys/interface */
export interface SessionAction {
/**
* active the session
@@ -44,6 +52,8 @@ export interface SessionAction {
isSwitchSession?: boolean,
) => Promise<string>;
duplicateSession: (id: string) => Promise<void>;
updateSessionGroupId: (sessionId: string, groupId: string) => Promise<void>;
/**
* Pins or unpins a session.
*/
@@ -52,17 +62,26 @@ export interface SessionAction {
* re-fetch the data
*/
refreshSessions: (params?: SWRRefreshParams<ChatSessionList>) => Promise<void>;
/**
* remove session
* @param id - sessionId
*/
removeSession: (id: string) => void;
/**
* A custom hook that uses SWR to fetch sessions data.
*/
removeSession: (id: string) => Promise<void>;
useFetchSessions: () => SWRResponse<ChatSessionList>;
useSearchSessions: (keyword?: string) => SWRResponse<any>;
internal_dispatchSessions: (payload: SessionDispatch) => void;
internal_updateSession: (
id: string,
data: Partial<{ group?: SessionGroupId; meta?: any; pinned?: boolean }>,
) => Promise<void>;
internal_processSessions: (
sessions: LobeSessions,
customGroups: LobeSessionGroups,
actions?: string,
) => void;
/* eslint-enable */
}
export const createSessionSlice: StateCreator<
@@ -101,7 +120,6 @@ export const createSessionSlice: StateCreator<
return id;
},
duplicateSession: async (id) => {
const { activeSession, refreshSessions } = get();
const session = sessionSelectors.getSessionById(id)(get());
@@ -135,63 +153,12 @@ export const createSessionSlice: StateCreator<
activeSession(newId);
},
pinSession: async (sessionId, pinned) => {
await get().refreshSessions({
action: async () => {
await sessionService.updateSession(sessionId, { pinned });
},
// 乐观更新
optimisticData: produce((draft) => {
if (!draft) return;
const session = draft.all.find((i) => i.id === sessionId);
if (!session) return;
session.pinned = pinned;
if (pinned) {
draft.pinned.unshift(session);
if (session.group === 'default') {
const index = draft.default.findIndex((i) => i.id === sessionId);
draft.default.splice(index, 1);
} else {
const customGroup = draft.customGroup.find((group) => group.id === session.group);
if (customGroup) {
const index = customGroup.children.findIndex((i) => i.id === sessionId);
customGroup.children.splice(index, 1);
}
}
} else {
const index = draft.pinned.findIndex((i) => i.id === sessionId);
if (index !== -1) {
draft.pinned.splice(index, 1);
}
if (session.group === 'default') {
draft.default.push(session);
} else {
const customGroup = draft.customGroup.find((group) => group.id === session.group);
if (customGroup) {
customGroup.children.push(session);
}
}
}
}),
});
pinSession: async (id, pinned) => {
await get().internal_updateSession(id, { pinned });
},
refreshSessions: async (params) => {
if (params) {
// @ts-ignore
await mutate(FETCH_SESSIONS_KEY, params.action, {
optimisticData: params.optimisticData,
// we won't need to make the action's data go into cache ,or the display will be
// old -> optimistic -> undefined -> new
populateCache: false,
});
} else await mutate(FETCH_SESSIONS_KEY);
refreshSessions: async () => {
await mutate(FETCH_SESSIONS_KEY);
},
removeSession: async (sessionId) => {
@@ -204,8 +171,10 @@ export const createSessionSlice: StateCreator<
}
},
// TODO: 这里的逻辑需要优化,后续不应该是直接请求一个大的 sessions 数据
// 最好拆成一个 all 请求,然后在前端完成 groupBy 的分组逻辑
updateSessionGroupId: async (sessionId, group) => {
await get().internal_updateSession(sessionId, { group });
},
useFetchSessions: () =>
useClientDataSWR<ChatSessionList>(FETCH_SESSIONS_KEY, sessionService.getGroupedSessions, {
onSuccess: (data) => {
@@ -217,20 +186,14 @@ export const createSessionSlice: StateCreator<
// TODO后续的根本解法应该是解除 inbox 和 session 的数据耦合
// 避免互相依赖的情况出现
set(
{
customSessionGroups: data.customGroup,
defaultSessions: data.default,
isSessionsFirstFetchFinished: true,
pinnedSessions: data.pinned,
sessions: data.all,
},
false,
n('useFetchSessions/onSuccess', data),
get().internal_processSessions(
data.sessions,
data.sessionGroups,
n('useFetchSessions/updateData') as any,
);
set({ isSessionsFirstFetchFinished: true }, false, n('useFetchSessions/onSuccess', data));
},
}),
useSearchSessions: (keyword) =>
useSWR<LobeSessions>(
[SEARCH_SESSIONS_KEY, keyword],
@@ -241,4 +204,39 @@ export const createSessionSlice: StateCreator<
},
{ revalidateOnFocus: false, revalidateOnMount: false },
),
/* eslint-disable sort-keys-fix/sort-keys-fix */
internal_dispatchSessions: (payload) => {
const nextSessions = sessionsReducer(get().sessions, payload);
get().internal_processSessions(nextSessions, get().sessionGroups);
},
internal_updateSession: async (id, data) => {
get().internal_dispatchSessions({ type: 'updateSession', id, value: data });
await sessionService.updateSession(id, data);
await get().refreshSessions();
},
internal_processSessions: (sessions, sessionGroups) => {
const customGroups = sessionGroups.map((item) => ({
...item,
children: sessions.filter((i) => i.group === item.id && !i.pinned),
}));
const defaultGroup = sessions.filter(
(item) => (!item.group || item.group === 'default') && !item.pinned,
);
const pinnedGroup = sessions.filter((item) => item.pinned);
set(
{
customSessionGroups: customGroups,
defaultSessions: defaultGroup,
pinnedSessions: pinnedGroup,
sessionGroups,
sessions,
},
false,
n('processSessions'),
);
},
});

View File

@@ -1,6 +1,11 @@
import { DEFAULT_AGENT_META } from '@/const/meta';
import { DEFAULT_AGENT_CONFIG } from '@/const/settings';
import { CustomSessionGroup, LobeAgentSession, LobeSessionType } from '@/types/session';
import {
CustomSessionGroup,
LobeAgentSession,
LobeSessionGroups,
LobeSessionType,
} from '@/types/session';
export const initLobeSession: LobeAgentSession = {
config: DEFAULT_AGENT_CONFIG,
@@ -24,6 +29,7 @@ export interface SessionState {
isSessionsFirstFetchFinished: boolean;
pinnedSessions: LobeAgentSession[];
searchKeywords: string;
sessionGroups: LobeSessionGroups;
sessionSearchKeywords?: string;
/**
* it means defaultSessions
@@ -40,5 +46,6 @@ export const initialSessionState: SessionState = {
isSessionsFirstFetchFinished: false,
pinnedSessions: [],
searchKeywords: '',
sessionGroups: [],
sessions: [],
};

View File

@@ -0,0 +1,79 @@
import { nanoid } from 'nanoid';
import { describe, expect, it, vi } from 'vitest';
import { LobeAgentConfig } from '@/types/agent';
import { LobeAgentSession, LobeSessions } from '@/types/session';
import { SessionDispatch, sessionsReducer } from './reducers';
describe('sessionsReducer', () => {
const mockSession = {
id: nanoid(),
config: {
model: 'gpt-3.5-turbo',
} as any,
meta: {
title: 'Test Agent',
description: 'A test agent',
avatar: '',
},
} as any;
const initialState: LobeSessions = [];
it('should add a new session', () => {
const addAction: SessionDispatch = {
session: mockSession,
type: 'addSession',
};
const newState = sessionsReducer(initialState, addAction);
expect(newState).toHaveLength(1);
expect(newState[0]).toMatchObject({
...mockSession,
createdAt: expect.any(Date),
updatedAt: expect.any(Date),
});
});
it('should remove an existing session', () => {
const state: LobeSessions = [mockSession];
const removeAction: SessionDispatch = {
id: mockSession.id,
type: 'removeSession',
};
const newState = sessionsReducer(state, removeAction);
expect(newState).toHaveLength(0);
});
it('should update an existing session', () => {
const state: LobeSessions = [mockSession];
const updateAction: SessionDispatch = {
id: mockSession.id,
type: 'updateSession',
value: { group: 'abc' },
};
const newState = sessionsReducer(state, updateAction);
expect(newState).toHaveLength(1);
expect(newState[0]).toMatchObject({
...mockSession,
group: 'abc',
updatedAt: expect.any(Date),
});
});
it('should return the same state for unknown action', () => {
const state: LobeSessions = [mockSession];
// @ts-ignore
const unknownAction: SessionDispatch = { type: 'unknown' };
const newState = sessionsReducer(state, unknownAction);
expect(newState).toEqual(state);
});
});

View File

@@ -0,0 +1,61 @@
import { produce } from 'immer';
import { LobeAgentSession, LobeSessions } from '@/types/session';
interface AddSession {
session: LobeAgentSession;
type: 'addSession';
}
interface RemoveSession {
id: string;
type: 'removeSession';
}
interface UpdateSession {
id: string;
type: 'updateSession';
value: Partial<LobeAgentSession>;
}
export type SessionDispatch = AddSession | RemoveSession | UpdateSession;
export const sessionsReducer = (state: LobeSessions, payload: SessionDispatch): LobeSessions => {
switch (payload.type) {
case 'addSession': {
return produce(state, (draft) => {
const { session } = payload;
if (!session) return;
// TODO: 后续将 Date 类型做个迁移,就可以移除这里的 ignore 了
// @ts-ignore
draft.unshift({ ...session, createdAt: new Date(), updatedAt: new Date() });
});
}
case 'removeSession': {
return produce(state, (draftState) => {
const index = draftState.findIndex((item) => item.id === payload.id);
if (index !== -1) {
draftState.splice(index, 1);
}
});
}
case 'updateSession': {
return produce(state, (draftState) => {
const { value, id } = payload;
const index = draftState.findIndex((item) => item.id === id);
if (index !== -1) {
// @ts-ignore
draftState[index] = { ...draftState[index], ...value, updatedAt: new Date() };
}
});
}
default: {
return produce(state, () => {});
}
}
};

View File

@@ -8,6 +8,15 @@ afterEach(() => {
vi.restoreAllMocks();
});
vi.mock('@/components/AntdStaticMethods', () => ({
message: {
loading: vi.fn(),
success: vi.fn(),
error: vi.fn(),
destroy: vi.fn(),
},
}));
describe('createSessionGroupSlice', () => {
describe('addSessionGroup', () => {
it('should add a session group and refresh sessions', async () => {

View File

@@ -1,17 +1,23 @@
import { t } from 'i18next';
import { StateCreator } from 'zustand/vanilla';
import { message } from '@/components/AntdStaticMethods';
import { sessionService } from '@/services/session';
import { SessionStore } from '@/store/session';
import { SessionGroupItem } from '@/types/session';
import { SessionGroupsDispatch, sessionGroupsReducer } from './reducer';
/* eslint-disable typescript-sort-keys/interface */
export interface SessionGroupAction {
addSessionGroup: (name: string) => Promise<string>;
clearSessionGroups: () => Promise<void>;
removeSessionGroup: (id: string) => Promise<void>;
updateSessionGroupId: (sessionId: string, groupId: string) => Promise<void>;
updateSessionGroupName: (id: string, name: string) => Promise<void>;
updateSessionGroupSort: (items: SessionGroupItem[]) => Promise<void>;
internal_dispatchSessionGroups: (payload: SessionGroupsDispatch) => void;
}
/* eslint-enable */
export const createSessionGroupSlice: StateCreator<
SessionStore,
@@ -36,11 +42,6 @@ export const createSessionGroupSlice: StateCreator<
await sessionService.removeSessionGroup(id);
await get().refreshSessions();
},
updateSessionGroupId: async (sessionId, group) => {
await sessionService.updateSession(sessionId, { group });
await get().refreshSessions();
},
updateSessionGroupName: async (id, name) => {
await sessionService.updateSessionGroup(id, { name });
@@ -48,7 +49,25 @@ export const createSessionGroupSlice: StateCreator<
},
updateSessionGroupSort: async (items) => {
const sortMap = items.map((item, index) => ({ id: item.id, sort: index }));
get().internal_dispatchSessionGroups({ sortMap, type: 'updateSessionGroupOrder' });
message.loading({
content: t('sessionGroup.sorting', { ns: 'chat' }),
duration: 0,
key: 'updateSessionGroupSort',
});
await sessionService.updateSessionGroupOrder(sortMap);
message.destroy('updateSessionGroupSort');
message.success(t('sessionGroup.sortSuccess', { ns: 'chat' }));
await get().refreshSessions();
},
/* eslint-disable sort-keys-fix/sort-keys-fix */
internal_dispatchSessionGroups: (payload) => {
const nextSessionGroups = sessionGroupsReducer(get().sessionGroups, payload);
get().internal_processSessions(get().sessions, nextSessionGroups, 'updateSessionGroups');
},
});

View File

@@ -0,0 +1,86 @@
import { nanoid } from 'nanoid';
import { describe, expect, it } from 'vitest';
import { SessionGroupItem } from '@/types/session';
import { sessionGroupsReducer } from './reducer';
describe('sessionGroupsReducer', () => {
const initialState: SessionGroupItem[] = [
{
id: nanoid(),
name: 'Group 1',
createdAt: Date.now(),
updatedAt: Date.now(),
},
{
id: nanoid(),
name: 'Group 2',
createdAt: Date.now(),
updatedAt: Date.now(),
sort: 1,
},
];
it('should add a new session group item', () => {
const newItem: SessionGroupItem = {
id: nanoid(),
name: 'New Group',
createdAt: Date.now(),
updatedAt: Date.now(),
};
const result = sessionGroupsReducer(initialState, {
type: 'addSessionGroupItem',
item: newItem,
});
expect(result).toHaveLength(3);
expect(result).toContainEqual(newItem);
});
it('should delete a session group item', () => {
const itemToDelete = initialState[0].id;
const result = sessionGroupsReducer(initialState, {
type: 'deleteSessionGroupItem',
id: itemToDelete,
});
expect(result).toHaveLength(1);
expect(result).not.toContainEqual(expect.objectContaining({ id: itemToDelete }));
});
it('should update a session group item', () => {
const itemToUpdate = initialState[0].id;
const updatedItem = { name: 'Updated Group' };
const result = sessionGroupsReducer(initialState, {
type: 'updateSessionGroupItem',
id: itemToUpdate,
item: updatedItem,
});
expect(result).toHaveLength(2);
expect(result).toContainEqual(expect.objectContaining({ id: itemToUpdate, ...updatedItem }));
});
it('should update session group order', () => {
const sortMap = [
{ id: initialState[1].id, sort: 0 },
{ id: initialState[0].id, sort: 1 },
];
const result = sessionGroupsReducer(initialState, { type: 'updateSessionGroupOrder', sortMap });
expect(result).toHaveLength(2);
expect(result[0].id).toBe(initialState[1].id);
expect(result[1].id).toBe(initialState[0].id);
});
it('should return the initial state for unknown action', () => {
const result = sessionGroupsReducer(initialState, { type: 'unknown' } as any);
expect(result).toEqual(initialState);
});
});

View File

@@ -0,0 +1,56 @@
import { SessionGroupItem } from '@/types/session';
export type AddSessionGroupAction = { item: SessionGroupItem; type: 'addSessionGroupItem' };
export type DeleteSessionGroupAction = { id: string; type: 'deleteSessionGroupItem' };
export type UpdateSessionGroupAction = {
id: string;
item: Partial<SessionGroupItem>;
type: 'updateSessionGroupItem';
};
export type UpdateSessionGroupOrderAction = {
sortMap: { id: string; sort?: number }[];
type: 'updateSessionGroupOrder';
};
export type SessionGroupsDispatch =
| AddSessionGroupAction
| DeleteSessionGroupAction
| UpdateSessionGroupAction
| UpdateSessionGroupOrderAction;
export const sessionGroupsReducer = (
state: SessionGroupItem[],
payload: SessionGroupsDispatch,
): SessionGroupItem[] => {
switch (payload.type) {
case 'addSessionGroupItem': {
return [...state, payload.item];
}
case 'deleteSessionGroupItem': {
return state.filter((item) => item.id !== payload.id);
}
case 'updateSessionGroupItem': {
return state.map((item) => {
if (item.id === payload.id) {
return { ...item, ...payload.item };
}
return item;
});
}
case 'updateSessionGroupOrder': {
return state
.map((item) => {
const sort = payload.sortMap.find((i) => i.id === item.id)?.sort;
return { ...item, sort };
})
.sort((a, b) => (a.sort || 0) - (b.sort || 0));
}
default: {
return state;
}
}
};

View File

@@ -1,10 +1,6 @@
import { SessionStore } from '@/store/session';
const sessionGroupItems = (s: SessionStore) =>
s.customSessionGroups.map((group) => ({
id: group.id,
name: group.name,
}));
const sessionGroupItems = (s: SessionStore) => s.sessionGroups;
const getGroupById = (id: string) => (s: SessionStore) =>
sessionGroupItems(s).find((group) => group.id === id);

7
src/types/service.ts Normal file
View File

@@ -0,0 +1,7 @@
export interface BatchTaskResult {
added: number;
errors?: Error[];
ids: string[];
skips: string[];
success: boolean;
}

View File

@@ -44,15 +44,13 @@ export interface LobeAgentSettings {
export type LobeSessions = LobeAgentSession[];
export interface CustomSessionGroup {
export interface CustomSessionGroup extends SessionGroupItem {
children: LobeSessions;
id: SessionGroupId;
name: string;
}
export type LobeSessionGroups = SessionGroupItem[];
export interface ChatSessionList {
all: LobeSessions;
customGroup: CustomSessionGroup[];
default: LobeSessions;
pinned: LobeSessions;
sessionGroups: LobeSessionGroups;
sessions: LobeSessions;
}

View File

@@ -1,9 +1,9 @@
// generate('1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ', 16); //=> "4f90d13a42"
import { customAlphabet } from 'nanoid/non-secure';
export const nanoid = customAlphabet(
'1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ',
8,
);
export const createNanoId = (size = 8) =>
customAlphabet('1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ', size);
export const nanoid = createNanoId();
export { v4 as uuid } from 'uuid';