mirror of
https://github.com/lobehub/lobehub.git
synced 2026-03-27 13:29:15 +07:00
♻️ refactor(memory-user-memory): simplify buildContext(...) (#11808)
This commit is contained in:
@@ -3,9 +3,9 @@ import { MemorySourceType } from '@lobechat/types';
|
||||
import { readFile } from 'node:fs/promises';
|
||||
import { isAbsolute, join } from 'node:path';
|
||||
|
||||
import { BenchmarkLocomoContextProvider, BenchmarkLocomoPart } from '../../../../src/providers';
|
||||
import type { IngestPayload } from '../../../../src/converters/locomo';
|
||||
import { activityPrompt } from '../../../../src/prompts';
|
||||
import { BenchmarkLocomoContextProvider, BenchmarkLocomoPart } from '../../../../src/providers';
|
||||
import type { ExtractorTemplateProps, MemoryExtractionJob } from '../../../../src/types';
|
||||
|
||||
export interface PromptVars extends ExtractorTemplateProps {
|
||||
@@ -44,7 +44,11 @@ const buildParts = (payload: IngestPayload, sessionId?: string): BenchmarkLocomo
|
||||
);
|
||||
};
|
||||
|
||||
const resolveSessionDate = (payload: IngestPayload, parts: BenchmarkLocomoPart[], sessionId?: string) => {
|
||||
const resolveSessionDate = (
|
||||
payload: IngestPayload,
|
||||
parts: BenchmarkLocomoPart[],
|
||||
sessionId?: string,
|
||||
) => {
|
||||
const sessionDate =
|
||||
payload.sessions.find((session) => session.sessionId === sessionId)?.timestamp ||
|
||||
payload.sessions[0]?.timestamp;
|
||||
@@ -67,7 +71,9 @@ export const buildLocomoActivityMessages = async (vars: PromptVars) => {
|
||||
|
||||
const parts = buildParts(payload, vars.sessionId);
|
||||
if (parts.length === 0) {
|
||||
throw new Error(`No matching parts found in ${payload.sampleId} for session ${vars.sessionId || 'all'}`);
|
||||
throw new Error(
|
||||
`No matching parts found in ${payload.sampleId} for session ${vars.sessionId || 'all'}`,
|
||||
);
|
||||
}
|
||||
const userId = vars.userId || `locomo-user-${payload.sampleId}`;
|
||||
const sourceId = payload.topicId || `sample_${payload.sampleId}`;
|
||||
@@ -86,7 +92,7 @@ export const buildLocomoActivityMessages = async (vars: PromptVars) => {
|
||||
userId,
|
||||
};
|
||||
|
||||
const { context } = await provider.buildContext(extractionJob);
|
||||
const { context } = await provider.buildContext(extractionJob.userId);
|
||||
|
||||
const rendered = renderPlaceholderTemplate(activityPrompt, {
|
||||
availableCategories: vars.availableCategories,
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import { MemorySourceType } from '@lobechat/types';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import locomoIngestPayloads from './tests/benchmark-locomo-converted.json';
|
||||
import { BenchmarkLocomoContextProvider } from './benchmarkLocomo';
|
||||
import locomoIngestPayloads from './tests/benchmark-locomo-converted.json';
|
||||
|
||||
describe('BenchmarkLocomoContextProvider', () => {
|
||||
it('should convert LoCoMo ingest payload into benchmark XML context', async () => {
|
||||
@@ -42,11 +41,7 @@ describe('BenchmarkLocomoContextProvider', () => {
|
||||
userId,
|
||||
});
|
||||
|
||||
const result = await provider.buildContext({
|
||||
source: MemorySourceType.BenchmarkLocomo,
|
||||
sourceId,
|
||||
userId,
|
||||
});
|
||||
const result = await provider.buildContext(userId);
|
||||
|
||||
const { context } = result;
|
||||
expect(result.sourceId).toBe(sourceId);
|
||||
|
||||
@@ -3,7 +3,7 @@ import { toXml } from 'xast-util-to-xml';
|
||||
import type { Child } from 'xastscript';
|
||||
import { x } from 'xastscript';
|
||||
|
||||
import type { BuiltContext, MemoryContextProvider, MemoryExtractionJob } from '../types';
|
||||
import type { BuiltContext, MemoryContextProvider } from '../types';
|
||||
|
||||
export interface BenchmarkLocomoPart {
|
||||
content: string;
|
||||
@@ -21,9 +21,10 @@ export interface BenchmarkLocomoContextProviderOptions {
|
||||
userId: string;
|
||||
}
|
||||
|
||||
export class BenchmarkLocomoContextProvider
|
||||
implements MemoryContextProvider<Record<string, unknown>, Record<string, unknown>>
|
||||
{
|
||||
export class BenchmarkLocomoContextProvider implements MemoryContextProvider<
|
||||
Record<string, unknown>,
|
||||
Record<string, unknown>
|
||||
> {
|
||||
private readonly options: BenchmarkLocomoContextProviderOptions;
|
||||
|
||||
constructor(options: BenchmarkLocomoContextProviderOptions) {
|
||||
@@ -42,7 +43,7 @@ export class BenchmarkLocomoContextProvider
|
||||
return x('message', attributes, part.content, metadata ? `\n[metadata:${metadata}]` : '');
|
||||
}
|
||||
|
||||
async buildContext(job: MemoryExtractionJob): Promise<BuiltContext<Record<string, unknown>>> {
|
||||
async buildContext(userId: string): Promise<BuiltContext<Record<string, unknown>>> {
|
||||
const messageChildren: Child[] = this.options.parts.map((part, index) =>
|
||||
this.buildMessageNode(part, index),
|
||||
);
|
||||
@@ -63,7 +64,7 @@ export class BenchmarkLocomoContextProvider
|
||||
context: toXml(root),
|
||||
metadata: {},
|
||||
sourceId: this.options.sourceId,
|
||||
userId: job.userId,
|
||||
userId: userId,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import { MemorySourceType } from '@lobechat/types';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { LobeChatTopicContextProvider } from './chatTopic';
|
||||
|
||||
@@ -34,7 +34,7 @@ describe('LobeChatTopicContextProvider', () => {
|
||||
topicId: 'topic-1',
|
||||
});
|
||||
|
||||
const result = await provider.buildContext(job);
|
||||
const result = await provider.buildContext(job.userId);
|
||||
|
||||
expect(result.context).toContain(
|
||||
'<chat_topic created_at="2024-03-01T09:00:00.000Z" id="topic-1" message_count="2" last_message_at="2024-03-01T10:01:00.000Z" updated_at="2024-03-01T10:02:00.000Z">',
|
||||
|
||||
@@ -47,7 +47,7 @@ export class LobeChatTopicContextProvider implements MemoryContextProvider<
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
async buildContext(job: MemoryExtractionJob): Promise<BuiltContext> {
|
||||
async buildContext(userId: string): Promise<BuiltContext> {
|
||||
const messageChildren: Child[] = [];
|
||||
|
||||
this.options.conversations.forEach((message, index) => {
|
||||
@@ -118,7 +118,7 @@ export class LobeChatTopicContextProvider implements MemoryContextProvider<
|
||||
context: topicContext,
|
||||
metadata: {},
|
||||
sourceId: this.options.topicId,
|
||||
userId: job.userId,
|
||||
userId: userId,
|
||||
} satisfies BuiltContext;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { LayersEnum, MemorySourceType } from '@lobechat/types';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import {
|
||||
RetrievalUserMemoryContextProvider,
|
||||
RetrievalUserMemoryIdentitiesProvider,
|
||||
} from './existingUserMemory';
|
||||
import { LayersEnum, MemorySourceType } from '@lobechat/types';
|
||||
|
||||
const job = {
|
||||
source: MemorySourceType.ChatTopic,
|
||||
@@ -109,7 +109,7 @@ describe('RetrievalUserMemoryContextProvider', () => {
|
||||
},
|
||||
});
|
||||
|
||||
const result = await provider.buildContext(job);
|
||||
const result = await provider.buildContext(job.userId, job.sourceId);
|
||||
|
||||
expect(result.sourceId).toBe('topic-1');
|
||||
expect(result.userId).toBe('user-1');
|
||||
@@ -167,7 +167,7 @@ describe('RetrievalUserMemoryIdentitiesProvider', () => {
|
||||
],
|
||||
});
|
||||
|
||||
const result = await provider.buildContext(job);
|
||||
const result = await provider.buildContext(job.userId, job.sourceId);
|
||||
|
||||
expect(result.sourceId).toBe('topic-1');
|
||||
expect(result.userId).toBe('user-1');
|
||||
|
||||
@@ -10,7 +10,7 @@ import { toXml } from 'xast-util-to-xml';
|
||||
import type { Child } from 'xastscript';
|
||||
import { x } from 'xastscript';
|
||||
|
||||
import type { BuiltContext, MemoryContextProvider, MemoryExtractionJob } from '../types';
|
||||
import type { BuiltContext, MemoryContextProvider } from '../types';
|
||||
|
||||
interface RetrievedMemories {
|
||||
activities: UserMemoryActivityWithoutVectors[];
|
||||
@@ -69,14 +69,12 @@ export class RetrievalUserMemoryContextProvider implements MemoryContextProvider
|
||||
return parts.join(' | ');
|
||||
}
|
||||
|
||||
async buildContext(job: MemoryExtractionJob): Promise<BuiltContext> {
|
||||
async buildContext(userId: string, sourceId: string): Promise<BuiltContext> {
|
||||
const activities = this.retrievedMemories.activities || [];
|
||||
const contexts = this.retrievedMemories.contexts || [];
|
||||
const experiences = this.retrievedMemories.experiences || [];
|
||||
const preferences = this.retrievedMemories.preferences || [];
|
||||
|
||||
|
||||
|
||||
const userMemoriesChildren: Child[] = [];
|
||||
|
||||
activities.forEach((activity) => {
|
||||
@@ -252,8 +250,8 @@ export class RetrievalUserMemoryContextProvider implements MemoryContextProvider
|
||||
return {
|
||||
context: memoryContext,
|
||||
metadata: {},
|
||||
sourceId: job.sourceId,
|
||||
userId: job.userId,
|
||||
sourceId: sourceId,
|
||||
userId: userId,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -267,7 +265,7 @@ export class RetrievalUserMemoryIdentitiesProvider implements MemoryContextProvi
|
||||
this.fetchedAt = options.fetchedAt;
|
||||
}
|
||||
|
||||
async buildContext(job: MemoryExtractionJob): Promise<BuiltContext> {
|
||||
async buildContext(userId: string, sourceId: string): Promise<BuiltContext> {
|
||||
const identityChildren: Child[] = [];
|
||||
|
||||
this.retrievedIdentities.forEach((item) => {
|
||||
@@ -345,8 +343,8 @@ export class RetrievalUserMemoryIdentitiesProvider implements MemoryContextProvi
|
||||
return {
|
||||
context: identityContext,
|
||||
metadata: {},
|
||||
sourceId: job.sourceId,
|
||||
userId: job.userId,
|
||||
sourceId: sourceId,
|
||||
userId: userId,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ export interface MemoryContextProvider<
|
||||
P extends Record<string, unknown> = Record<string, unknown>,
|
||||
R extends Record<string, unknown> = Record<string, unknown>,
|
||||
> {
|
||||
buildContext(job: MemoryExtractionJob, options?: P): Promise<BuiltContext<R>>;
|
||||
buildContext(userId: string, sourceId: string, options?: P): Promise<BuiltContext<R>>;
|
||||
}
|
||||
|
||||
export interface MemoryResultRecorder<T = Record<string, unknown>> {
|
||||
|
||||
@@ -1186,7 +1186,7 @@ export class MemoryExtractionExecutor {
|
||||
topic: topic,
|
||||
topicId: topic.id,
|
||||
});
|
||||
const topicContext = await topicContextProvider.buildContext(extractionJob);
|
||||
const topicContext = await topicContextProvider.buildContext(extractionJob.userId);
|
||||
|
||||
resultRecorder = new LobeChatTopicResultRecorder({
|
||||
currentMetadata: topic.metadata || {},
|
||||
@@ -1208,8 +1208,10 @@ export class MemoryExtractionExecutor {
|
||||
const retrievedMemoryContextProvider = new RetrievalUserMemoryContextProvider({
|
||||
retrievedMemories,
|
||||
});
|
||||
const retrievalMemoryContext =
|
||||
await retrievedMemoryContextProvider.buildContext(extractionJob);
|
||||
const retrievalMemoryContext = await retrievedMemoryContextProvider.buildContext(
|
||||
extractionJob.userId,
|
||||
extractionJob.sourceId,
|
||||
);
|
||||
|
||||
const retrievedMemoryIdentities = await this.listUserMemoryIdentities(
|
||||
extractionJob,
|
||||
@@ -1220,7 +1222,10 @@ export class MemoryExtractionExecutor {
|
||||
retrievedIdentities: retrievedMemoryIdentities,
|
||||
});
|
||||
const retrievedIdentityContext =
|
||||
await retrievedMemoryIdentitiesContextProvider.buildContext(extractionJob);
|
||||
await retrievedMemoryIdentitiesContextProvider.buildContext(
|
||||
extractionJob.userId,
|
||||
extractionJob.sourceId,
|
||||
);
|
||||
const trimmedRetrievedContexts = [
|
||||
topicContext.context,
|
||||
retrievalMemoryContext.context,
|
||||
@@ -1994,7 +1999,7 @@ export class MemoryExtractionExecutor {
|
||||
userId: params.userId,
|
||||
};
|
||||
|
||||
const builtContext = await contextProvider.buildContext(extractionJob);
|
||||
const builtContext = await contextProvider.buildContext(extractionJob.userId);
|
||||
const extractorContextLimit = this.privateConfig.agentLayerExtractor.contextLimit;
|
||||
const trimmedContext = this.trimTextToTokenLimit(
|
||||
builtContext.context,
|
||||
|
||||
Reference in New Issue
Block a user