feat(model): improve model list UI and add disabled models management (#11036)

*  feat(model): improve model list UI and add disabled models management

- Enhanced DisabledModels component with better UI/UX
- Updated ModelList layout and interactions
- Added repository methods for disabled model management
- Improved AI model service and router functionality
- Added tests for new functionality

*  feat(DisabledModels): enhance loading and rendering logic for disabled models

- Implemented pagination and dynamic loading for disabled models
- Improved state management for visible models and loading conditions
- Ensured unique model entries in the displayed list
- Updated component to handle provider changes effectively

Signed-off-by: Innei <tukon479@gmail.com>

* fix(DisabledModels): handle edge case for last page in pagination logic

- Added a check to ensure lastPage is defined before evaluating pagination end conditions
- Improved robustness of loading state management in DisabledModels component

Signed-off-by: Innei <tukon479@gmail.com>

* lint

* lint

* lint

---------

Signed-off-by: Innei <tukon479@gmail.com>
This commit is contained in:
Innei
2025-12-30 16:49:12 +08:00
committed by GitHub
parent 381cf51ec0
commit 4faa65c6af
10 changed files with 290 additions and 63 deletions

View File

@@ -57,7 +57,7 @@ const { spawnMock } = vi.hoisted(() => ({
}));
vi.mock('node:child_process', () => ({
spawn: (...args: any[]) => spawnMock(...args),
spawn: (...args: any[]) => spawnMock.call(null, ...args),
}));
// Mock electron
@@ -229,7 +229,9 @@ describe('SystemController', () => {
await invokeIpc('system.openFullDiskAccessSettings');
expect(shell.openExternal).toHaveBeenCalledWith('com.apple.settings:Privacy&path=FullDiskAccess');
expect(shell.openExternal).toHaveBeenCalledWith(
'com.apple.settings:Privacy&path=FullDiskAccess',
);
expect(shell.openExternal).toHaveBeenCalledWith(
'x-apple.systempreferences:com.apple.preference.security?Privacy_AllFiles',
);

View File

@@ -913,6 +913,55 @@ describe('AiInfraRepos', () => {
expect(result).toHaveLength(0);
});
it('should support offset/limit pagination', async () => {
const providerId = 'openai';
const userModels = Array.from({ length: 5 }).map((_, i) => ({
enabled: i % 2 === 0,
id: `u-${i + 1}`,
type: 'chat',
})) as AiProviderModelListItem[];
const builtinModels = Array.from({ length: 5 }).map((_, i) => ({
enabled: true,
id: `b-${i + 1}`,
type: 'chat',
})) as AiProviderModelListItem[];
vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(userModels);
vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(builtinModels);
const all = await repo.getAiProviderModelList(providerId);
const result = await repo.getAiProviderModelList(providerId, { limit: 3, offset: 2 });
expect(result.map((i) => i.id)).toEqual(all.slice(2, 5).map((i) => i.id));
});
it('should support enabled filter with pagination', async () => {
const providerId = 'openai';
const userModels = [
{ enabled: false, id: 'u-1', type: 'chat' },
{ enabled: true, id: 'u-2', type: 'chat' },
{ enabled: false, id: 'u-3', type: 'chat' },
] as AiProviderModelListItem[];
const builtinModels = [
{ enabled: false, id: 'b-1', type: 'chat' },
{ enabled: true, id: 'b-2', type: 'chat' },
] as AiProviderModelListItem[];
vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(userModels);
vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(builtinModels);
const result = await repo.getAiProviderModelList(providerId, {
enabled: false,
limit: 10,
offset: 0,
});
expect(result.map((i) => i.id)).toEqual(['u-1', 'u-3', 'b-1']);
});
// New tests for getAiProviderModelList per the corrected behavior
it('should allow search=true and add searchImpl=params when user enables it without providing settings (builtin has no search and no settings)', async () => {
const providerId = 'openai';
@@ -946,9 +995,9 @@ describe('AiInfraRepos', () => {
const merged = result.find((m) => m.id === 'gpt-4');
expect(merged).toBeDefined();
expect(merged.abilities).toMatchObject({ search: true });
expect(merged!.abilities).toMatchObject({ search: true });
// when user enables search with no settings, default searchImpl should be 'params'
expect(merged.settings).toEqual({ searchImpl: 'params' });
expect(merged!.settings).toEqual({ searchImpl: 'params' });
});
it('should remove builtin search settings and disable search when user turns search off', async () => {
@@ -984,9 +1033,9 @@ describe('AiInfraRepos', () => {
const merged = result.find((m) => m.id === 'gpt-4');
expect(merged).toBeDefined();
// User's choice takes precedence
expect(merged.abilities).toMatchObject({ search: false });
expect(merged!.abilities).toMatchObject({ search: false });
// Builtin search settings should be removed since user turned search off
expect(merged.settings).toBeUndefined();
expect(merged!.settings).toBeUndefined();
});
it('should set search=true and settings=params for custom provider when user enables search and builtin has no search/settings', async () => {

View File

@@ -280,7 +280,14 @@ export class AiInfraRepos {
};
};
getAiProviderModelList = async (providerId: string) => {
getAiProviderModelList = async (
providerId: string,
options?: {
enabled?: boolean;
limit?: number;
offset?: number;
},
) => {
const aiModels = await this.aiModelModel.getModelListByProviderId(providerId);
const defaultModels: AiProviderModelListItem[] =
@@ -288,7 +295,22 @@ export class AiInfraRepos {
// Not modifying search settings here doesn't affect usage, but done for data consistency on get
const mergedModel = mergeArrayById(defaultModels, aiModels) as AiProviderModelListItem[];
return mergedModel.map((m) => injectSearchSettings(providerId, m));
let list = mergedModel.map((m) =>
injectSearchSettings(providerId, m),
) as AiProviderModelListItem[];
if (typeof options?.enabled === 'boolean') {
list = list.filter((m) => m.enabled === options.enabled);
}
if (typeof options?.offset === 'number' || typeof options?.limit === 'number') {
const offset = Math.max(0, options?.offset ?? 0);
const limit = options?.limit;
if (typeof limit === 'number') return list.slice(offset, offset + Math.max(0, limit));
return list.slice(offset);
}
return list;
};
/**

View File

@@ -1,10 +1,13 @@
import { ActionIcon, Button, Dropdown, Flexbox, Icon, Text, TooltipGroup } from '@lobehub/ui';
import { ActionIcon, Dropdown, Flexbox, Icon, Text, TooltipGroup } from '@lobehub/ui';
import type { ItemType } from 'antd/es/menu/interface';
import isEqual from 'fast-deep-equal';
import { ArrowDownUpIcon, ChevronDown, LucideCheck } from 'lucide-react';
import { memo, useCallback, useMemo, useState } from 'react';
import { ArrowDownUpIcon, LucideCheck } from 'lucide-react';
import type { AiProviderModelListItem } from 'model-bank';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import useSWRInfinite from 'swr/infinite';
import { aiModelService } from '@/services/aiModel';
import { useAiInfraStore } from '@/store/aiInfra';
import { aiModelSelectors } from '@/store/aiInfra/selectors';
import { useGlobalStore } from '@/store/global';
@@ -14,6 +17,7 @@ import ModelItem from './ModelItem';
interface DisabledModelsProps {
activeTab: string;
providerId: string;
}
// Sort type enumeration
@@ -25,16 +29,21 @@ enum SortType {
ReleasedAtDesc = 'releasedAtDesc',
}
const DisabledModels = memo<DisabledModelsProps>(({ activeTab }) => {
const { t } = useTranslation('modelProvider');
const PAGE_SIZE = 30;
const FETCH_DISABLED_MODELS_PAGE_KEY = 'FETCH_DISABLED_MODELS_PAGE';
const [showMore, setShowMore] = useState(false);
const DisabledModels = memo<DisabledModelsProps>(({ activeTab, providerId }) => {
const { t } = useTranslation(['modelProvider', 'common']);
const [sortType, updateSystemStatus] = useGlobalStore((s) => [
systemStatusSelectors.disabledModelsSortType(s),
s.updateSystemStatus,
]);
// Only render at most PAGE_SIZE items from the store on first paint.
// As user scrolls, we reveal more store items in PAGE_SIZE steps, then start remote loading.
const [visibleBaseSize, setVisibleBaseSize] = useState(PAGE_SIZE);
const updateSortType = useCallback(
(newSortType: SortType) => {
updateSystemStatus({ disabledModelsSortType: newSortType });
@@ -42,13 +51,133 @@ const DisabledModels = memo<DisabledModelsProps>(({ activeTab }) => {
[updateSystemStatus],
);
// initial render source: provider list already in store (typically built-in + user merged)
const disabledModels = useAiInfraStore(aiModelSelectors.disabledAiProviderModelList, isEqual);
const baseIds = useMemo(() => new Set(disabledModels.map((m) => m.id)), [disabledModels]);
useEffect(() => {
// reset visible window on provider change
setVisibleBaseSize(PAGE_SIZE);
}, [providerId]);
const visibleBaseModels = useMemo(
() => disabledModels.slice(0, visibleBaseSize),
[disabledModels, visibleBaseSize],
);
const hasMoreBase = visibleBaseSize < disabledModels.length;
const remoteEnabled = !!providerId && !hasMoreBase;
const getKey = useCallback(
(pageIndex: number, previousPageData: AiProviderModelListItem[] | null) => {
if (!remoteEnabled) return null;
if (previousPageData && previousPageData.length < PAGE_SIZE) return null;
// start fetching after the initial list from store
const offset = disabledModels.length + pageIndex * PAGE_SIZE;
return [FETCH_DISABLED_MODELS_PAGE_KEY, providerId, offset] as const;
},
[disabledModels.length, providerId, remoteEnabled],
);
const {
data: pages,
error,
isValidating,
setSize,
size,
} = useSWRInfinite<AiProviderModelListItem[]>(getKey, async ([, id, offset]) => {
return aiModelService.getAiProviderModelList(id as string, {
enabled: false,
limit: PAGE_SIZE,
offset: offset as number,
});
});
const pagedDisabledModels = useMemo(() => (pages ? pages.flat() : []), [pages]);
// ensure "load more" pages do not duplicate the initial store list
const appendedDisabledModels = useMemo(() => {
if (!pagedDisabledModels.length) return [];
return pagedDisabledModels.filter((m) => !baseIds.has(m.id));
}, [baseIds, pagedDisabledModels]);
const mergedDisabledModels = useMemo(() => {
// keep store order for initial items; append new ones afterwards
const exists = new Set<string>();
const merged: AiProviderModelListItem[] = [];
visibleBaseModels.forEach((m) => {
if (exists.has(m.id)) return;
exists.add(m.id);
merged.push(m);
});
appendedDisabledModels.forEach((m) => {
if (exists.has(m.id)) return;
exists.add(m.id);
merged.push(m);
});
return merged;
}, [appendedDisabledModels, visibleBaseModels]);
const isInitialLoading = remoteEnabled && !pages && !error;
const isReachingEnd = useMemo(() => {
if (!pages || pages.length === 0) return false;
const lastPage = pages.at(-1);
if (!lastPage) return false;
return lastPage.length < PAGE_SIZE;
}, [pages]);
const isLoadingMore = isValidating && size > 0 && !!pages && pages.length < size;
const loadMoreRef = useRef<HTMLDivElement>(null);
const triggerLoadMore = useCallback(() => {
if (hasMoreBase) {
setVisibleBaseSize((v) => Math.min(v + PAGE_SIZE, disabledModels.length));
return;
}
if (isReachingEnd) return;
if (isValidating) return;
setSize(size + 1);
}, [disabledModels.length, hasMoreBase, isReachingEnd, isValidating, setSize, size]);
useEffect(() => {
if (!hasMoreBase && isReachingEnd) return;
if (!loadMoreRef.current) return;
const observer = new IntersectionObserver(
(entries) => {
entries.forEach((entry) => {
if (!entry.isIntersecting) return;
triggerLoadMore();
});
},
{
rootMargin: '200px',
threshold: 0.01,
},
);
observer.observe(loadMoreRef.current);
return () => {
observer.disconnect();
};
}, [hasMoreBase, isReachingEnd, triggerLoadMore]);
const sourceDisabledModels = mergedDisabledModels;
const shouldRenderSection =
disabledModels.length > 0 || isInitialLoading || sourceDisabledModels.length > 0;
// Filter models based on active tab
const filteredDisabledModels = useMemo(() => {
if (activeTab === 'all') return disabledModels;
return disabledModels.filter((model) => model.type === activeTab);
}, [disabledModels, activeTab]);
if (activeTab === 'all') return sourceDisabledModels;
return sourceDisabledModels.filter((model) => model.type === activeTab);
}, [activeTab, sourceDisabledModels]);
// Sort models based on sort type
const sortedDisabledModels = useMemo(() => {
@@ -101,16 +230,16 @@ const DisabledModels = memo<DisabledModelsProps>(({ activeTab }) => {
}
}, [filteredDisabledModels, sortType]);
const displayModels = showMore ? sortedDisabledModels : sortedDisabledModels.slice(0, 10);
const displayModels = sortedDisabledModels;
return (
filteredDisabledModels.length > 0 && (
shouldRenderSection && (
<Flexbox>
<Flexbox align="center" horizontal justify="space-between">
<Text style={{ fontSize: 12, marginTop: 8 }} type={'secondary'}>
{t('providerModels.list.disabled')}
</Text>
{filteredDisabledModels.length > 1 && (
{sourceDisabledModels.length > 1 && (
<Dropdown
menu={{
items: [
@@ -174,18 +303,15 @@ const DisabledModels = memo<DisabledModelsProps>(({ activeTab }) => {
<ModelItem {...item} key={item.id} />
))}
</TooltipGroup>
{!showMore && sortedDisabledModels.length > 10 && (
<Button
block
icon={ChevronDown}
onClick={() => {
setShowMore(true);
}}
size={'small'}
>
{t('providerModels.list.disabledActions.showMore')}
</Button>
)}
<Flexbox align="center" horizontal justify="center" paddingBlock={8}>
<div ref={loadMoreRef} style={{ height: 1, width: '0' }} />
{(isInitialLoading || isLoadingMore) && (
<Text style={{ fontSize: 12, marginTop: 4 }} type={'secondary'}>
{t('common:loading')}
</Text>
)}
</Flexbox>
</Flexbox>
)
);

View File

@@ -38,31 +38,33 @@ const EnabledModelList = ({ activeTab }: EnabledModelListProps) => {
{t('providerModels.list.enabled')}
</Text>
{!isEmpty && (
<Flexbox horizontal>
<ActionIcon
icon={ToggleLeft}
loading={batchLoading}
onClick={async () => {
setBatchLoading(true);
await batchToggleAiModels(
enabledModels.map((i) => i.id),
false,
);
setBatchLoading(false);
}}
size={'small'}
title={t('providerModels.list.enabledActions.disableAll')}
/>
<TooltipGroup>
<Flexbox horizontal>
<ActionIcon
icon={ToggleLeft}
loading={batchLoading}
onClick={async () => {
setBatchLoading(true);
await batchToggleAiModels(
enabledModels.map((i) => i.id),
false,
);
setBatchLoading(false);
}}
size={'small'}
title={t('providerModels.list.enabledActions.disableAll')}
/>
<ActionIcon
icon={ArrowDownUpIcon}
onClick={() => {
setOpen(true);
}}
size={'small'}
title={t('providerModels.list.enabledActions.sort')}
/>
</Flexbox>
<ActionIcon
icon={ArrowDownUpIcon}
onClick={() => {
setOpen(true);
}}
size={'small'}
title={t('providerModels.list.enabledActions.sort')}
/>
</Flexbox>
</TooltipGroup>
)}
{open && (
<SortModelModal

View File

@@ -33,7 +33,8 @@ interface ContentProps {
}
const Content = memo<ContentProps>(({ id }) => {
const { t } = useTranslation('modelProvider');
// preload common namespace to avoid Suspense remount when child components start using it (e.g. infinite scroll loading text)
const { t } = useTranslation(['modelProvider', 'common']);
const [activeTab, setActiveTab] = useState('all');
const [isSearching, isEmpty, useFetchAiProviderModels] = useAiInfraStore((s) => [
@@ -135,7 +136,7 @@ const Content = memo<ContentProps>(({ id }) => {
style={{ marginBottom: 12, marginLeft: -6 }}
/>
<EnabledModelList activeTab={currentActiveTab} />
<DisabledModels activeTab={currentActiveTab} />
<DisabledModels activeTab={currentActiveTab} providerId={id} />
</Flexbox>
);
});

View File

@@ -57,6 +57,7 @@ const resources = {
file,
home,
hotkey,
image,
knowledgeBase,

View File

@@ -90,7 +90,11 @@ describe('aiModelRouter', () => {
const result = await caller.getAiProviderModelList({ id: 'provider-1' });
expect(result).toEqual(mockModelList);
expect(mockGetList).toHaveBeenCalledWith('provider-1');
expect(mockGetList).toHaveBeenCalledWith('provider-1', {
enabled: undefined,
limit: undefined,
offset: undefined,
});
});
it('should remove ai model', async () => {

View File

@@ -85,9 +85,20 @@ export const aiModelRouter = router({
}),
getAiProviderModelList: aiModelProcedure
.input(z.object({ id: z.string() }))
.input(
z.object({
enabled: z.boolean().optional(),
id: z.string(),
limit: z.number().int().min(1).max(200).optional(),
offset: z.number().int().min(0).optional(),
}),
)
.query(async ({ ctx, input }): Promise<AiProviderModelListItem[]> => {
return ctx.aiInfraRepos.getAiProviderModelList(input.id);
return ctx.aiInfraRepos.getAiProviderModelList(input.id, {
enabled: input.enabled,
limit: input.limit,
offset: input.offset,
});
}),
removeAiModel: aiModelProcedure

View File

@@ -8,13 +8,22 @@ import {
import { lambdaClient } from '@/libs/trpc/client';
export interface GetAiProviderModelListParams {
enabled?: boolean;
limit?: number;
offset?: number;
}
export class AiModelService {
createAiModel = async (params: CreateAiModelParams) => {
return lambdaClient.aiModel.createAiModel.mutate(params);
};
getAiProviderModelList = async (id: string): Promise<AiProviderModelListItem[]> => {
return lambdaClient.aiModel.getAiProviderModelList.query({ id });
getAiProviderModelList = async (
id: string,
params?: GetAiProviderModelListParams,
): Promise<AiProviderModelListItem[]> => {
return lambdaClient.aiModel.getAiProviderModelList.query({ id, ...params });
};
getAiModelById = async (id: string) => {