feat(memory-user-memory): support to configure preferred model (#11637)

This commit is contained in:
Neko
2026-01-20 16:36:49 +08:00
committed by GitHub
parent dc7f7d212b
commit 49374daab2
4 changed files with 410 additions and 21 deletions

View File

@@ -32,12 +32,18 @@ export type MemoryLayerExtractorConfig = MemoryLayerExtractorPublicConfig &
export interface MemoryExtractionPrivateConfig {
agentGateKeeper: MemoryAgentConfig;
agentGateKeeperPreferredModels?: string[];
agentGateKeeperPreferredProviders?: string[];
agentLayerExtractor: MemoryLayerExtractorConfig;
agentLayerExtractorPreferredModels?: string[];
agentLayerExtractorPreferredProviders?: string[];
concurrency?: number;
embedding: MemoryAgentConfig;
embeddingPreferredModels?: string[];
embeddingPreferredProviders?: string[];
featureFlags: {
enableBenchmarkLoCoMo: boolean;
},
};
observabilityS3?: {
accessKeyId?: string;
bucketName?: string;
@@ -157,6 +163,12 @@ const sanitizeAgent = (agent?: MemoryAgentConfig): MemoryAgentPublicConfig | und
return sanitized as MemoryAgentPublicConfig;
};
const parsePreferredList = (value?: string) =>
value
?.split(',')
.map((item) => item.trim().toLowerCase())
.filter(Boolean);
export const parseMemoryExtractionConfig = (): MemoryExtractionPrivateConfig => {
const agentGateKeeper = parseGateKeeperAgent();
const agentLayerExtractor = parseLayerExtractorAgent(agentGateKeeper.model);
@@ -167,8 +179,8 @@ export const parseMemoryExtractionConfig = (): MemoryExtractionPrivateConfig =>
);
const extractorObservabilityS3 = parseExtractorAgentObservabilityS3();
const featureFlags = {
enableBenchmarkLoCoMo: process.env.MEMORY_USER_MEMORY_FEATURE_FLAG_BENCHMARK_LOCOMO === 'true'
}
enableBenchmarkLoCoMo: process.env.MEMORY_USER_MEMORY_FEATURE_FLAG_BENCHMARK_LOCOMO === 'true',
};
const concurrencyRaw = process.env.MEMORY_USER_MEMORY_CONCURRENCY;
const concurrency =
concurrencyRaw !== undefined
@@ -191,7 +203,9 @@ export const parseMemoryExtractionConfig = (): MemoryExtractionPrivateConfig =>
return acc;
}, {});
const upstashWorkflowExtraHeaders = process.env.MEMORY_USER_MEMORY_WORKFLOW_EXTRA_HEADERS?.split(',')
const upstashWorkflowExtraHeaders = process.env.MEMORY_USER_MEMORY_WORKFLOW_EXTRA_HEADERS?.split(
',',
)
.filter(Boolean)
.reduce<Record<string, string>>((acc, pair) => {
const [key, value] = pair.split('=').map((s) => s.trim());
@@ -201,11 +215,36 @@ export const parseMemoryExtractionConfig = (): MemoryExtractionPrivateConfig =>
return acc;
}, {});
const agentGateKeeperPreferredProviders = parsePreferredList(
process.env.MEMORY_USER_MEMORY_GATEKEEPER_PREFERRED_PROVIDERS,
);
const agentGateKeeperPreferredModels = parsePreferredList(
process.env.MEMORY_USER_MEMORY_GATEKEEPER_PREFERRED_MODELS,
);
const embeddingPreferredProviders = parsePreferredList(
process.env.MEMORY_USER_MEMORY_EMBEDDING_PREFERRED_PROVIDERS,
);
const embeddingPreferredModels = parsePreferredList(
process.env.MEMORY_USER_MEMORY_EMBEDDING_PREFERRED_MODELS,
);
const agentLayerExtractorPreferredProviders = parsePreferredList(
process.env.MEMORY_USER_MEMORY_LAYER_EXTRACTOR_PREFERRED_PROVIDERS,
);
const agentLayerExtractorPreferredModels = parsePreferredList(
process.env.MEMORY_USER_MEMORY_LAYER_EXTRACTOR_PREFERRED_MODELS,
);
return {
agentGateKeeper,
agentGateKeeperPreferredModels,
agentGateKeeperPreferredProviders,
agentLayerExtractor,
agentLayerExtractorPreferredModels,
agentLayerExtractorPreferredProviders,
concurrency,
embedding,
embeddingPreferredModels,
embeddingPreferredProviders,
featureFlags,
observabilityS3: extractorObservabilityS3,
upstashWorkflowExtraHeaders,

View File

@@ -0,0 +1,101 @@
import { describe, expect, it } from 'vitest';
import { LayersEnum, MemorySourceType } from '@/types/userMemory';
import {
type MemoryExtractionNormalizedPayload,
type MemoryExtractionPayloadInput,
buildWorkflowPayloadInput,
normalizeMemoryExtractionPayload,
} from '../extract';
describe('normalizeMemoryExtractionPayload', () => {
it('normalizes sources, layers, ids, and dates with fallback baseUrl', () => {
const fromDate = new Date('2024-01-01T00:00:00Z');
const toDate = new Date('2024-02-01T00:00:00Z');
const payload: MemoryExtractionPayloadInput = {
forceAll: true,
forceTopics: true,
fromDate,
identityCursor: 3,
layers: [LayersEnum.Context, LayersEnum.Identity, LayersEnum.Context],
mode: 'direct',
sourceIds: ['source-1', 'source-1', ''],
sources: ['chatTopics', 'benchmark_locomo', 'unknown'],
toDate,
topicIds: ['topic-1', 'topic-1', ''],
userId: 'user-a',
userIds: ['user-a', 'user-b', ''],
};
const normalized = normalizeMemoryExtractionPayload(payload, 'https://api.example.com');
expect(normalized.baseUrl).toBe('https://api.example.com');
expect(normalized.forceAll).toBe(true);
expect(normalized.forceTopics).toBe(true);
expect(normalized.from).toEqual(fromDate);
expect(normalized.to).toEqual(toDate);
expect(normalized.identityCursor).toBe(3);
expect(normalized.layers).toEqual([LayersEnum.Context, LayersEnum.Identity]);
expect(normalized.sources).toEqual([
MemorySourceType.ChatTopic,
MemorySourceType.BenchmarkLocomo,
]);
expect(normalized.sourceIds).toEqual(['source-1']);
expect(normalized.topicIds).toEqual(['topic-1']);
expect(normalized.userId).toBe('user-a');
expect(normalized.userIds).toEqual(['user-a', 'user-b']);
});
it('throws when baseUrl is missing in both payload and fallback', () => {
const payload: MemoryExtractionPayloadInput = {
forceAll: false,
forceTopics: false,
userIds: [],
};
expect(() => normalizeMemoryExtractionPayload(payload)).toThrow('Missing baseUrl');
});
});
describe('buildWorkflowPayloadInput', () => {
const baseNormalized: MemoryExtractionNormalizedPayload = {
baseUrl: 'https://api.example.com',
forceAll: false,
forceTopics: false,
from: undefined,
identityCursor: 0,
layers: [],
mode: 'workflow',
sourceIds: [],
sources: [MemorySourceType.ChatTopic],
to: undefined,
topicCursor: undefined,
topicIds: [],
userCursor: undefined,
userId: undefined,
userIds: ['user-x', 'user-y'],
};
it('falls back to the first user id when userId is missing', () => {
const payload = buildWorkflowPayloadInput(baseNormalized);
expect(payload.userId).toBe('user-x');
expect(payload.userIds).toEqual(['user-x', 'user-y']);
expect(payload.baseUrl).toBe('https://api.example.com');
expect(payload.mode).toBe('workflow');
});
it('preserves explicit userId when provided', () => {
const normalized: MemoryExtractionNormalizedPayload = {
...baseNormalized,
userId: 'user-z',
};
const payload = buildWorkflowPayloadInput(normalized);
expect(payload.userId).toBe('user-z');
expect(payload.userIds).toEqual(['user-x', 'user-y']);
});
});

View File

@@ -0,0 +1,111 @@
import type { AiProviderRuntimeState } from '@lobechat/types';
import type { EnabledAiModel } from 'model-bank';
import { describe, expect, it } from 'vitest';
import type { MemoryExtractionPrivateConfig } from '@/server/globalConfig/parseMemoryExtractionConfig';
import { MemoryExtractionExecutor } from '../extract';
const createRuntimeState = (models: EnabledAiModel[], keyVaults: Record<string, any>) =>
({
enabledAiModels: models,
enabledAiProviders: [],
enabledChatAiProviders: [],
enabledImageAiProviders: [],
runtimeConfig: Object.fromEntries(
Object.entries(keyVaults).map(([providerId, vault]) => [
providerId,
{ config: {}, keyVaults: vault, settings: {} },
]),
),
}) as AiProviderRuntimeState;
const createExecutor = (privateOverrides?: Partial<MemoryExtractionPrivateConfig>) => {
const basePrivateConfig: MemoryExtractionPrivateConfig = {
agentGateKeeper: { model: 'gate-2', provider: 'provider-b' },
agentLayerExtractor: {
contextLimit: 2048,
layers: {
context: 'layer-ctx',
experience: 'layer-exp',
identity: 'layer-id',
preference: 'layer-pref',
},
model: 'layer-1',
provider: 'provider-l',
},
concurrency: 1,
embedding: { model: 'embed-1', provider: 'provider-e' },
featureFlags: { enableBenchmarkLoCoMo: false },
observabilityS3: { enabled: false },
};
const serverConfig = {
aiProvider: {},
memory: {},
};
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore accessing private constructor for testing
return new MemoryExtractionExecutor(serverConfig as any, {
...basePrivateConfig,
...privateOverrides,
});
};
describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => {
it('prefers configured providers/models for gatekeeper, embedding, and layer extractors', () => {
const executor = createExecutor({
embeddingPreferredProviders: ['provider-e'],
agentGateKeeperPreferredModels: ['gate-1'],
agentGateKeeperPreferredProviders: ['provider-a', 'provider-b'],
agentLayerExtractorPreferredProviders: ['provider-l'],
});
const runtimeState = createRuntimeState(
[
{ abilities: {}, id: 'gate-1', providerId: 'provider-a', type: 'chat' },
{ abilities: {}, id: 'gate-2', providerId: 'provider-b', type: 'chat' },
{ abilities: {}, id: 'embed-1', providerId: 'provider-e', type: 'embedding' },
{ abilities: {}, id: 'layer-ctx', providerId: 'provider-l', type: 'chat' },
{ abilities: {}, id: 'layer-exp', providerId: 'provider-l', type: 'chat' },
{ abilities: {}, id: 'layer-id', providerId: 'provider-l', type: 'chat' },
{ abilities: {}, id: 'layer-pref', providerId: 'provider-l', type: 'chat' },
],
{
'provider-a': { apiKey: 'a-key' },
'provider-b': { apiKey: 'b-key' },
'provider-e': { apiKey: 'e-key' },
'provider-l': { apiKey: 'l-key' },
},
);
const keyVaults = (executor as any).resolveRuntimeKeyVaults(runtimeState);
expect(keyVaults).toMatchObject({
'provider-a': { apiKey: 'a-key' }, // gatekeeper picked preferred provider/model
'provider-e': { apiKey: 'e-key' }, // embedding honored preferred provider
'provider-l': { apiKey: 'l-key' }, // layer extractor models resolved
});
});
it('throws when no provider can satisfy an embedding model', () => {
const executor = createExecutor();
const runtimeState = createRuntimeState(
[
{ abilities: {}, id: 'gate-2', providerId: 'provider-b', type: 'chat' },
{ abilities: {}, id: 'layer-ctx', providerId: 'provider-l', type: 'chat' },
{ abilities: {}, id: 'layer-exp', providerId: 'provider-l', type: 'chat' },
{ abilities: {}, id: 'layer-id', providerId: 'provider-l', type: 'chat' },
{ abilities: {}, id: 'layer-pref', providerId: 'provider-l', type: 'chat' },
],
{
'provider-b': { apiKey: 'b-key' },
'provider-l': { apiKey: 'l-key' },
},
);
expect(() => (executor as any).resolveRuntimeKeyVaults(runtimeState)).toThrow(/embedding/i);
});
});

View File

@@ -38,6 +38,7 @@ import {
} from '@lobechat/observability-otel/modules/memory-user-memory';
import { attributesCommon } from '@lobechat/observability-otel/node';
import type {
AiProviderRuntimeState,
IdentityMemoryDetail,
MemoryExtractionAgentCallTrace,
MemoryExtractionTraceError,
@@ -55,6 +56,7 @@ import type { ListUsersForMemoryExtractorCursor } from '@/database/models/user';
import { UserModel } from '@/database/models/user';
import { UserMemoryModel } from '@/database/models/userMemory';
import { UserMemorySourceBenchmarkLoCoMoModel } from '@/database/models/userMemory/sources/benchmarkLoCoMo';
import { AiInfraRepos } from '@/database/repositories/aiInfra';
import { getServerDB } from '@/database/server';
import { getServerGlobalConfig } from '@/server/globalConfig';
import {
@@ -64,7 +66,7 @@ import {
import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
import { S3 } from '@/server/modules/S3';
import type { GlobalMemoryLayer } from '@/types/serverConfig';
import type { UserKeyVaults } from '@/types/user/settings';
import type { ProviderConfig } from '@/types/user/settings';
import {
LayersEnum,
MemorySourceType,
@@ -212,6 +214,11 @@ export interface TopicBatchWorkflowPayload extends MemoryExtractionPayloadInput
userId: string;
}
type ProviderKeyVaultMap = Record<
string,
AiProviderRuntimeState['runtimeConfig'][string]['keyVaults'] | undefined
>;
export const buildWorkflowPayloadInput = (
payload: MemoryExtractionNormalizedPayload,
): MemoryExtractionPayloadInput => ({
@@ -232,11 +239,9 @@ export const buildWorkflowPayloadInput = (
userIds: payload.userIds,
});
const normalizeProvider = (provider: string) => provider.toLowerCase() as keyof UserKeyVaults;
const extractCredentialsFromVault = (provider: string, keyVaults?: UserKeyVaults) => {
const vault = keyVaults?.[normalizeProvider(provider)];
const normalizeProvider = (provider: string) => provider.toLowerCase();
const extractCredentialsFromVault = (vault?: Record<string, unknown>) => {
if (!vault || typeof vault !== 'object') return {};
const apiKey = 'apiKey' in vault && typeof vault.apiKey === 'string' ? vault.apiKey : undefined;
@@ -275,11 +280,10 @@ const maskSecret = (value?: string) => {
return `${value.slice(0, 6)}***${value.slice(-4)}`;
};
const resolveRuntimeAgentConfig = (agent: MemoryAgentConfig, keyVaults?: UserKeyVaults) => {
const resolveRuntimeAgentConfig = (agent: MemoryAgentConfig, keyVaults?: ProviderKeyVaultMap) => {
const provider = agent.provider || 'openai';
const { apiKey: userApiKey, baseURL: userBaseURL } = extractCredentialsFromVault(
provider,
keyVaults,
keyVaults?.[normalizeProvider(provider)],
);
// Only use the user baseURL if we are also using their API key; otherwise fall back entirely
@@ -309,7 +313,7 @@ const debugRuntimeInit = (
});
};
const initRuntimeForAgent = async (agent: MemoryAgentConfig, keyVaults?: UserKeyVaults) => {
const initRuntimeForAgent = async (agent: MemoryAgentConfig, keyVaults?: ProviderKeyVaultMap) => {
const resolved = resolveRuntimeAgentConfig(agent, keyVaults);
debugRuntimeInit(agent, resolved);
@@ -366,6 +370,13 @@ type MemoryExtractionConfig = ReturnType<typeof parseMemoryExtractionConfig>;
type ServerConfig = Awaited<ReturnType<typeof getServerGlobalConfig>>;
export class MemoryExtractionExecutor {
private readonly aiProviderConfig: Record<string, ProviderConfig>;
private readonly embeddingPreferredModels?: string[];
private readonly embeddingPreferredProviders?: string[];
private readonly gatekeeperPreferredModels?: string[];
private readonly gatekeeperPreferredProviders?: string[];
private readonly layerPreferredModels?: string[];
private readonly layerPreferredProviders?: string[];
private readonly privateConfig: MemoryExtractionConfig;
private readonly modelConfig: {
embeddingsModel: string;
@@ -380,6 +391,13 @@ export class MemoryExtractionExecutor {
private constructor(serverConfig: ServerConfig, privateConfig: MemoryExtractionConfig) {
this.privateConfig = privateConfig;
this.aiProviderConfig = (serverConfig.aiProvider || {}) as Record<string, ProviderConfig>;
this.embeddingPreferredProviders = privateConfig.embeddingPreferredProviders;
this.embeddingPreferredModels = privateConfig.embeddingPreferredModels;
this.gatekeeperPreferredProviders = privateConfig.agentGateKeeperPreferredProviders;
this.gatekeeperPreferredModels = privateConfig.agentGateKeeperPreferredModels;
this.layerPreferredProviders = privateConfig.agentLayerExtractorPreferredProviders;
this.layerPreferredModels = privateConfig.agentLayerExtractorPreferredModels;
const publicMemoryConfig = serverConfig.memory?.userMemory;
@@ -1014,8 +1032,11 @@ export class MemoryExtractionExecutor {
};
const userModel = new UserModel(db, job.userId);
const userState = await userModel.getUserState(KeyVaultsGateKeeper.getUserKeyVaults);
const keyVaults = userState.settings?.keyVaults as UserKeyVaults | undefined;
const [userState, aiProviderRuntimeState] = await Promise.all([
userModel.getUserState(KeyVaultsGateKeeper.getUserKeyVaults),
this.getAiProviderRuntimeState(job.userId),
]);
const keyVaults = this.resolveRuntimeKeyVaults(aiProviderRuntimeState);
const language = userState.settings?.general?.responseLanguage;
const runtimes = await this.getRuntime(job.userId, keyVaults);
@@ -1625,7 +1646,112 @@ export class MemoryExtractionExecutor {
};
}
private async getRuntime(userId: string, keyVaults?: UserKeyVaults): Promise<RuntimeBundle> {
private async getAiProviderRuntimeState(userId: string): Promise<AiProviderRuntimeState> {
const db = await this.db;
const aiInfraRepos = new AiInfraRepos(db, userId, this.aiProviderConfig);
return aiInfraRepos.getAiProviderRuntimeState(KeyVaultsGateKeeper.getUserKeyVaults);
}
private resolveRuntimeKeyVaults(runtimeState: AiProviderRuntimeState): ProviderKeyVaultMap {
const normalizedRuntimeConfig = Object.fromEntries(
Object.entries(runtimeState.runtimeConfig || {}).map(([providerId, config]) => [
normalizeProvider(providerId),
config,
]),
);
const providerModels = runtimeState.enabledAiModels.reduce<Record<string, Set<string>>>(
(acc, model) => {
const providerId = normalizeProvider(model.providerId);
acc[providerId] = acc[providerId] || new Set<string>();
acc[providerId].add(model.id);
return acc;
},
{},
);
const resolveProviderForModel = (
modelId: string,
fallbackProvider?: string,
preferredProviders?: string[],
preferredModels?: string[],
label?: string,
) => {
const providerOrder = Array.from(
new Set(
[
...(preferredProviders?.map(normalizeProvider) || []),
fallbackProvider ? normalizeProvider(fallbackProvider) : undefined,
...Object.keys(providerModels),
].filter(Boolean) as string[],
),
);
const candidateModels = preferredModels && preferredModels.length > 0 ? preferredModels : [];
for (const providerId of providerOrder) {
const models = providerModels[providerId];
if (!models) continue;
if (models.has(modelId)) return providerId;
const preferredMatch = candidateModels.find((preferredModel) => models.has(preferredModel));
if (preferredMatch) return providerId;
}
throw new Error(
`Unable to resolve provider for ${label || 'model'} "${modelId}". Check preferred providers/models configuration.`,
);
};
const keyVaults: ProviderKeyVaultMap = {};
const gatekeeperProvider = resolveProviderForModel(
this.modelConfig.gateModel,
this.privateConfig.agentGateKeeper.provider,
this.gatekeeperPreferredProviders,
this.gatekeeperPreferredModels,
'gatekeeper',
);
const gatekeeperRuntime = normalizedRuntimeConfig[gatekeeperProvider];
if (gatekeeperRuntime?.keyVaults) {
keyVaults[gatekeeperProvider] = gatekeeperRuntime.keyVaults;
}
const embeddingProvider = resolveProviderForModel(
this.modelConfig.embeddingsModel,
this.privateConfig.embedding.provider,
this.embeddingPreferredProviders,
this.embeddingPreferredModels,
'embedding',
);
const embeddingRuntime = normalizedRuntimeConfig[embeddingProvider];
if (embeddingRuntime?.keyVaults) {
keyVaults[embeddingProvider] = embeddingRuntime.keyVaults;
}
Object.values(this.modelConfig.layerModels).forEach((model) => {
if (!model) return;
const providerId = resolveProviderForModel(
model,
this.privateConfig.agentLayerExtractor.provider,
this.layerPreferredProviders,
this.layerPreferredModels,
'layer extractor',
);
const runtime = normalizedRuntimeConfig[providerId];
if (runtime?.keyVaults) {
keyVaults[providerId] = runtime.keyVaults;
}
});
return keyVaults;
}
private async getRuntime(
userId: string,
keyVaults?: ProviderKeyVaultMap,
): Promise<RuntimeBundle> {
// TODO: implement a better cache eviction strategy
// TODO: make cache size configurable
if (this.runtimeCache.keys.length > 200) {
@@ -1673,8 +1799,11 @@ export class MemoryExtractionExecutor {
try {
const db = await this.db;
const userModel = new UserModel(db, params.userId);
const userState = await userModel.getUserState(KeyVaultsGateKeeper.getUserKeyVaults);
const keyVaults = userState.settings?.keyVaults as UserKeyVaults | undefined;
const [userState, aiProviderRuntimeState] = await Promise.all([
userModel.getUserState(KeyVaultsGateKeeper.getUserKeyVaults),
this.getAiProviderRuntimeState(params.userId),
]);
const keyVaults = this.resolveRuntimeKeyVaults(aiProviderRuntimeState);
const language = params.language || userState.settings?.general?.responseLanguage;
const runtimes = await this.getRuntime(params.userId, keyVaults);
@@ -1849,7 +1978,10 @@ export class MemoryExtractionWorkflowService {
return this.client;
}
static triggerProcessUsers(payload: MemoryExtractionPayloadInput, options?: { extraHeaders?: Record<string, string> }) {
static triggerProcessUsers(
payload: MemoryExtractionPayloadInput,
options?: { extraHeaders?: Record<string, string> },
) {
if (!payload.baseUrl) {
throw new Error('Missing baseUrl for workflow trigger');
}
@@ -1858,7 +1990,10 @@ export class MemoryExtractionWorkflowService {
return this.getClient().trigger({ body: payload, headers: options?.extraHeaders, url });
}
static triggerProcessUserTopics(payload: UserTopicWorkflowPayload, options?: { extraHeaders?: Record<string, string> }) {
static triggerProcessUserTopics(
payload: UserTopicWorkflowPayload,
options?: { extraHeaders?: Record<string, string> },
) {
if (!payload.baseUrl) {
throw new Error('Missing baseUrl for workflow trigger');
}
@@ -1867,7 +2002,10 @@ export class MemoryExtractionWorkflowService {
return this.getClient().trigger({ body: payload, headers: options?.extraHeaders, url });
}
static triggerProcessTopics(payload: MemoryExtractionPayloadInput, options?: { extraHeaders?: Record<string, string> }) {
static triggerProcessTopics(
payload: MemoryExtractionPayloadInput,
options?: { extraHeaders?: Record<string, string> },
) {
if (!payload.baseUrl) {
throw new Error('Missing baseUrl for workflow trigger');
}