mirror of
https://github.com/lobehub/lobehub.git
synced 2026-03-26 13:19:34 +07:00
✨ feat: support server context compression (#12976)
* ♻️ refactor: add eval-only server context compression * ♻️ refactor: align eval compression with runtime step flow * ♻️ refactor: trim redundant call_llm diff * ✨ add mid-run context compression step * 📝 document post compression helper * 🐛 revert unnecessary agent runtime service diff * ♻️ refactor: clean up context compression follow-up logic * ♻️ refactor: move compression gate before call llm * ♻️ refactor: make call llm compression gate explicit * ♻️ refactor: restore agent-side compression checks * ♻️ refactor: rename agent llm continuation helper * ♻️ refactor: inline agent compression helper * ♻️ refactor: preserve trailing user message during compression * 📝 docs: clarify toLLMCall refactor direction * ✅ test: add coverage for context compression flow * ⏪ reset: unstash
This commit is contained in:
@@ -310,6 +310,36 @@ export class GeneralChatAgent implements Agent {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Proceed to the next LLM call, inserting compression first when needed.
|
||||
*/
|
||||
private toLLMCall(payload: GeneralAgentCallLLMInstructionPayload): AgentInstruction {
|
||||
const compressionEnabled = this.config.compressionConfig?.enabled ?? true;
|
||||
|
||||
if (compressionEnabled) {
|
||||
const messages = payload.messages;
|
||||
const compressionCheck = shouldCompress(messages, {
|
||||
maxWindowToken: this.config.compressionConfig?.maxWindowToken,
|
||||
});
|
||||
|
||||
if (compressionCheck.needsCompression) {
|
||||
return {
|
||||
payload: {
|
||||
currentTokenCount: compressionCheck.currentTokenCount,
|
||||
existingSummary: this.findExistingSummary(messages),
|
||||
messages,
|
||||
},
|
||||
type: 'compress_context',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
payload,
|
||||
type: 'call_llm',
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle abort scenario - unified abort handling logic
|
||||
*/
|
||||
@@ -517,16 +547,13 @@ export class GeneralChatAgent implements Agent {
|
||||
}
|
||||
|
||||
// No pending tools, continue to call LLM with tool results
|
||||
return {
|
||||
payload: {
|
||||
messages: state.messages,
|
||||
model: this.config.modelRuntimeConfig?.model,
|
||||
parentMessageId,
|
||||
provider: this.config.modelRuntimeConfig?.provider,
|
||||
tools: state.tools,
|
||||
} as GeneralAgentCallLLMInstructionPayload,
|
||||
type: 'call_llm',
|
||||
};
|
||||
return this.toLLMCall({
|
||||
messages: state.messages,
|
||||
model: this.config.modelRuntimeConfig?.model,
|
||||
parentMessageId,
|
||||
provider: this.config.modelRuntimeConfig?.provider,
|
||||
tools: state.tools,
|
||||
} as GeneralAgentCallLLMInstructionPayload);
|
||||
}
|
||||
|
||||
case 'tools_batch_result': {
|
||||
@@ -550,16 +577,13 @@ export class GeneralChatAgent implements Agent {
|
||||
}
|
||||
|
||||
// No pending tools, continue to call LLM with tool results
|
||||
return {
|
||||
payload: {
|
||||
messages: state.messages,
|
||||
model: this.config.modelRuntimeConfig?.model,
|
||||
parentMessageId,
|
||||
provider: this.config.modelRuntimeConfig?.provider,
|
||||
tools: state.tools,
|
||||
} as GeneralAgentCallLLMInstructionPayload,
|
||||
type: 'call_llm',
|
||||
};
|
||||
return this.toLLMCall({
|
||||
messages: state.messages,
|
||||
model: this.config.modelRuntimeConfig?.model,
|
||||
parentMessageId,
|
||||
provider: this.config.modelRuntimeConfig?.provider,
|
||||
tools: state.tools,
|
||||
} as GeneralAgentCallLLMInstructionPayload);
|
||||
}
|
||||
|
||||
case 'task_result': {
|
||||
@@ -567,16 +591,13 @@ export class GeneralChatAgent implements Agent {
|
||||
const { parentMessageId } = context.payload as TaskResultPayload;
|
||||
|
||||
// Continue to call LLM with updated messages (task message is already in state)
|
||||
return {
|
||||
payload: {
|
||||
messages: state.messages,
|
||||
model: this.config.modelRuntimeConfig?.model,
|
||||
parentMessageId,
|
||||
provider: this.config.modelRuntimeConfig?.provider,
|
||||
tools: state.tools,
|
||||
} as GeneralAgentCallLLMInstructionPayload,
|
||||
type: 'call_llm',
|
||||
};
|
||||
return this.toLLMCall({
|
||||
messages: state.messages,
|
||||
model: this.config.modelRuntimeConfig?.model,
|
||||
parentMessageId,
|
||||
provider: this.config.modelRuntimeConfig?.provider,
|
||||
tools: state.tools,
|
||||
} as GeneralAgentCallLLMInstructionPayload);
|
||||
}
|
||||
|
||||
case 'tasks_batch_result': {
|
||||
@@ -596,16 +617,13 @@ export class GeneralChatAgent implements Agent {
|
||||
];
|
||||
|
||||
// Continue to call LLM with updated messages (task messages are already in state)
|
||||
return {
|
||||
payload: {
|
||||
messages: messagesWithPrompt,
|
||||
model: this.config.modelRuntimeConfig?.model,
|
||||
parentMessageId,
|
||||
provider: this.config.modelRuntimeConfig?.provider,
|
||||
tools: state.tools,
|
||||
} as GeneralAgentCallLLMInstructionPayload,
|
||||
type: 'call_llm',
|
||||
};
|
||||
return this.toLLMCall({
|
||||
messages: messagesWithPrompt,
|
||||
model: this.config.modelRuntimeConfig?.model,
|
||||
parentMessageId,
|
||||
provider: this.config.modelRuntimeConfig?.provider,
|
||||
tools: state.tools,
|
||||
} as GeneralAgentCallLLMInstructionPayload);
|
||||
}
|
||||
|
||||
case 'compression_result': {
|
||||
|
||||
@@ -53,6 +53,26 @@ describe('GeneralChatAgent', () => {
|
||||
},
|
||||
});
|
||||
|
||||
const createCompressionAgent = () =>
|
||||
new GeneralChatAgent({
|
||||
agentConfig: { maxSteps: 100 },
|
||||
compressionConfig: {
|
||||
enabled: true,
|
||||
maxWindowToken: 1,
|
||||
},
|
||||
operationId: 'test-session',
|
||||
modelRuntimeConfig: mockModelRuntimeConfig,
|
||||
});
|
||||
|
||||
const expectCompressionInstruction = (messages: AgentState['messages']) => ({
|
||||
type: 'compress_context',
|
||||
payload: {
|
||||
currentTokenCount: expect.any(Number),
|
||||
existingSummary: undefined,
|
||||
messages,
|
||||
},
|
||||
});
|
||||
|
||||
describe('init and user_input phase', () => {
|
||||
it('should return call_llm instruction for init phase', async () => {
|
||||
const agent = new GeneralChatAgent({
|
||||
@@ -612,6 +632,26 @@ describe('GeneralChatAgent', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should return compress_context before continuing to LLM when tool results exceed window', async () => {
|
||||
const agent = createCompressionAgent();
|
||||
|
||||
const state = createMockState({
|
||||
messages: [
|
||||
{ role: 'user', content: 'Hello' },
|
||||
{ role: 'assistant', content: '' },
|
||||
{ role: 'tool', content: 'Result', tool_call_id: 'call-1' },
|
||||
] as any,
|
||||
});
|
||||
|
||||
const context = createMockContext('tool_result', {
|
||||
parentMessageId: 'tool-msg-1',
|
||||
});
|
||||
|
||||
const result = await agent.runner(context, state);
|
||||
|
||||
expect(result).toEqual(expectCompressionInstruction(state.messages));
|
||||
});
|
||||
|
||||
it('should return request_human_approve when there are pending tools', async () => {
|
||||
const agent = new GeneralChatAgent({
|
||||
agentConfig: { maxSteps: 100 },
|
||||
@@ -736,6 +776,27 @@ describe('GeneralChatAgent', () => {
|
||||
skipCreateToolMessage: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should return compress_context before continuing to LLM when batch tool results exceed window', async () => {
|
||||
const agent = createCompressionAgent();
|
||||
|
||||
const state = createMockState({
|
||||
messages: [
|
||||
{ role: 'user', content: 'Hello' },
|
||||
{ role: 'assistant', content: '' },
|
||||
{ role: 'tool', content: 'Result 1', tool_call_id: 'call-1' },
|
||||
{ role: 'tool', content: 'Result 2', tool_call_id: 'call-2' },
|
||||
] as any,
|
||||
});
|
||||
|
||||
const context = createMockContext('tools_batch_result', {
|
||||
parentMessageId: 'tool-msg-2',
|
||||
});
|
||||
|
||||
const result = await agent.runner(context, state);
|
||||
|
||||
expect(result).toEqual(expectCompressionInstruction(state.messages));
|
||||
});
|
||||
});
|
||||
|
||||
describe('error phase', () => {
|
||||
@@ -1181,6 +1242,26 @@ describe('GeneralChatAgent', () => {
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should return compress_context before continuing to LLM when task results exceed window', async () => {
|
||||
const agent = createCompressionAgent();
|
||||
|
||||
const state = createMockState({
|
||||
messages: [
|
||||
{ role: 'user', content: 'Execute task' },
|
||||
{ role: 'assistant', content: '' },
|
||||
{ role: 'task', content: 'Task result', metadata: { instruction: 'Do task' } },
|
||||
] as any,
|
||||
});
|
||||
|
||||
const context = createMockContext('task_result', {
|
||||
parentMessageId: 'task-parent-msg',
|
||||
});
|
||||
|
||||
const result = await agent.runner(context, state);
|
||||
|
||||
expect(result).toEqual(expectCompressionInstruction(state.messages));
|
||||
});
|
||||
});
|
||||
|
||||
describe('tasks_batch_result phase (multiple tasks)', () => {
|
||||
@@ -1278,6 +1359,75 @@ describe('GeneralChatAgent', () => {
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should return compress_context and preserve the follow-up prompt when tasks exceed window', async () => {
|
||||
const agent = createCompressionAgent();
|
||||
|
||||
const state = createMockState({
|
||||
messages: [
|
||||
{ role: 'user', content: 'Execute tasks' },
|
||||
{ role: 'assistant', content: '' },
|
||||
{ role: 'task', content: 'Task 1 result', metadata: { instruction: 'Do task 1' } },
|
||||
{ role: 'task', content: 'Task 2 result', metadata: { instruction: 'Do task 2' } },
|
||||
] as any,
|
||||
});
|
||||
|
||||
const context = createMockContext('tasks_batch_result', {
|
||||
parentMessageId: 'task-parent-msg',
|
||||
});
|
||||
|
||||
const result = await agent.runner(context, state);
|
||||
|
||||
expect(result).toEqual(
|
||||
expectCompressionInstruction([
|
||||
...state.messages,
|
||||
{
|
||||
content:
|
||||
'All tasks above have been completed. Please summarize the results or continue with your response following user query language.',
|
||||
role: 'user',
|
||||
},
|
||||
]),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('compression_result phase', () => {
|
||||
it('should return call_llm with compressed messages and force a new assistant message', async () => {
|
||||
const agent = new GeneralChatAgent({
|
||||
agentConfig: { maxSteps: 100 },
|
||||
operationId: 'test-session',
|
||||
modelRuntimeConfig: mockModelRuntimeConfig,
|
||||
});
|
||||
|
||||
const compressedMessages = [
|
||||
{ content: 'Compressed summary', id: 'group-1', role: 'compressedGroup' },
|
||||
{ content: 'Latest user follow-up', role: 'user' },
|
||||
] as any;
|
||||
|
||||
const state = createMockState({
|
||||
tools: [{ name: 'search' }] as any,
|
||||
});
|
||||
|
||||
const context = createMockContext('compression_result', {
|
||||
compressedMessages,
|
||||
parentMessageId: 'assistant-msg-after-compression',
|
||||
skipped: false,
|
||||
});
|
||||
|
||||
const result = await agent.runner(context, state);
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'call_llm',
|
||||
payload: {
|
||||
createAssistantMessage: true,
|
||||
messages: compressedMessages,
|
||||
model: 'gpt-4o-mini',
|
||||
parentMessageId: 'assistant-msg-after-compression',
|
||||
provider: 'openai',
|
||||
tools: state.tools,
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('unknown phase', () => {
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import {
|
||||
type AgentEvent,
|
||||
type AgentInstruction,
|
||||
type AgentInstructionCompressContext,
|
||||
type CallLLMPayload,
|
||||
type GeneralAgentCallLLMResultPayload,
|
||||
type GeneralAgentCompressionResultPayload,
|
||||
type InstructionExecutor,
|
||||
UsageCounter,
|
||||
} from '@lobechat/agent-runtime';
|
||||
@@ -17,6 +19,7 @@ import {
|
||||
} from '@lobechat/context-engine';
|
||||
import { parse } from '@lobechat/conversation-flow';
|
||||
import { consumeStreamUntilDone } from '@lobechat/model-runtime';
|
||||
import { chainCompressContext } from '@lobechat/prompts';
|
||||
import { type ChatToolPayload, type MessageToolCall, type UIChatMessage } from '@lobechat/types';
|
||||
import { serializePartsForStorage } from '@lobechat/utils';
|
||||
import debug from 'debug';
|
||||
@@ -26,6 +29,7 @@ import { type LobeChatDatabase } from '@/database/type';
|
||||
import { serverMessagesEngine } from '@/server/modules/Mecha/ContextEngineering';
|
||||
import { type EvalContext } from '@/server/modules/Mecha/ContextEngineering/types';
|
||||
import { initModelRuntimeFromDB } from '@/server/modules/ModelRuntime';
|
||||
import { MessageService } from '@/server/services/message';
|
||||
import { type ToolExecutionService } from '@/server/services/toolExecution';
|
||||
|
||||
import { type IStreamEventManager } from './types';
|
||||
@@ -590,6 +594,258 @@ export const createRuntimeExecutors = (
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
|
||||
compress_context: async (instruction, state) => {
|
||||
const { payload } = instruction as AgentInstructionCompressContext;
|
||||
const { messages, currentTokenCount } = payload;
|
||||
const { operationId, stepIndex } = ctx;
|
||||
const operationLogId = `${operationId}:${stepIndex}`;
|
||||
const stagePrefix = `[${operationLogId}][compress_context]`;
|
||||
const events: AgentEvent[] = [];
|
||||
const newState = structuredClone(state);
|
||||
const topicId = state.metadata?.topicId;
|
||||
const lastMessage = messages.at(-1);
|
||||
const preservedMessages =
|
||||
messages.length > 1 && lastMessage?.role === 'user' ? [lastMessage] : [];
|
||||
const preservedMessageIds = new Set(
|
||||
preservedMessages.map((message) => message.id).filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
const messagesToCompress = preservedMessages.length > 0 ? messages.slice(0, -1) : messages;
|
||||
const compressedMessagesFallback = [...messagesToCompress, ...preservedMessages];
|
||||
|
||||
if (!topicId || !ctx.userId) {
|
||||
return {
|
||||
events,
|
||||
newState,
|
||||
nextContext: {
|
||||
payload: {
|
||||
compressedMessages: compressedMessagesFallback,
|
||||
groupId: '',
|
||||
parentMessageId: undefined,
|
||||
skipped: true,
|
||||
} as GeneralAgentCompressionResultPayload,
|
||||
phase: 'compression_result',
|
||||
session: {
|
||||
messageCount: newState.messages.length,
|
||||
sessionId: operationId,
|
||||
status: 'running',
|
||||
stepCount: state.stepCount + 1,
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const dbMessages = await ctx.messageModel.query({
|
||||
agentId: state.metadata?.agentId,
|
||||
threadId: state.metadata?.threadId,
|
||||
topicId,
|
||||
});
|
||||
|
||||
const messageIds = dbMessages
|
||||
.filter(
|
||||
(message) =>
|
||||
message.role !== 'compressedGroup' &&
|
||||
Boolean(message.id) &&
|
||||
!preservedMessageIds.has(message.id),
|
||||
)
|
||||
.map((message) => message.id);
|
||||
|
||||
if (messageIds.length === 0 || messagesToCompress.length === 0) {
|
||||
return {
|
||||
events,
|
||||
newState,
|
||||
nextContext: {
|
||||
payload: {
|
||||
compressedMessages: compressedMessagesFallback,
|
||||
groupId: '',
|
||||
parentMessageId: undefined,
|
||||
skipped: true,
|
||||
} as GeneralAgentCompressionResultPayload,
|
||||
phase: 'compression_result',
|
||||
session: {
|
||||
messageCount: newState.messages.length,
|
||||
sessionId: operationId,
|
||||
status: 'running',
|
||||
stepCount: state.stepCount + 1,
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const latestAssistantMessage = dbMessages.findLast((message) => message.role === 'assistant');
|
||||
const messageService = new MessageService(ctx.serverDB, ctx.userId);
|
||||
const compressionResult = await messageService.createCompressionGroup(topicId, messageIds, {
|
||||
agentId: state.metadata?.agentId,
|
||||
threadId: state.metadata?.threadId,
|
||||
topicId,
|
||||
});
|
||||
|
||||
const compressionModel =
|
||||
newState.modelRuntimeConfig?.compressionModel || newState.modelRuntimeConfig;
|
||||
|
||||
if (!compressionModel?.model || !compressionModel?.provider) {
|
||||
return {
|
||||
events,
|
||||
newState,
|
||||
nextContext: {
|
||||
payload: {
|
||||
compressedMessages: compressedMessagesFallback,
|
||||
groupId: '',
|
||||
parentMessageId: latestAssistantMessage?.id,
|
||||
skipped: true,
|
||||
} as GeneralAgentCompressionResultPayload,
|
||||
phase: 'compression_result',
|
||||
session: {
|
||||
messageCount: newState.messages.length,
|
||||
sessionId: operationId,
|
||||
status: 'running',
|
||||
stepCount: state.stepCount + 1,
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const compressionPayload = chainCompressContext(compressionResult.messagesToSummarize);
|
||||
const compressionRuntime = await initModelRuntimeFromDB(
|
||||
ctx.serverDB,
|
||||
ctx.userId,
|
||||
compressionModel.provider,
|
||||
);
|
||||
|
||||
let summaryContent = '';
|
||||
let summaryUsage: any;
|
||||
let summaryError: any;
|
||||
|
||||
const compressionResponse = await compressionRuntime.chat(
|
||||
{
|
||||
messages: compressionPayload.messages!,
|
||||
model: compressionModel.model,
|
||||
stream: true,
|
||||
},
|
||||
{
|
||||
callback: {
|
||||
onCompletion: async (data) => {
|
||||
if (data.usage) summaryUsage = data.usage;
|
||||
},
|
||||
onError: async (errorData) => {
|
||||
summaryError = errorData;
|
||||
},
|
||||
onText: async (text) => {
|
||||
summaryContent += text;
|
||||
},
|
||||
},
|
||||
user: ctx.userId,
|
||||
},
|
||||
);
|
||||
|
||||
await consumeStreamUntilDone(compressionResponse);
|
||||
|
||||
if (summaryError) {
|
||||
throw new Error(
|
||||
typeof summaryError.message === 'string'
|
||||
? summaryError.message
|
||||
: JSON.stringify(summaryError),
|
||||
);
|
||||
}
|
||||
|
||||
const finalCompression = await messageService.finalizeCompression(
|
||||
compressionResult.messageGroupId,
|
||||
summaryContent,
|
||||
{
|
||||
agentId: state.metadata?.agentId,
|
||||
threadId: state.metadata?.threadId,
|
||||
topicId,
|
||||
},
|
||||
);
|
||||
|
||||
const compressedMessagesBase =
|
||||
finalCompression.messages || compressionResult.messagesToSummarize;
|
||||
const compressedMessages = [...compressedMessagesBase];
|
||||
|
||||
for (const preservedMessage of preservedMessages) {
|
||||
if (
|
||||
!compressedMessages.some(
|
||||
(message) =>
|
||||
message === preservedMessage ||
|
||||
(Boolean(message.id) &&
|
||||
Boolean(preservedMessage.id) &&
|
||||
message.id === preservedMessage.id),
|
||||
)
|
||||
) {
|
||||
compressedMessages.push(preservedMessage);
|
||||
}
|
||||
}
|
||||
|
||||
newState.messages = compressedMessages;
|
||||
|
||||
if (summaryUsage) {
|
||||
const { usage, cost } = UsageCounter.accumulateLLM({
|
||||
cost: newState.cost,
|
||||
model: compressionModel.model,
|
||||
modelUsage: summaryUsage,
|
||||
provider: compressionModel.provider,
|
||||
usage: newState.usage,
|
||||
});
|
||||
|
||||
newState.usage = usage;
|
||||
if (cost) newState.cost = cost;
|
||||
}
|
||||
|
||||
events.push({
|
||||
groupId: compressionResult.messageGroupId,
|
||||
parentMessageId: latestAssistantMessage?.id,
|
||||
type: 'compression_complete',
|
||||
});
|
||||
|
||||
return {
|
||||
events,
|
||||
newState,
|
||||
nextContext: {
|
||||
payload: {
|
||||
compressedMessages,
|
||||
groupId: compressionResult.messageGroupId,
|
||||
parentMessageId: latestAssistantMessage?.id,
|
||||
} as GeneralAgentCompressionResultPayload,
|
||||
phase: 'compression_result',
|
||||
session: {
|
||||
messageCount: compressedMessages.length,
|
||||
sessionId: operationId,
|
||||
status: 'running',
|
||||
stepCount: state.stepCount + 1,
|
||||
},
|
||||
},
|
||||
};
|
||||
} catch (error) {
|
||||
log(
|
||||
`${stagePrefix} Compression failed. originalTokens=%d error=%O`,
|
||||
currentTokenCount,
|
||||
error,
|
||||
);
|
||||
|
||||
events.push({ error, type: 'compression_error' });
|
||||
|
||||
return {
|
||||
events,
|
||||
newState,
|
||||
nextContext: {
|
||||
payload: {
|
||||
compressedMessages: compressedMessagesFallback,
|
||||
groupId: '',
|
||||
parentMessageId: undefined,
|
||||
skipped: true,
|
||||
} as GeneralAgentCompressionResultPayload,
|
||||
phase: 'compression_result',
|
||||
session: {
|
||||
messageCount: newState.messages.length,
|
||||
sessionId: operationId,
|
||||
status: 'running',
|
||||
stepCount: state.stepCount + 1,
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Tool execution
|
||||
*/
|
||||
|
||||
@@ -7,6 +7,9 @@ import { initModelRuntimeFromDB } from '@/server/modules/ModelRuntime';
|
||||
|
||||
import { createRuntimeExecutors, type RuntimeExecutorContext } from '../RuntimeExecutors';
|
||||
|
||||
const mockCreateCompressionGroup = vi.fn();
|
||||
const mockFinalizeCompression = vi.fn();
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@/server/modules/ModelRuntime', () => ({
|
||||
initModelRuntimeFromDB: vi.fn().mockResolvedValue({
|
||||
@@ -14,6 +17,13 @@ vi.mock('@/server/modules/ModelRuntime', () => ({
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock('@/server/services/message', () => ({
|
||||
MessageService: vi.fn().mockImplementation(() => ({
|
||||
createCompressionGroup: mockCreateCompressionGroup,
|
||||
finalizeCompression: mockFinalizeCompression,
|
||||
})),
|
||||
}));
|
||||
|
||||
// @lobechat/model-runtime resolves to @cloud/business-model-runtime which has
|
||||
// cloud-specific dependencies that are unavailable in the test environment
|
||||
vi.mock('@lobechat/model-runtime', () => ({
|
||||
@@ -44,9 +54,16 @@ describe('RuntimeExecutors', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockCreateCompressionGroup.mockResolvedValue({
|
||||
messageGroupId: 'group-123',
|
||||
messagesToSummarize: [],
|
||||
success: true,
|
||||
});
|
||||
mockFinalizeCompression.mockResolvedValue({ success: true });
|
||||
|
||||
mockMessageModel = {
|
||||
create: vi.fn().mockResolvedValue({ id: 'msg-123' }),
|
||||
query: vi.fn().mockResolvedValue([]),
|
||||
update: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
|
||||
@@ -113,6 +130,14 @@ describe('RuntimeExecutors', () => {
|
||||
total: 0,
|
||||
});
|
||||
|
||||
const createCompressContextInstruction = (messages: any[]) => ({
|
||||
payload: {
|
||||
currentTokenCount: 1000,
|
||||
messages,
|
||||
},
|
||||
type: 'compress_context' as const,
|
||||
});
|
||||
|
||||
describe('call_llm executor', () => {
|
||||
const createMockState = (overrides?: Partial<AgentState>): AgentState => ({
|
||||
cost: createMockCost(),
|
||||
@@ -261,6 +286,330 @@ describe('RuntimeExecutors', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should execute compress_context and return compression_result', async () => {
|
||||
const mockChat = vi.fn().mockImplementation(async (_payload, options) => {
|
||||
await options?.callback?.onText?.('summary');
|
||||
await options?.callback?.onCompletion?.({
|
||||
usage: {
|
||||
completionTokens: 5,
|
||||
promptTokens: 10,
|
||||
totalTokens: 15,
|
||||
},
|
||||
});
|
||||
return new Response('done');
|
||||
});
|
||||
vi.mocked(initModelRuntimeFromDB).mockResolvedValueOnce({ chat: mockChat } as any);
|
||||
|
||||
mockMessageModel.query.mockResolvedValue([
|
||||
{ content: 'history', id: 'msg-history', role: 'user' },
|
||||
{ content: 'loading', id: 'assistant-existing', role: 'assistant' },
|
||||
]);
|
||||
mockCreateCompressionGroup.mockResolvedValue({
|
||||
messageGroupId: 'group-123',
|
||||
messagesToSummarize: [{ content: 'history', id: 'msg-history', role: 'user' }],
|
||||
success: true,
|
||||
});
|
||||
mockFinalizeCompression.mockResolvedValue({
|
||||
messages: [
|
||||
{ content: 'summary', id: 'group-123', role: 'compressedGroup' },
|
||||
{ content: 'loading', id: 'assistant-existing', role: 'assistant' },
|
||||
],
|
||||
success: true,
|
||||
});
|
||||
|
||||
const executors = createRuntimeExecutors(ctx);
|
||||
const state = createMockState({
|
||||
messages: [{ content: 'x '.repeat(70000), role: 'user' }],
|
||||
});
|
||||
|
||||
const instruction = createCompressContextInstruction([
|
||||
{ content: 'x '.repeat(70000), role: 'user' },
|
||||
]);
|
||||
|
||||
const result = await executors.compress_context!(instruction, state);
|
||||
|
||||
expect(mockCreateCompressionGroup).toHaveBeenCalledTimes(1);
|
||||
expect(mockFinalizeCompression).toHaveBeenCalledTimes(1);
|
||||
expect(mockChat).toHaveBeenCalledTimes(1);
|
||||
expect(result.nextContext?.phase).toBe('compression_result');
|
||||
expect((result.nextContext?.payload as any).compressedMessages[0]).toEqual({
|
||||
content: 'summary',
|
||||
id: 'group-123',
|
||||
role: 'compressedGroup',
|
||||
});
|
||||
expect((result.nextContext?.payload as any).parentMessageId).toBe('assistant-existing');
|
||||
expect(result.events).toContainEqual({
|
||||
groupId: 'group-123',
|
||||
parentMessageId: 'assistant-existing',
|
||||
type: 'compression_complete',
|
||||
});
|
||||
expect(result.newState.usage.llm.tokens.total).toBe(15);
|
||||
});
|
||||
|
||||
it('should skip compress_context when topic metadata is missing', async () => {
|
||||
const executors = createRuntimeExecutors({
|
||||
...ctx,
|
||||
});
|
||||
const state = createMockState({
|
||||
messages: [{ content: 'history', role: 'user' }],
|
||||
metadata: {
|
||||
agentId: 'agent-123',
|
||||
},
|
||||
});
|
||||
|
||||
const instruction = createCompressContextInstruction([{ content: 'history', role: 'user' }]);
|
||||
|
||||
const result = await executors.compress_context!(instruction, state);
|
||||
|
||||
expect(mockCreateCompressionGroup).not.toHaveBeenCalled();
|
||||
expect((result.nextContext?.payload as any).skipped).toBe(true);
|
||||
});
|
||||
|
||||
it('should skip compress_context when userId is missing', async () => {
|
||||
const executors = createRuntimeExecutors({
|
||||
...ctx,
|
||||
userId: undefined,
|
||||
});
|
||||
const state = createMockState({
|
||||
messages: [{ content: 'history', role: 'user' }],
|
||||
});
|
||||
|
||||
const instruction = createCompressContextInstruction([{ content: 'history', role: 'user' }]);
|
||||
|
||||
const result = await executors.compress_context!(instruction, state);
|
||||
|
||||
expect(mockCreateCompressionGroup).not.toHaveBeenCalled();
|
||||
expect((result.nextContext?.payload as any).skipped).toBe(true);
|
||||
});
|
||||
|
||||
it('should skip compress_context when there are no compressible messages after preserving the trailing user message', async () => {
|
||||
mockMessageModel.query.mockResolvedValue([]);
|
||||
|
||||
const executors = createRuntimeExecutors(ctx);
|
||||
const state = createMockState({
|
||||
messages: [{ content: 'continue with this exact instruction', role: 'user' }],
|
||||
});
|
||||
|
||||
const instruction = createCompressContextInstruction(state.messages);
|
||||
|
||||
const result = await executors.compress_context!(instruction, state);
|
||||
|
||||
expect(mockCreateCompressionGroup).not.toHaveBeenCalled();
|
||||
expect(result.nextContext?.payload as any).toMatchObject({
|
||||
compressedMessages: state.messages,
|
||||
groupId: '',
|
||||
parentMessageId: undefined,
|
||||
skipped: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should skip compress_context when compression model config is missing', async () => {
|
||||
mockMessageModel.query.mockResolvedValue([
|
||||
{ content: 'history', id: 'msg-history', role: 'user' },
|
||||
{ content: 'loading', id: 'assistant-existing', role: 'assistant' },
|
||||
]);
|
||||
|
||||
const executors = createRuntimeExecutors(ctx);
|
||||
const state = createMockState({
|
||||
messages: [{ content: 'history', role: 'user' }],
|
||||
modelRuntimeConfig: undefined,
|
||||
});
|
||||
|
||||
const instruction = createCompressContextInstruction([{ content: 'history', role: 'user' }]);
|
||||
|
||||
const result = await executors.compress_context!(instruction, state);
|
||||
|
||||
expect(mockCreateCompressionGroup).toHaveBeenCalledTimes(1);
|
||||
expect(mockFinalizeCompression).not.toHaveBeenCalled();
|
||||
expect(result.nextContext?.payload as any).toMatchObject({
|
||||
compressedMessages: [{ content: 'history', role: 'user' }],
|
||||
parentMessageId: 'assistant-existing',
|
||||
skipped: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should continue when compress_context fails', async () => {
|
||||
mockCreateCompressionGroup.mockRejectedValueOnce(new Error('compression failed'));
|
||||
|
||||
mockMessageModel.query.mockResolvedValue([
|
||||
{ content: 'history', id: 'msg-history', role: 'user' },
|
||||
]);
|
||||
const executors = createRuntimeExecutors(ctx);
|
||||
const state = createMockState({
|
||||
messages: [{ content: 'history', role: 'user' }],
|
||||
});
|
||||
|
||||
const instruction = createCompressContextInstruction([{ content: 'history', role: 'user' }]);
|
||||
|
||||
const result = await executors.compress_context!(instruction, state);
|
||||
|
||||
expect(result.nextContext?.phase).toBe('compression_result');
|
||||
expect((result.nextContext?.payload as any).skipped).toBe(true);
|
||||
expect(mockFinalizeCompression).not.toHaveBeenCalled();
|
||||
expect(result.events).toHaveLength(1);
|
||||
expect(result.events[0]).toMatchObject({ type: 'compression_error' });
|
||||
});
|
||||
|
||||
it('should preserve the trailing user message outside compression', async () => {
|
||||
const mockChat = vi.fn().mockImplementation(async (_payload, options) => {
|
||||
await options?.callback?.onText?.('summary');
|
||||
return new Response('done');
|
||||
});
|
||||
vi.mocked(initModelRuntimeFromDB).mockResolvedValueOnce({ chat: mockChat } as any);
|
||||
|
||||
mockMessageModel.query.mockResolvedValue([
|
||||
{ content: 'history', id: 'msg-history', role: 'user' },
|
||||
{ content: 'loading', id: 'assistant-existing', role: 'assistant' },
|
||||
]);
|
||||
mockCreateCompressionGroup.mockResolvedValue({
|
||||
messageGroupId: 'group-123',
|
||||
messagesToSummarize: [{ content: 'history', id: 'msg-history', role: 'user' }],
|
||||
success: true,
|
||||
});
|
||||
mockFinalizeCompression.mockResolvedValue({
|
||||
messages: [{ content: 'summary', id: 'group-123', role: 'compressedGroup' }],
|
||||
success: true,
|
||||
});
|
||||
|
||||
const executors = createRuntimeExecutors(ctx);
|
||||
const state = createMockState({
|
||||
messages: [
|
||||
{ content: 'history', id: 'msg-history', role: 'user' },
|
||||
{ content: 'continue with this exact instruction', role: 'user' },
|
||||
],
|
||||
});
|
||||
|
||||
const instruction = createCompressContextInstruction(state.messages);
|
||||
|
||||
const result = await executors.compress_context!(instruction, state);
|
||||
|
||||
expect(mockCreateCompressionGroup).toHaveBeenCalledWith(
|
||||
'topic-123',
|
||||
['msg-history', 'assistant-existing'],
|
||||
expect.any(Object),
|
||||
);
|
||||
expect((result.nextContext?.payload as any).compressedMessages).toEqual([
|
||||
{ content: 'summary', id: 'group-123', role: 'compressedGroup' },
|
||||
{ content: 'continue with this exact instruction', role: 'user' },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should fallback to messagesToSummarize when finalizeCompression does not return messages', async () => {
|
||||
const mockChat = vi.fn().mockImplementation(async (_payload, options) => {
|
||||
await options?.callback?.onText?.('summary');
|
||||
return new Response('done');
|
||||
});
|
||||
vi.mocked(initModelRuntimeFromDB).mockResolvedValueOnce({ chat: mockChat } as any);
|
||||
|
||||
mockMessageModel.query.mockResolvedValue([
|
||||
{ content: 'history', id: 'msg-history', role: 'user' },
|
||||
{ content: 'loading', id: 'assistant-existing', role: 'assistant' },
|
||||
]);
|
||||
mockCreateCompressionGroup.mockResolvedValue({
|
||||
messageGroupId: 'group-123',
|
||||
messagesToSummarize: [{ content: 'history', id: 'msg-history', role: 'user' }],
|
||||
success: true,
|
||||
});
|
||||
mockFinalizeCompression.mockResolvedValue({
|
||||
messages: undefined,
|
||||
success: true,
|
||||
});
|
||||
|
||||
const executors = createRuntimeExecutors(ctx);
|
||||
const state = createMockState({
|
||||
messages: [{ content: 'history', role: 'user' }],
|
||||
});
|
||||
|
||||
const instruction = createCompressContextInstruction(state.messages);
|
||||
|
||||
const result = await executors.compress_context!(instruction, state);
|
||||
|
||||
expect((result.nextContext?.payload as any).compressedMessages).toEqual([
|
||||
{ content: 'history', id: 'msg-history', role: 'user' },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should not duplicate the preserved trailing user message when it is already present in finalized messages', async () => {
|
||||
const preservedMessage = {
|
||||
content: 'continue with this exact instruction',
|
||||
id: 'msg-follow-up',
|
||||
role: 'user',
|
||||
};
|
||||
|
||||
const mockChat = vi.fn().mockImplementation(async (_payload, options) => {
|
||||
await options?.callback?.onText?.('summary');
|
||||
return new Response('done');
|
||||
});
|
||||
vi.mocked(initModelRuntimeFromDB).mockResolvedValueOnce({ chat: mockChat } as any);
|
||||
|
||||
mockMessageModel.query.mockResolvedValue([
|
||||
{ content: 'history', id: 'msg-history', role: 'user' },
|
||||
{ content: 'loading', id: 'assistant-existing', role: 'assistant' },
|
||||
preservedMessage,
|
||||
]);
|
||||
mockCreateCompressionGroup.mockResolvedValue({
|
||||
messageGroupId: 'group-123',
|
||||
messagesToSummarize: [{ content: 'history', id: 'msg-history', role: 'user' }],
|
||||
success: true,
|
||||
});
|
||||
mockFinalizeCompression.mockResolvedValue({
|
||||
messages: [
|
||||
{ content: 'summary', id: 'group-123', role: 'compressedGroup' },
|
||||
preservedMessage,
|
||||
],
|
||||
success: true,
|
||||
});
|
||||
|
||||
const executors = createRuntimeExecutors(ctx);
|
||||
const state = createMockState({
|
||||
messages: [{ content: 'history', id: 'msg-history', role: 'user' }, preservedMessage],
|
||||
});
|
||||
|
||||
const instruction = createCompressContextInstruction(state.messages);
|
||||
|
||||
const result = await executors.compress_context!(instruction, state);
|
||||
|
||||
expect((result.nextContext?.payload as any).compressedMessages).toEqual([
|
||||
{ content: 'summary', id: 'group-123', role: 'compressedGroup' },
|
||||
preservedMessage,
|
||||
]);
|
||||
});
|
||||
|
||||
it('should continue with skipped compression when the compression model reports a summary error', async () => {
|
||||
const mockChat = vi.fn().mockImplementation(async (_payload, options) => {
|
||||
await options?.callback?.onError?.({ message: 'summary failed' });
|
||||
return new Response('done');
|
||||
});
|
||||
vi.mocked(initModelRuntimeFromDB).mockResolvedValueOnce({ chat: mockChat } as any);
|
||||
|
||||
mockMessageModel.query.mockResolvedValue([
|
||||
{ content: 'history', id: 'msg-history', role: 'user' },
|
||||
{ content: 'loading', id: 'assistant-existing', role: 'assistant' },
|
||||
]);
|
||||
mockCreateCompressionGroup.mockResolvedValue({
|
||||
messageGroupId: 'group-123',
|
||||
messagesToSummarize: [{ content: 'history', id: 'msg-history', role: 'user' }],
|
||||
success: true,
|
||||
});
|
||||
|
||||
const executors = createRuntimeExecutors(ctx);
|
||||
const state = createMockState({
|
||||
messages: [{ content: 'history', role: 'user' }],
|
||||
});
|
||||
|
||||
const instruction = createCompressContextInstruction(state.messages);
|
||||
|
||||
const result = await executors.compress_context!(instruction, state);
|
||||
|
||||
expect(mockFinalizeCompression).not.toHaveBeenCalled();
|
||||
expect((result.nextContext?.payload as any).skipped).toBe(true);
|
||||
expect(result.events).toContainEqual(
|
||||
expect.objectContaining({
|
||||
type: 'compression_error',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
describe('assistantMessageId reuse', () => {
|
||||
it('should reuse existing assistant message when assistantMessageId is provided', async () => {
|
||||
const executors = createRuntimeExecutors(ctx);
|
||||
|
||||
Reference in New Issue
Block a user