♻️ refactor(memory-user-memory): simplify buildContext(...) (#11808)

This commit is contained in:
Neko
2026-01-25 16:55:15 +08:00
committed by GitHub
parent 5929f7b196
commit d5a9913155
9 changed files with 44 additions and 39 deletions

View File

@@ -3,9 +3,9 @@ import { MemorySourceType } from '@lobechat/types';
import { readFile } from 'node:fs/promises';
import { isAbsolute, join } from 'node:path';
import { BenchmarkLocomoContextProvider, BenchmarkLocomoPart } from '../../../../src/providers';
import type { IngestPayload } from '../../../../src/converters/locomo';
import { activityPrompt } from '../../../../src/prompts';
import { BenchmarkLocomoContextProvider, BenchmarkLocomoPart } from '../../../../src/providers';
import type { ExtractorTemplateProps, MemoryExtractionJob } from '../../../../src/types';
export interface PromptVars extends ExtractorTemplateProps {
@@ -44,7 +44,11 @@ const buildParts = (payload: IngestPayload, sessionId?: string): BenchmarkLocomo
);
};
const resolveSessionDate = (payload: IngestPayload, parts: BenchmarkLocomoPart[], sessionId?: string) => {
const resolveSessionDate = (
payload: IngestPayload,
parts: BenchmarkLocomoPart[],
sessionId?: string,
) => {
const sessionDate =
payload.sessions.find((session) => session.sessionId === sessionId)?.timestamp ||
payload.sessions[0]?.timestamp;
@@ -67,7 +71,9 @@ export const buildLocomoActivityMessages = async (vars: PromptVars) => {
const parts = buildParts(payload, vars.sessionId);
if (parts.length === 0) {
throw new Error(`No matching parts found in ${payload.sampleId} for session ${vars.sessionId || 'all'}`);
throw new Error(
`No matching parts found in ${payload.sampleId} for session ${vars.sessionId || 'all'}`,
);
}
const userId = vars.userId || `locomo-user-${payload.sampleId}`;
const sourceId = payload.topicId || `sample_${payload.sampleId}`;
@@ -86,7 +92,7 @@ export const buildLocomoActivityMessages = async (vars: PromptVars) => {
userId,
};
const { context } = await provider.buildContext(extractionJob);
const { context } = await provider.buildContext(extractionJob.userId);
const rendered = renderPlaceholderTemplate(activityPrompt, {
availableCategories: vars.availableCategories,

View File

@@ -1,8 +1,7 @@
import { MemorySourceType } from '@lobechat/types';
import { describe, expect, it } from 'vitest';
import locomoIngestPayloads from './tests/benchmark-locomo-converted.json';
import { BenchmarkLocomoContextProvider } from './benchmarkLocomo';
import locomoIngestPayloads from './tests/benchmark-locomo-converted.json';
describe('BenchmarkLocomoContextProvider', () => {
it('should convert LoCoMo ingest payload into benchmark XML context', async () => {
@@ -42,11 +41,7 @@ describe('BenchmarkLocomoContextProvider', () => {
userId,
});
const result = await provider.buildContext({
source: MemorySourceType.BenchmarkLocomo,
sourceId,
userId,
});
const result = await provider.buildContext(userId);
const { context } = result;
expect(result.sourceId).toBe(sourceId);

View File

@@ -3,7 +3,7 @@ import { toXml } from 'xast-util-to-xml';
import type { Child } from 'xastscript';
import { x } from 'xastscript';
import type { BuiltContext, MemoryContextProvider, MemoryExtractionJob } from '../types';
import type { BuiltContext, MemoryContextProvider } from '../types';
export interface BenchmarkLocomoPart {
content: string;
@@ -21,9 +21,10 @@ export interface BenchmarkLocomoContextProviderOptions {
userId: string;
}
export class BenchmarkLocomoContextProvider
implements MemoryContextProvider<Record<string, unknown>, Record<string, unknown>>
{
export class BenchmarkLocomoContextProvider implements MemoryContextProvider<
Record<string, unknown>,
Record<string, unknown>
> {
private readonly options: BenchmarkLocomoContextProviderOptions;
constructor(options: BenchmarkLocomoContextProviderOptions) {
@@ -42,7 +43,7 @@ export class BenchmarkLocomoContextProvider
return x('message', attributes, part.content, metadata ? `\n[metadata:${metadata}]` : '');
}
async buildContext(job: MemoryExtractionJob): Promise<BuiltContext<Record<string, unknown>>> {
async buildContext(userId: string): Promise<BuiltContext<Record<string, unknown>>> {
const messageChildren: Child[] = this.options.parts.map((part, index) =>
this.buildMessageNode(part, index),
);
@@ -63,7 +64,7 @@ export class BenchmarkLocomoContextProvider
context: toXml(root),
metadata: {},
sourceId: this.options.sourceId,
userId: job.userId,
userId: userId,
};
}
}

View File

@@ -1,5 +1,5 @@
import { describe, expect, it } from 'vitest';
import { MemorySourceType } from '@lobechat/types';
import { describe, expect, it } from 'vitest';
import { LobeChatTopicContextProvider } from './chatTopic';
@@ -34,7 +34,7 @@ describe('LobeChatTopicContextProvider', () => {
topicId: 'topic-1',
});
const result = await provider.buildContext(job);
const result = await provider.buildContext(job.userId);
expect(result.context).toContain(
'<chat_topic created_at="2024-03-01T09:00:00.000Z" id="topic-1" message_count="2" last_message_at="2024-03-01T10:01:00.000Z" updated_at="2024-03-01T10:02:00.000Z">',

View File

@@ -47,7 +47,7 @@ export class LobeChatTopicContextProvider implements MemoryContextProvider<
this.options = options;
}
async buildContext(job: MemoryExtractionJob): Promise<BuiltContext> {
async buildContext(userId: string): Promise<BuiltContext> {
const messageChildren: Child[] = [];
this.options.conversations.forEach((message, index) => {
@@ -118,7 +118,7 @@ export class LobeChatTopicContextProvider implements MemoryContextProvider<
context: topicContext,
metadata: {},
sourceId: this.options.topicId,
userId: job.userId,
userId: userId,
} satisfies BuiltContext;
}
}

View File

@@ -1,10 +1,10 @@
import { LayersEnum, MemorySourceType } from '@lobechat/types';
import { describe, expect, it } from 'vitest';
import {
RetrievalUserMemoryContextProvider,
RetrievalUserMemoryIdentitiesProvider,
} from './existingUserMemory';
import { LayersEnum, MemorySourceType } from '@lobechat/types';
const job = {
source: MemorySourceType.ChatTopic,
@@ -109,7 +109,7 @@ describe('RetrievalUserMemoryContextProvider', () => {
},
});
const result = await provider.buildContext(job);
const result = await provider.buildContext(job.userId, job.sourceId);
expect(result.sourceId).toBe('topic-1');
expect(result.userId).toBe('user-1');
@@ -167,7 +167,7 @@ describe('RetrievalUserMemoryIdentitiesProvider', () => {
],
});
const result = await provider.buildContext(job);
const result = await provider.buildContext(job.userId, job.sourceId);
expect(result.sourceId).toBe('topic-1');
expect(result.userId).toBe('user-1');

View File

@@ -10,7 +10,7 @@ import { toXml } from 'xast-util-to-xml';
import type { Child } from 'xastscript';
import { x } from 'xastscript';
import type { BuiltContext, MemoryContextProvider, MemoryExtractionJob } from '../types';
import type { BuiltContext, MemoryContextProvider } from '../types';
interface RetrievedMemories {
activities: UserMemoryActivityWithoutVectors[];
@@ -69,14 +69,12 @@ export class RetrievalUserMemoryContextProvider implements MemoryContextProvider
return parts.join(' | ');
}
async buildContext(job: MemoryExtractionJob): Promise<BuiltContext> {
async buildContext(userId: string, sourceId: string): Promise<BuiltContext> {
const activities = this.retrievedMemories.activities || [];
const contexts = this.retrievedMemories.contexts || [];
const experiences = this.retrievedMemories.experiences || [];
const preferences = this.retrievedMemories.preferences || [];
const userMemoriesChildren: Child[] = [];
activities.forEach((activity) => {
@@ -252,8 +250,8 @@ export class RetrievalUserMemoryContextProvider implements MemoryContextProvider
return {
context: memoryContext,
metadata: {},
sourceId: job.sourceId,
userId: job.userId,
sourceId: sourceId,
userId: userId,
};
}
}
@@ -267,7 +265,7 @@ export class RetrievalUserMemoryIdentitiesProvider implements MemoryContextProvi
this.fetchedAt = options.fetchedAt;
}
async buildContext(job: MemoryExtractionJob): Promise<BuiltContext> {
async buildContext(userId: string, sourceId: string): Promise<BuiltContext> {
const identityChildren: Child[] = [];
this.retrievedIdentities.forEach((item) => {
@@ -345,8 +343,8 @@ export class RetrievalUserMemoryIdentitiesProvider implements MemoryContextProvi
return {
context: identityContext,
metadata: {},
sourceId: job.sourceId,
userId: job.userId,
sourceId: sourceId,
userId: userId,
};
}
}

View File

@@ -112,7 +112,7 @@ export interface MemoryContextProvider<
P extends Record<string, unknown> = Record<string, unknown>,
R extends Record<string, unknown> = Record<string, unknown>,
> {
buildContext(job: MemoryExtractionJob, options?: P): Promise<BuiltContext<R>>;
buildContext(userId: string, sourceId: string, options?: P): Promise<BuiltContext<R>>;
}
export interface MemoryResultRecorder<T = Record<string, unknown>> {

View File

@@ -1186,7 +1186,7 @@ export class MemoryExtractionExecutor {
topic: topic,
topicId: topic.id,
});
const topicContext = await topicContextProvider.buildContext(extractionJob);
const topicContext = await topicContextProvider.buildContext(extractionJob.userId);
resultRecorder = new LobeChatTopicResultRecorder({
currentMetadata: topic.metadata || {},
@@ -1208,8 +1208,10 @@ export class MemoryExtractionExecutor {
const retrievedMemoryContextProvider = new RetrievalUserMemoryContextProvider({
retrievedMemories,
});
const retrievalMemoryContext =
await retrievedMemoryContextProvider.buildContext(extractionJob);
const retrievalMemoryContext = await retrievedMemoryContextProvider.buildContext(
extractionJob.userId,
extractionJob.sourceId,
);
const retrievedMemoryIdentities = await this.listUserMemoryIdentities(
extractionJob,
@@ -1220,7 +1222,10 @@ export class MemoryExtractionExecutor {
retrievedIdentities: retrievedMemoryIdentities,
});
const retrievedIdentityContext =
await retrievedMemoryIdentitiesContextProvider.buildContext(extractionJob);
await retrievedMemoryIdentitiesContextProvider.buildContext(
extractionJob.userId,
extractionJob.sourceId,
);
const trimmedRetrievedContexts = [
topicContext.context,
retrievalMemoryContext.context,
@@ -1994,7 +1999,7 @@ export class MemoryExtractionExecutor {
userId: params.userId,
};
const builtContext = await contextProvider.buildContext(extractionJob);
const builtContext = await contextProvider.buildContext(extractionJob.userId);
const extractorContextLimit = this.privateConfig.agentLayerExtractor.contextLimit;
const trimmedContext = this.trimTextToTokenLimit(
builtContext.context,