mirror of
https://github.com/lobehub/lobehub.git
synced 2026-03-27 13:29:15 +07:00
✨ feat(memory-user-memory): support to configure preferred model (#11637)
This commit is contained in:
@@ -32,12 +32,18 @@ export type MemoryLayerExtractorConfig = MemoryLayerExtractorPublicConfig &
|
||||
|
||||
export interface MemoryExtractionPrivateConfig {
|
||||
agentGateKeeper: MemoryAgentConfig;
|
||||
agentGateKeeperPreferredModels?: string[];
|
||||
agentGateKeeperPreferredProviders?: string[];
|
||||
agentLayerExtractor: MemoryLayerExtractorConfig;
|
||||
agentLayerExtractorPreferredModels?: string[];
|
||||
agentLayerExtractorPreferredProviders?: string[];
|
||||
concurrency?: number;
|
||||
embedding: MemoryAgentConfig;
|
||||
embeddingPreferredModels?: string[];
|
||||
embeddingPreferredProviders?: string[];
|
||||
featureFlags: {
|
||||
enableBenchmarkLoCoMo: boolean;
|
||||
},
|
||||
};
|
||||
observabilityS3?: {
|
||||
accessKeyId?: string;
|
||||
bucketName?: string;
|
||||
@@ -157,6 +163,12 @@ const sanitizeAgent = (agent?: MemoryAgentConfig): MemoryAgentPublicConfig | und
|
||||
return sanitized as MemoryAgentPublicConfig;
|
||||
};
|
||||
|
||||
const parsePreferredList = (value?: string) =>
|
||||
value
|
||||
?.split(',')
|
||||
.map((item) => item.trim().toLowerCase())
|
||||
.filter(Boolean);
|
||||
|
||||
export const parseMemoryExtractionConfig = (): MemoryExtractionPrivateConfig => {
|
||||
const agentGateKeeper = parseGateKeeperAgent();
|
||||
const agentLayerExtractor = parseLayerExtractorAgent(agentGateKeeper.model);
|
||||
@@ -167,8 +179,8 @@ export const parseMemoryExtractionConfig = (): MemoryExtractionPrivateConfig =>
|
||||
);
|
||||
const extractorObservabilityS3 = parseExtractorAgentObservabilityS3();
|
||||
const featureFlags = {
|
||||
enableBenchmarkLoCoMo: process.env.MEMORY_USER_MEMORY_FEATURE_FLAG_BENCHMARK_LOCOMO === 'true'
|
||||
}
|
||||
enableBenchmarkLoCoMo: process.env.MEMORY_USER_MEMORY_FEATURE_FLAG_BENCHMARK_LOCOMO === 'true',
|
||||
};
|
||||
const concurrencyRaw = process.env.MEMORY_USER_MEMORY_CONCURRENCY;
|
||||
const concurrency =
|
||||
concurrencyRaw !== undefined
|
||||
@@ -191,7 +203,9 @@ export const parseMemoryExtractionConfig = (): MemoryExtractionPrivateConfig =>
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
const upstashWorkflowExtraHeaders = process.env.MEMORY_USER_MEMORY_WORKFLOW_EXTRA_HEADERS?.split(',')
|
||||
const upstashWorkflowExtraHeaders = process.env.MEMORY_USER_MEMORY_WORKFLOW_EXTRA_HEADERS?.split(
|
||||
',',
|
||||
)
|
||||
.filter(Boolean)
|
||||
.reduce<Record<string, string>>((acc, pair) => {
|
||||
const [key, value] = pair.split('=').map((s) => s.trim());
|
||||
@@ -201,11 +215,36 @@ export const parseMemoryExtractionConfig = (): MemoryExtractionPrivateConfig =>
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
const agentGateKeeperPreferredProviders = parsePreferredList(
|
||||
process.env.MEMORY_USER_MEMORY_GATEKEEPER_PREFERRED_PROVIDERS,
|
||||
);
|
||||
const agentGateKeeperPreferredModels = parsePreferredList(
|
||||
process.env.MEMORY_USER_MEMORY_GATEKEEPER_PREFERRED_MODELS,
|
||||
);
|
||||
const embeddingPreferredProviders = parsePreferredList(
|
||||
process.env.MEMORY_USER_MEMORY_EMBEDDING_PREFERRED_PROVIDERS,
|
||||
);
|
||||
const embeddingPreferredModels = parsePreferredList(
|
||||
process.env.MEMORY_USER_MEMORY_EMBEDDING_PREFERRED_MODELS,
|
||||
);
|
||||
const agentLayerExtractorPreferredProviders = parsePreferredList(
|
||||
process.env.MEMORY_USER_MEMORY_LAYER_EXTRACTOR_PREFERRED_PROVIDERS,
|
||||
);
|
||||
const agentLayerExtractorPreferredModels = parsePreferredList(
|
||||
process.env.MEMORY_USER_MEMORY_LAYER_EXTRACTOR_PREFERRED_MODELS,
|
||||
);
|
||||
|
||||
return {
|
||||
agentGateKeeper,
|
||||
agentGateKeeperPreferredModels,
|
||||
agentGateKeeperPreferredProviders,
|
||||
agentLayerExtractor,
|
||||
agentLayerExtractorPreferredModels,
|
||||
agentLayerExtractorPreferredProviders,
|
||||
concurrency,
|
||||
embedding,
|
||||
embeddingPreferredModels,
|
||||
embeddingPreferredProviders,
|
||||
featureFlags,
|
||||
observabilityS3: extractorObservabilityS3,
|
||||
upstashWorkflowExtraHeaders,
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { LayersEnum, MemorySourceType } from '@/types/userMemory';
|
||||
|
||||
import {
|
||||
type MemoryExtractionNormalizedPayload,
|
||||
type MemoryExtractionPayloadInput,
|
||||
buildWorkflowPayloadInput,
|
||||
normalizeMemoryExtractionPayload,
|
||||
} from '../extract';
|
||||
|
||||
describe('normalizeMemoryExtractionPayload', () => {
|
||||
it('normalizes sources, layers, ids, and dates with fallback baseUrl', () => {
|
||||
const fromDate = new Date('2024-01-01T00:00:00Z');
|
||||
const toDate = new Date('2024-02-01T00:00:00Z');
|
||||
|
||||
const payload: MemoryExtractionPayloadInput = {
|
||||
forceAll: true,
|
||||
forceTopics: true,
|
||||
fromDate,
|
||||
identityCursor: 3,
|
||||
layers: [LayersEnum.Context, LayersEnum.Identity, LayersEnum.Context],
|
||||
mode: 'direct',
|
||||
sourceIds: ['source-1', 'source-1', ''],
|
||||
sources: ['chatTopics', 'benchmark_locomo', 'unknown'],
|
||||
toDate,
|
||||
topicIds: ['topic-1', 'topic-1', ''],
|
||||
userId: 'user-a',
|
||||
userIds: ['user-a', 'user-b', ''],
|
||||
};
|
||||
|
||||
const normalized = normalizeMemoryExtractionPayload(payload, 'https://api.example.com');
|
||||
|
||||
expect(normalized.baseUrl).toBe('https://api.example.com');
|
||||
expect(normalized.forceAll).toBe(true);
|
||||
expect(normalized.forceTopics).toBe(true);
|
||||
expect(normalized.from).toEqual(fromDate);
|
||||
expect(normalized.to).toEqual(toDate);
|
||||
expect(normalized.identityCursor).toBe(3);
|
||||
expect(normalized.layers).toEqual([LayersEnum.Context, LayersEnum.Identity]);
|
||||
expect(normalized.sources).toEqual([
|
||||
MemorySourceType.ChatTopic,
|
||||
MemorySourceType.BenchmarkLocomo,
|
||||
]);
|
||||
expect(normalized.sourceIds).toEqual(['source-1']);
|
||||
expect(normalized.topicIds).toEqual(['topic-1']);
|
||||
expect(normalized.userId).toBe('user-a');
|
||||
expect(normalized.userIds).toEqual(['user-a', 'user-b']);
|
||||
});
|
||||
|
||||
it('throws when baseUrl is missing in both payload and fallback', () => {
|
||||
const payload: MemoryExtractionPayloadInput = {
|
||||
forceAll: false,
|
||||
forceTopics: false,
|
||||
userIds: [],
|
||||
};
|
||||
|
||||
expect(() => normalizeMemoryExtractionPayload(payload)).toThrow('Missing baseUrl');
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildWorkflowPayloadInput', () => {
|
||||
const baseNormalized: MemoryExtractionNormalizedPayload = {
|
||||
baseUrl: 'https://api.example.com',
|
||||
forceAll: false,
|
||||
forceTopics: false,
|
||||
from: undefined,
|
||||
identityCursor: 0,
|
||||
layers: [],
|
||||
mode: 'workflow',
|
||||
sourceIds: [],
|
||||
sources: [MemorySourceType.ChatTopic],
|
||||
to: undefined,
|
||||
topicCursor: undefined,
|
||||
topicIds: [],
|
||||
userCursor: undefined,
|
||||
userId: undefined,
|
||||
userIds: ['user-x', 'user-y'],
|
||||
};
|
||||
|
||||
it('falls back to the first user id when userId is missing', () => {
|
||||
const payload = buildWorkflowPayloadInput(baseNormalized);
|
||||
|
||||
expect(payload.userId).toBe('user-x');
|
||||
expect(payload.userIds).toEqual(['user-x', 'user-y']);
|
||||
expect(payload.baseUrl).toBe('https://api.example.com');
|
||||
expect(payload.mode).toBe('workflow');
|
||||
});
|
||||
|
||||
it('preserves explicit userId when provided', () => {
|
||||
const normalized: MemoryExtractionNormalizedPayload = {
|
||||
...baseNormalized,
|
||||
userId: 'user-z',
|
||||
};
|
||||
|
||||
const payload = buildWorkflowPayloadInput(normalized);
|
||||
|
||||
expect(payload.userId).toBe('user-z');
|
||||
expect(payload.userIds).toEqual(['user-x', 'user-y']);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,111 @@
|
||||
import type { AiProviderRuntimeState } from '@lobechat/types';
|
||||
import type { EnabledAiModel } from 'model-bank';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import type { MemoryExtractionPrivateConfig } from '@/server/globalConfig/parseMemoryExtractionConfig';
|
||||
|
||||
import { MemoryExtractionExecutor } from '../extract';
|
||||
|
||||
const createRuntimeState = (models: EnabledAiModel[], keyVaults: Record<string, any>) =>
|
||||
({
|
||||
enabledAiModels: models,
|
||||
enabledAiProviders: [],
|
||||
enabledChatAiProviders: [],
|
||||
enabledImageAiProviders: [],
|
||||
runtimeConfig: Object.fromEntries(
|
||||
Object.entries(keyVaults).map(([providerId, vault]) => [
|
||||
providerId,
|
||||
{ config: {}, keyVaults: vault, settings: {} },
|
||||
]),
|
||||
),
|
||||
}) as AiProviderRuntimeState;
|
||||
|
||||
const createExecutor = (privateOverrides?: Partial<MemoryExtractionPrivateConfig>) => {
|
||||
const basePrivateConfig: MemoryExtractionPrivateConfig = {
|
||||
agentGateKeeper: { model: 'gate-2', provider: 'provider-b' },
|
||||
agentLayerExtractor: {
|
||||
contextLimit: 2048,
|
||||
layers: {
|
||||
context: 'layer-ctx',
|
||||
experience: 'layer-exp',
|
||||
identity: 'layer-id',
|
||||
preference: 'layer-pref',
|
||||
},
|
||||
model: 'layer-1',
|
||||
provider: 'provider-l',
|
||||
},
|
||||
concurrency: 1,
|
||||
embedding: { model: 'embed-1', provider: 'provider-e' },
|
||||
featureFlags: { enableBenchmarkLoCoMo: false },
|
||||
observabilityS3: { enabled: false },
|
||||
};
|
||||
|
||||
const serverConfig = {
|
||||
aiProvider: {},
|
||||
memory: {},
|
||||
};
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||
// @ts-ignore accessing private constructor for testing
|
||||
return new MemoryExtractionExecutor(serverConfig as any, {
|
||||
...basePrivateConfig,
|
||||
...privateOverrides,
|
||||
});
|
||||
};
|
||||
|
||||
describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => {
|
||||
it('prefers configured providers/models for gatekeeper, embedding, and layer extractors', () => {
|
||||
const executor = createExecutor({
|
||||
embeddingPreferredProviders: ['provider-e'],
|
||||
agentGateKeeperPreferredModels: ['gate-1'],
|
||||
agentGateKeeperPreferredProviders: ['provider-a', 'provider-b'],
|
||||
agentLayerExtractorPreferredProviders: ['provider-l'],
|
||||
});
|
||||
|
||||
const runtimeState = createRuntimeState(
|
||||
[
|
||||
{ abilities: {}, id: 'gate-1', providerId: 'provider-a', type: 'chat' },
|
||||
{ abilities: {}, id: 'gate-2', providerId: 'provider-b', type: 'chat' },
|
||||
{ abilities: {}, id: 'embed-1', providerId: 'provider-e', type: 'embedding' },
|
||||
{ abilities: {}, id: 'layer-ctx', providerId: 'provider-l', type: 'chat' },
|
||||
{ abilities: {}, id: 'layer-exp', providerId: 'provider-l', type: 'chat' },
|
||||
{ abilities: {}, id: 'layer-id', providerId: 'provider-l', type: 'chat' },
|
||||
{ abilities: {}, id: 'layer-pref', providerId: 'provider-l', type: 'chat' },
|
||||
],
|
||||
{
|
||||
'provider-a': { apiKey: 'a-key' },
|
||||
'provider-b': { apiKey: 'b-key' },
|
||||
'provider-e': { apiKey: 'e-key' },
|
||||
'provider-l': { apiKey: 'l-key' },
|
||||
},
|
||||
);
|
||||
|
||||
const keyVaults = (executor as any).resolveRuntimeKeyVaults(runtimeState);
|
||||
|
||||
expect(keyVaults).toMatchObject({
|
||||
'provider-a': { apiKey: 'a-key' }, // gatekeeper picked preferred provider/model
|
||||
'provider-e': { apiKey: 'e-key' }, // embedding honored preferred provider
|
||||
'provider-l': { apiKey: 'l-key' }, // layer extractor models resolved
|
||||
});
|
||||
});
|
||||
|
||||
it('throws when no provider can satisfy an embedding model', () => {
|
||||
const executor = createExecutor();
|
||||
|
||||
const runtimeState = createRuntimeState(
|
||||
[
|
||||
{ abilities: {}, id: 'gate-2', providerId: 'provider-b', type: 'chat' },
|
||||
{ abilities: {}, id: 'layer-ctx', providerId: 'provider-l', type: 'chat' },
|
||||
{ abilities: {}, id: 'layer-exp', providerId: 'provider-l', type: 'chat' },
|
||||
{ abilities: {}, id: 'layer-id', providerId: 'provider-l', type: 'chat' },
|
||||
{ abilities: {}, id: 'layer-pref', providerId: 'provider-l', type: 'chat' },
|
||||
],
|
||||
{
|
||||
'provider-b': { apiKey: 'b-key' },
|
||||
'provider-l': { apiKey: 'l-key' },
|
||||
},
|
||||
);
|
||||
|
||||
expect(() => (executor as any).resolveRuntimeKeyVaults(runtimeState)).toThrow(/embedding/i);
|
||||
});
|
||||
});
|
||||
@@ -38,6 +38,7 @@ import {
|
||||
} from '@lobechat/observability-otel/modules/memory-user-memory';
|
||||
import { attributesCommon } from '@lobechat/observability-otel/node';
|
||||
import type {
|
||||
AiProviderRuntimeState,
|
||||
IdentityMemoryDetail,
|
||||
MemoryExtractionAgentCallTrace,
|
||||
MemoryExtractionTraceError,
|
||||
@@ -55,6 +56,7 @@ import type { ListUsersForMemoryExtractorCursor } from '@/database/models/user';
|
||||
import { UserModel } from '@/database/models/user';
|
||||
import { UserMemoryModel } from '@/database/models/userMemory';
|
||||
import { UserMemorySourceBenchmarkLoCoMoModel } from '@/database/models/userMemory/sources/benchmarkLoCoMo';
|
||||
import { AiInfraRepos } from '@/database/repositories/aiInfra';
|
||||
import { getServerDB } from '@/database/server';
|
||||
import { getServerGlobalConfig } from '@/server/globalConfig';
|
||||
import {
|
||||
@@ -64,7 +66,7 @@ import {
|
||||
import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
|
||||
import { S3 } from '@/server/modules/S3';
|
||||
import type { GlobalMemoryLayer } from '@/types/serverConfig';
|
||||
import type { UserKeyVaults } from '@/types/user/settings';
|
||||
import type { ProviderConfig } from '@/types/user/settings';
|
||||
import {
|
||||
LayersEnum,
|
||||
MemorySourceType,
|
||||
@@ -212,6 +214,11 @@ export interface TopicBatchWorkflowPayload extends MemoryExtractionPayloadInput
|
||||
userId: string;
|
||||
}
|
||||
|
||||
type ProviderKeyVaultMap = Record<
|
||||
string,
|
||||
AiProviderRuntimeState['runtimeConfig'][string]['keyVaults'] | undefined
|
||||
>;
|
||||
|
||||
export const buildWorkflowPayloadInput = (
|
||||
payload: MemoryExtractionNormalizedPayload,
|
||||
): MemoryExtractionPayloadInput => ({
|
||||
@@ -232,11 +239,9 @@ export const buildWorkflowPayloadInput = (
|
||||
userIds: payload.userIds,
|
||||
});
|
||||
|
||||
const normalizeProvider = (provider: string) => provider.toLowerCase() as keyof UserKeyVaults;
|
||||
|
||||
const extractCredentialsFromVault = (provider: string, keyVaults?: UserKeyVaults) => {
|
||||
const vault = keyVaults?.[normalizeProvider(provider)];
|
||||
const normalizeProvider = (provider: string) => provider.toLowerCase();
|
||||
|
||||
const extractCredentialsFromVault = (vault?: Record<string, unknown>) => {
|
||||
if (!vault || typeof vault !== 'object') return {};
|
||||
|
||||
const apiKey = 'apiKey' in vault && typeof vault.apiKey === 'string' ? vault.apiKey : undefined;
|
||||
@@ -275,11 +280,10 @@ const maskSecret = (value?: string) => {
|
||||
return `${value.slice(0, 6)}***${value.slice(-4)}`;
|
||||
};
|
||||
|
||||
const resolveRuntimeAgentConfig = (agent: MemoryAgentConfig, keyVaults?: UserKeyVaults) => {
|
||||
const resolveRuntimeAgentConfig = (agent: MemoryAgentConfig, keyVaults?: ProviderKeyVaultMap) => {
|
||||
const provider = agent.provider || 'openai';
|
||||
const { apiKey: userApiKey, baseURL: userBaseURL } = extractCredentialsFromVault(
|
||||
provider,
|
||||
keyVaults,
|
||||
keyVaults?.[normalizeProvider(provider)],
|
||||
);
|
||||
|
||||
// Only use the user baseURL if we are also using their API key; otherwise fall back entirely
|
||||
@@ -309,7 +313,7 @@ const debugRuntimeInit = (
|
||||
});
|
||||
};
|
||||
|
||||
const initRuntimeForAgent = async (agent: MemoryAgentConfig, keyVaults?: UserKeyVaults) => {
|
||||
const initRuntimeForAgent = async (agent: MemoryAgentConfig, keyVaults?: ProviderKeyVaultMap) => {
|
||||
const resolved = resolveRuntimeAgentConfig(agent, keyVaults);
|
||||
debugRuntimeInit(agent, resolved);
|
||||
|
||||
@@ -366,6 +370,13 @@ type MemoryExtractionConfig = ReturnType<typeof parseMemoryExtractionConfig>;
|
||||
type ServerConfig = Awaited<ReturnType<typeof getServerGlobalConfig>>;
|
||||
|
||||
export class MemoryExtractionExecutor {
|
||||
private readonly aiProviderConfig: Record<string, ProviderConfig>;
|
||||
private readonly embeddingPreferredModels?: string[];
|
||||
private readonly embeddingPreferredProviders?: string[];
|
||||
private readonly gatekeeperPreferredModels?: string[];
|
||||
private readonly gatekeeperPreferredProviders?: string[];
|
||||
private readonly layerPreferredModels?: string[];
|
||||
private readonly layerPreferredProviders?: string[];
|
||||
private readonly privateConfig: MemoryExtractionConfig;
|
||||
private readonly modelConfig: {
|
||||
embeddingsModel: string;
|
||||
@@ -380,6 +391,13 @@ export class MemoryExtractionExecutor {
|
||||
|
||||
private constructor(serverConfig: ServerConfig, privateConfig: MemoryExtractionConfig) {
|
||||
this.privateConfig = privateConfig;
|
||||
this.aiProviderConfig = (serverConfig.aiProvider || {}) as Record<string, ProviderConfig>;
|
||||
this.embeddingPreferredProviders = privateConfig.embeddingPreferredProviders;
|
||||
this.embeddingPreferredModels = privateConfig.embeddingPreferredModels;
|
||||
this.gatekeeperPreferredProviders = privateConfig.agentGateKeeperPreferredProviders;
|
||||
this.gatekeeperPreferredModels = privateConfig.agentGateKeeperPreferredModels;
|
||||
this.layerPreferredProviders = privateConfig.agentLayerExtractorPreferredProviders;
|
||||
this.layerPreferredModels = privateConfig.agentLayerExtractorPreferredModels;
|
||||
|
||||
const publicMemoryConfig = serverConfig.memory?.userMemory;
|
||||
|
||||
@@ -1014,8 +1032,11 @@ export class MemoryExtractionExecutor {
|
||||
};
|
||||
|
||||
const userModel = new UserModel(db, job.userId);
|
||||
const userState = await userModel.getUserState(KeyVaultsGateKeeper.getUserKeyVaults);
|
||||
const keyVaults = userState.settings?.keyVaults as UserKeyVaults | undefined;
|
||||
const [userState, aiProviderRuntimeState] = await Promise.all([
|
||||
userModel.getUserState(KeyVaultsGateKeeper.getUserKeyVaults),
|
||||
this.getAiProviderRuntimeState(job.userId),
|
||||
]);
|
||||
const keyVaults = this.resolveRuntimeKeyVaults(aiProviderRuntimeState);
|
||||
const language = userState.settings?.general?.responseLanguage;
|
||||
|
||||
const runtimes = await this.getRuntime(job.userId, keyVaults);
|
||||
@@ -1625,7 +1646,112 @@ export class MemoryExtractionExecutor {
|
||||
};
|
||||
}
|
||||
|
||||
private async getRuntime(userId: string, keyVaults?: UserKeyVaults): Promise<RuntimeBundle> {
|
||||
private async getAiProviderRuntimeState(userId: string): Promise<AiProviderRuntimeState> {
|
||||
const db = await this.db;
|
||||
const aiInfraRepos = new AiInfraRepos(db, userId, this.aiProviderConfig);
|
||||
|
||||
return aiInfraRepos.getAiProviderRuntimeState(KeyVaultsGateKeeper.getUserKeyVaults);
|
||||
}
|
||||
|
||||
private resolveRuntimeKeyVaults(runtimeState: AiProviderRuntimeState): ProviderKeyVaultMap {
|
||||
const normalizedRuntimeConfig = Object.fromEntries(
|
||||
Object.entries(runtimeState.runtimeConfig || {}).map(([providerId, config]) => [
|
||||
normalizeProvider(providerId),
|
||||
config,
|
||||
]),
|
||||
);
|
||||
const providerModels = runtimeState.enabledAiModels.reduce<Record<string, Set<string>>>(
|
||||
(acc, model) => {
|
||||
const providerId = normalizeProvider(model.providerId);
|
||||
acc[providerId] = acc[providerId] || new Set<string>();
|
||||
acc[providerId].add(model.id);
|
||||
return acc;
|
||||
},
|
||||
{},
|
||||
);
|
||||
|
||||
const resolveProviderForModel = (
|
||||
modelId: string,
|
||||
fallbackProvider?: string,
|
||||
preferredProviders?: string[],
|
||||
preferredModels?: string[],
|
||||
label?: string,
|
||||
) => {
|
||||
const providerOrder = Array.from(
|
||||
new Set(
|
||||
[
|
||||
...(preferredProviders?.map(normalizeProvider) || []),
|
||||
fallbackProvider ? normalizeProvider(fallbackProvider) : undefined,
|
||||
...Object.keys(providerModels),
|
||||
].filter(Boolean) as string[],
|
||||
),
|
||||
);
|
||||
|
||||
const candidateModels = preferredModels && preferredModels.length > 0 ? preferredModels : [];
|
||||
|
||||
for (const providerId of providerOrder) {
|
||||
const models = providerModels[providerId];
|
||||
if (!models) continue;
|
||||
|
||||
if (models.has(modelId)) return providerId;
|
||||
|
||||
const preferredMatch = candidateModels.find((preferredModel) => models.has(preferredModel));
|
||||
if (preferredMatch) return providerId;
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
`Unable to resolve provider for ${label || 'model'} "${modelId}". Check preferred providers/models configuration.`,
|
||||
);
|
||||
};
|
||||
|
||||
const keyVaults: ProviderKeyVaultMap = {};
|
||||
|
||||
const gatekeeperProvider = resolveProviderForModel(
|
||||
this.modelConfig.gateModel,
|
||||
this.privateConfig.agentGateKeeper.provider,
|
||||
this.gatekeeperPreferredProviders,
|
||||
this.gatekeeperPreferredModels,
|
||||
'gatekeeper',
|
||||
);
|
||||
const gatekeeperRuntime = normalizedRuntimeConfig[gatekeeperProvider];
|
||||
if (gatekeeperRuntime?.keyVaults) {
|
||||
keyVaults[gatekeeperProvider] = gatekeeperRuntime.keyVaults;
|
||||
}
|
||||
|
||||
const embeddingProvider = resolveProviderForModel(
|
||||
this.modelConfig.embeddingsModel,
|
||||
this.privateConfig.embedding.provider,
|
||||
this.embeddingPreferredProviders,
|
||||
this.embeddingPreferredModels,
|
||||
'embedding',
|
||||
);
|
||||
const embeddingRuntime = normalizedRuntimeConfig[embeddingProvider];
|
||||
if (embeddingRuntime?.keyVaults) {
|
||||
keyVaults[embeddingProvider] = embeddingRuntime.keyVaults;
|
||||
}
|
||||
|
||||
Object.values(this.modelConfig.layerModels).forEach((model) => {
|
||||
if (!model) return;
|
||||
const providerId = resolveProviderForModel(
|
||||
model,
|
||||
this.privateConfig.agentLayerExtractor.provider,
|
||||
this.layerPreferredProviders,
|
||||
this.layerPreferredModels,
|
||||
'layer extractor',
|
||||
);
|
||||
const runtime = normalizedRuntimeConfig[providerId];
|
||||
if (runtime?.keyVaults) {
|
||||
keyVaults[providerId] = runtime.keyVaults;
|
||||
}
|
||||
});
|
||||
|
||||
return keyVaults;
|
||||
}
|
||||
|
||||
private async getRuntime(
|
||||
userId: string,
|
||||
keyVaults?: ProviderKeyVaultMap,
|
||||
): Promise<RuntimeBundle> {
|
||||
// TODO: implement a better cache eviction strategy
|
||||
// TODO: make cache size configurable
|
||||
if (this.runtimeCache.keys.length > 200) {
|
||||
@@ -1673,8 +1799,11 @@ export class MemoryExtractionExecutor {
|
||||
try {
|
||||
const db = await this.db;
|
||||
const userModel = new UserModel(db, params.userId);
|
||||
const userState = await userModel.getUserState(KeyVaultsGateKeeper.getUserKeyVaults);
|
||||
const keyVaults = userState.settings?.keyVaults as UserKeyVaults | undefined;
|
||||
const [userState, aiProviderRuntimeState] = await Promise.all([
|
||||
userModel.getUserState(KeyVaultsGateKeeper.getUserKeyVaults),
|
||||
this.getAiProviderRuntimeState(params.userId),
|
||||
]);
|
||||
const keyVaults = this.resolveRuntimeKeyVaults(aiProviderRuntimeState);
|
||||
const language = params.language || userState.settings?.general?.responseLanguage;
|
||||
|
||||
const runtimes = await this.getRuntime(params.userId, keyVaults);
|
||||
@@ -1849,7 +1978,10 @@ export class MemoryExtractionWorkflowService {
|
||||
return this.client;
|
||||
}
|
||||
|
||||
static triggerProcessUsers(payload: MemoryExtractionPayloadInput, options?: { extraHeaders?: Record<string, string> }) {
|
||||
static triggerProcessUsers(
|
||||
payload: MemoryExtractionPayloadInput,
|
||||
options?: { extraHeaders?: Record<string, string> },
|
||||
) {
|
||||
if (!payload.baseUrl) {
|
||||
throw new Error('Missing baseUrl for workflow trigger');
|
||||
}
|
||||
@@ -1858,7 +1990,10 @@ export class MemoryExtractionWorkflowService {
|
||||
return this.getClient().trigger({ body: payload, headers: options?.extraHeaders, url });
|
||||
}
|
||||
|
||||
static triggerProcessUserTopics(payload: UserTopicWorkflowPayload, options?: { extraHeaders?: Record<string, string> }) {
|
||||
static triggerProcessUserTopics(
|
||||
payload: UserTopicWorkflowPayload,
|
||||
options?: { extraHeaders?: Record<string, string> },
|
||||
) {
|
||||
if (!payload.baseUrl) {
|
||||
throw new Error('Missing baseUrl for workflow trigger');
|
||||
}
|
||||
@@ -1867,7 +2002,10 @@ export class MemoryExtractionWorkflowService {
|
||||
return this.getClient().trigger({ body: payload, headers: options?.extraHeaders, url });
|
||||
}
|
||||
|
||||
static triggerProcessTopics(payload: MemoryExtractionPayloadInput, options?: { extraHeaders?: Record<string, string> }) {
|
||||
static triggerProcessTopics(
|
||||
payload: MemoryExtractionPayloadInput,
|
||||
options?: { extraHeaders?: Record<string, string> },
|
||||
) {
|
||||
if (!payload.baseUrl) {
|
||||
throw new Error('Missing baseUrl for workflow trigger');
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user