diff --git a/apps/desktop/package.json b/apps/desktop/package.json index fddc3ebf24..10a97d842e 100644 --- a/apps/desktop/package.json +++ b/apps/desktop/package.json @@ -59,6 +59,7 @@ "electron-store": "^8.2.0", "electron-vite": "^3.0.0", "execa": "^9.5.2", + "fast-glob": "^3.3.3", "fix-path": "^5.0.0", "http-proxy-agent": "^7.0.2", "https-proxy-agent": "^7.0.6", diff --git a/apps/desktop/src/main/controllers/LocalFileCtr.ts b/apps/desktop/src/main/controllers/LocalFileCtr.ts index b59604547c..4578443dbd 100644 --- a/apps/desktop/src/main/controllers/LocalFileCtr.ts +++ b/apps/desktop/src/main/controllers/LocalFileCtr.ts @@ -1,4 +1,10 @@ import { + EditLocalFileParams, + EditLocalFileResult, + GlobFilesParams, + GlobFilesResult, + GrepContentParams, + GrepContentResult, ListLocalFileParams, LocalMoveFilesResultItem, LocalReadFileParams, @@ -13,10 +19,10 @@ import { } from '@lobechat/electron-client-ipc'; import { SYSTEM_FILES_TO_IGNORE, loadFile } from '@lobechat/file-loaders'; import { shell } from 'electron'; -import * as fs from 'node:fs'; -import { rename as renamePromise } from 'node:fs/promises'; +import fg from 'fast-glob'; +import { Stats, constants } from 'node:fs'; +import { access, mkdir, readFile, readdir, rename, stat, writeFile } from 'node:fs/promises'; import * as path from 'node:path'; -import { promisify } from 'node:util'; import FileSearchService from '@/services/fileSearchSrv'; import { FileResult, SearchOptions } from '@/types/fileSearch'; @@ -25,40 +31,15 @@ import { createLogger } from '@/utils/logger'; import { ControllerModule, ipcClientEvent } from './index'; -// 创建日志记录器 +// Create logger const logger = createLogger('controllers:LocalFileCtr'); -const statPromise = promisify(fs.stat); -const readdirPromise = promisify(fs.readdir); -const renamePromiseFs = promisify(fs.rename); -const accessPromise = promisify(fs.access); -const writeFilePromise = promisify(fs.writeFile); - export default class LocalFileCtr extends ControllerModule { private get searchService() { return this.app.getService(FileSearchService); } - /** - * Handle IPC event for local file search - */ - @ipcClientEvent('searchLocalFiles') - async handleLocalFilesSearch(params: LocalSearchFilesParams): Promise { - logger.debug('Received file search request:', { keywords: params.keywords }); - - const options: Omit = { - limit: 30, - }; - - try { - const results = await this.searchService.search(params.keywords, options); - logger.debug('File search completed', { count: results.length }); - return results; - } catch (error) { - logger.error('File search failed:', error); - return []; - } - } + // ==================== File Operation ==================== @ipcClientEvent('openLocalFile') async handleOpenLocalFile({ path: filePath }: OpenLocalFileParams): Promise<{ @@ -102,7 +83,7 @@ export default class LocalFileCtr extends ControllerModule { const results: LocalReadFileResult[] = []; for (const filePath of paths) { - // 初始化结果对象 + // Initialize result object logger.debug('Reading single file:', { filePath }); const result = await this.readFile({ path: filePath }); results.push(result); @@ -158,7 +139,7 @@ export default class LocalFileCtr extends ControllerModule { }; try { - const stats = await statPromise(filePath); + const stats = await stat(filePath); if (stats.isDirectory()) { logger.warn('Attempted to read directory content:', { filePath }); result.content = 'This is a directory and cannot be read as plain text.'; @@ -197,7 +178,7 @@ export default class LocalFileCtr extends ControllerModule { const results: FileResult[] = []; try { - const entries = await readdirPromise(dirPath); + const entries = await readdir(dirPath); logger.debug('Directory entries retrieved successfully:', { dirPath, entriesCount: entries.length, @@ -212,7 +193,7 @@ export default class LocalFileCtr extends ControllerModule { const fullPath = path.join(dirPath, entry); try { - const stats = await statPromise(fullPath); + const stats = await stat(fullPath); const isDirectory = stats.isDirectory(); results.push({ createdTime: stats.birthtime, @@ -260,7 +241,7 @@ export default class LocalFileCtr extends ControllerModule { return []; } - // 逐个处理移动请求 + // Process each move request for (const item of items) { const { oldPath: sourcePath, newPath } = item; const logPrefix = `[Moving file ${sourcePath} -> ${newPath}]`; @@ -272,7 +253,7 @@ export default class LocalFileCtr extends ControllerModule { success: false, }; - // 基本验证 + // Basic validation if (!sourcePath || !newPath) { logger.error(`${logPrefix} Parameter validation failed: source or target path is empty`); resultItem.error = 'Both oldPath and newPath are required for each item.'; @@ -281,9 +262,9 @@ export default class LocalFileCtr extends ControllerModule { } try { - // 检查源是否存在 + // Check if source exists try { - await accessPromise(sourcePath, fs.constants.F_OK); + await access(sourcePath, constants.F_OK); logger.debug(`${logPrefix} Source file exists`); } catch (accessError: any) { if (accessError.code === 'ENOENT') { @@ -297,28 +278,28 @@ export default class LocalFileCtr extends ControllerModule { } } - // 检查目标路径是否与源路径相同 + // Check if target path is the same as source path if (path.normalize(sourcePath) === path.normalize(newPath)) { logger.info(`${logPrefix} Source and target paths are identical, skipping move`); resultItem.success = true; - resultItem.newPath = newPath; // 即使未移动,也报告目标路径 + resultItem.newPath = newPath; // Report target path even if not moved results.push(resultItem); continue; } - // LBYL: 确保目标目录存在 + // LBYL: Ensure target directory exists const targetDir = path.dirname(newPath); makeSureDirExist(targetDir); logger.debug(`${logPrefix} Ensured target directory exists: ${targetDir}`); - // 执行移动 (rename) - await renamePromiseFs(sourcePath, newPath); + // Execute move (rename) + await rename(sourcePath, newPath); resultItem.success = true; resultItem.newPath = newPath; logger.info(`${logPrefix} Move successful`); } catch (error) { logger.error(`${logPrefix} Move failed:`, error); - // 使用与 handleMoveFile 类似的错误处理逻辑 + // Use similar error handling logic as handleMoveFile let errorMessage = (error as Error).message; if ((error as any).code === 'ENOENT') errorMessage = `Source path not found: ${sourcePath}.`; @@ -334,7 +315,7 @@ export default class LocalFileCtr extends ControllerModule { errorMessage = `The target directory ${newPath} is not empty (relevant on some systems if target exists and is a directory).`; else if ((error as any).code === 'EEXIST') errorMessage = `An item already exists at the target path: ${newPath}.`; - // 保留来自访问检查或目录检查的更具体错误 + // Keep more specific errors from access or directory checks else if ( !errorMessage.startsWith('Source path not found') && !errorMessage.startsWith('Permission denied accessing source path') && @@ -411,9 +392,9 @@ export default class LocalFileCtr extends ControllerModule { }; } - // Perform the rename operation using fs.promises.rename directly + // Perform the rename operation using rename directly try { - await renamePromise(currentPath, newPath); + await rename(currentPath, newPath); logger.info(`${logPrefix} Rename successful: ${currentPath} -> ${newPath}`); // Optionally return the newPath if frontend needs it // return { success: true, newPath: newPath }; @@ -444,7 +425,7 @@ export default class LocalFileCtr extends ControllerModule { const logPrefix = `[Writing file ${filePath}]`; logger.debug(`${logPrefix} Starting to write file`, { contentLength: content?.length }); - // 验证参数 + // Validate parameters if (!filePath) { logger.error(`${logPrefix} Parameter validation failed: path is empty`); return { error: 'Path cannot be empty', success: false }; @@ -456,14 +437,14 @@ export default class LocalFileCtr extends ControllerModule { } try { - // 确保目标目录存在 + // Ensure target directory exists (use async to avoid blocking main thread) const dirname = path.dirname(filePath); logger.debug(`${logPrefix} Creating directory: ${dirname}`); - fs.mkdirSync(dirname, { recursive: true }); + await mkdir(dirname, { recursive: true }); - // 写入文件内容 + // Write file content logger.debug(`${logPrefix} Starting to write content to file`); - await writeFilePromise(filePath, content, 'utf8'); + await writeFile(filePath, content, 'utf8'); logger.info(`${logPrefix} File written successfully`, { path: filePath, size: content.length, @@ -478,4 +459,250 @@ export default class LocalFileCtr extends ControllerModule { }; } } + + // ==================== Search & Find ==================== + + /** + * Handle IPC event for local file search + */ + @ipcClientEvent('searchLocalFiles') + async handleLocalFilesSearch(params: LocalSearchFilesParams): Promise { + logger.debug('Received file search request:', { keywords: params.keywords }); + + const options: Omit = { + limit: 30, + }; + + try { + const results = await this.searchService.search(params.keywords, options); + logger.debug('File search completed', { count: results.length }); + return results; + } catch (error) { + logger.error('File search failed:', error); + return []; + } + } + + @ipcClientEvent('grepContent') + async handleGrepContent(params: GrepContentParams): Promise { + const { + pattern, + path: searchPath = process.cwd(), + output_mode = 'files_with_matches', + } = params; + const logPrefix = `[grepContent: ${pattern}]`; + logger.debug(`${logPrefix} Starting content search`, { output_mode, searchPath }); + + try { + const regex = new RegExp( + pattern, + `g${params['-i'] ? 'i' : ''}${params.multiline ? 's' : ''}`, + ); + + // Determine files to search + let filesToSearch: string[] = []; + const stats = await stat(searchPath); + + if (stats.isFile()) { + filesToSearch = [searchPath]; + } else { + // Use glob pattern if provided, otherwise search all files + const globPattern = params.glob || '**/*'; + filesToSearch = await fg(globPattern, { + absolute: true, + cwd: searchPath, + dot: true, + ignore: ['**/node_modules/**', '**/.git/**'], + }); + + // Filter by type if provided + if (params.type) { + const ext = `.${params.type}`; + filesToSearch = filesToSearch.filter((file) => file.endsWith(ext)); + } + } + + logger.debug(`${logPrefix} Found ${filesToSearch.length} files to search`); + + const matches: string[] = []; + let totalMatches = 0; + + for (const filePath of filesToSearch) { + try { + const fileStats = await stat(filePath); + if (!fileStats.isFile()) continue; + + const content = await readFile(filePath, 'utf8'); + const lines = content.split('\n'); + + switch (output_mode) { + case 'files_with_matches': { + if (regex.test(content)) { + matches.push(filePath); + totalMatches++; + if (params.head_limit && matches.length >= params.head_limit) break; + } + break; + } + case 'content': { + const matchedLines: string[] = []; + for (let i = 0; i < lines.length; i++) { + if (regex.test(lines[i])) { + const contextBefore = params['-B'] || params['-C'] || 0; + const contextAfter = params['-A'] || params['-C'] || 0; + + const startLine = Math.max(0, i - contextBefore); + const endLine = Math.min(lines.length - 1, i + contextAfter); + + for (let j = startLine; j <= endLine; j++) { + const lineNum = params['-n'] ? `${j + 1}:` : ''; + matchedLines.push(`${filePath}:${lineNum}${lines[j]}`); + } + totalMatches++; + } + } + matches.push(...matchedLines); + if (params.head_limit && matches.length >= params.head_limit) break; + break; + } + case 'count': { + const fileMatches = (content.match(regex) || []).length; + if (fileMatches > 0) { + matches.push(`${filePath}:${fileMatches}`); + totalMatches += fileMatches; + } + break; + } + } + } catch (error) { + logger.debug(`${logPrefix} Skipping file ${filePath}:`, error); + } + } + + logger.info(`${logPrefix} Search completed`, { + matchCount: matches.length, + totalMatches, + }); + + return { + matches: params.head_limit ? matches.slice(0, params.head_limit) : matches, + success: true, + total_matches: totalMatches, + }; + } catch (error) { + logger.error(`${logPrefix} Grep failed:`, error); + return { + matches: [], + success: false, + total_matches: 0, + }; + } + } + + @ipcClientEvent('globLocalFiles') + async handleGlobFiles({ + path: searchPath = process.cwd(), + pattern, + }: GlobFilesParams): Promise { + const logPrefix = `[globFiles: ${pattern}]`; + logger.debug(`${logPrefix} Starting glob search`, { searchPath }); + + try { + const files = await fg(pattern, { + absolute: true, + cwd: searchPath, + dot: true, + onlyFiles: false, + stats: true, + }); + + // Sort by modification time (most recent first) + const sortedFiles = (files as unknown as Array<{ path: string; stats: Stats }>) + .sort((a, b) => b.stats.mtime.getTime() - a.stats.mtime.getTime()) + .map((f) => f.path); + + logger.info(`${logPrefix} Glob completed`, { fileCount: sortedFiles.length }); + + return { + files: sortedFiles, + success: true, + total_files: sortedFiles.length, + }; + } catch (error) { + logger.error(`${logPrefix} Glob failed:`, error); + return { + files: [], + success: false, + total_files: 0, + }; + } + } + + // ==================== File Editing ==================== + + @ipcClientEvent('editLocalFile') + async handleEditFile({ + file_path: filePath, + new_string, + old_string, + replace_all = false, + }: EditLocalFileParams): Promise { + const logPrefix = `[editFile: ${filePath}]`; + logger.debug(`${logPrefix} Starting file edit`, { replace_all }); + + try { + // Read file content + const content = await readFile(filePath, 'utf8'); + + // Check if old_string exists + if (!content.includes(old_string)) { + logger.error(`${logPrefix} Old string not found in file`); + return { + error: 'The specified old_string was not found in the file', + replacements: 0, + success: false, + }; + } + + // Perform replacement + let newContent: string; + let replacements: number; + + if (replace_all) { + const regex = new RegExp(old_string.replaceAll(/[$()*+.?[\\\]^{|}]/g, '\\$&'), 'g'); + const matches = content.match(regex); + replacements = matches ? matches.length : 0; + newContent = content.replaceAll(old_string, new_string); + } else { + // Replace only first occurrence + const index = content.indexOf(old_string); + if (index === -1) { + return { + error: 'Old string not found', + replacements: 0, + success: false, + }; + } + newContent = + content.slice(0, index) + new_string + content.slice(index + old_string.length); + replacements = 1; + } + + // Write back to file + await writeFile(filePath, newContent, 'utf8'); + + logger.info(`${logPrefix} File edited successfully`, { replacements }); + return { + replacements, + success: true, + }; + } catch (error) { + logger.error(`${logPrefix} Edit failed:`, error); + return { + error: (error as Error).message, + replacements: 0, + success: false, + }; + } + } } diff --git a/apps/desktop/src/main/controllers/__tests__/LocalFileCtr.test.ts b/apps/desktop/src/main/controllers/__tests__/LocalFileCtr.test.ts new file mode 100644 index 0000000000..0c9be6a337 --- /dev/null +++ b/apps/desktop/src/main/controllers/__tests__/LocalFileCtr.test.ts @@ -0,0 +1,392 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import type { App } from '@/core/App'; + +import LocalFileCtr from '../LocalFileCtr'; + +// Mock logger +vi.mock('@/utils/logger', () => ({ + createLogger: () => ({ + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + }), +})); + +// Mock file-loaders +vi.mock('@lobechat/file-loaders', () => ({ + SYSTEM_FILES_TO_IGNORE: ['.DS_Store', 'Thumbs.db'], + loadFile: vi.fn(), +})); + +// Mock electron +vi.mock('electron', () => ({ + shell: { + openPath: vi.fn(), + }, +})); + +// Mock fast-glob +vi.mock('fast-glob', () => ({ + default: vi.fn(), +})); + +// Mock node:fs/promises and node:fs +vi.mock('node:fs/promises', () => ({ + stat: vi.fn(), + readdir: vi.fn(), + rename: vi.fn(), + access: vi.fn(), + writeFile: vi.fn(), + readFile: vi.fn(), + mkdir: vi.fn(), +})); + +vi.mock('node:fs', () => ({ + Stats: class Stats {}, + constants: { + F_OK: 0, + }, + stat: vi.fn(), + readdir: vi.fn(), + rename: vi.fn(), + access: vi.fn(), + writeFile: vi.fn(), + readFile: vi.fn(), +})); + +// Mock FileSearchService +const mockSearchService = { + search: vi.fn(), +}; + +// Mock makeSureDirExist +vi.mock('@/utils/file-system', () => ({ + makeSureDirExist: vi.fn(), +})); + +const mockApp = { + getService: vi.fn(() => mockSearchService), +} as unknown as App; + +describe('LocalFileCtr', () => { + let localFileCtr: LocalFileCtr; + let mockShell: any; + let mockFg: any; + let mockLoadFile: any; + let mockFsPromises: any; + + beforeEach(async () => { + vi.clearAllMocks(); + + // Import mocks + mockShell = (await import('electron')).shell; + mockFg = (await import('fast-glob')).default; + mockLoadFile = (await import('@lobechat/file-loaders')).loadFile; + mockFsPromises = await import('node:fs/promises'); + + localFileCtr = new LocalFileCtr(mockApp); + }); + + describe('handleOpenLocalFile', () => { + it('should open file successfully', async () => { + vi.mocked(mockShell.openPath).mockResolvedValue(''); + + const result = await localFileCtr.handleOpenLocalFile({ path: '/test/file.txt' }); + + expect(result).toEqual({ success: true }); + expect(mockShell.openPath).toHaveBeenCalledWith('/test/file.txt'); + }); + + it('should return error when opening file fails', async () => { + const error = new Error('Failed to open'); + vi.mocked(mockShell.openPath).mockRejectedValue(error); + + const result = await localFileCtr.handleOpenLocalFile({ path: '/test/file.txt' }); + + expect(result).toEqual({ success: false, error: 'Failed to open' }); + }); + }); + + describe('handleOpenLocalFolder', () => { + it('should open directory when isDirectory is true', async () => { + vi.mocked(mockShell.openPath).mockResolvedValue(''); + + const result = await localFileCtr.handleOpenLocalFolder({ + path: '/test/folder', + isDirectory: true, + }); + + expect(result).toEqual({ success: true }); + expect(mockShell.openPath).toHaveBeenCalledWith('/test/folder'); + }); + + it('should open parent directory when isDirectory is false', async () => { + vi.mocked(mockShell.openPath).mockResolvedValue(''); + + const result = await localFileCtr.handleOpenLocalFolder({ + path: '/test/folder/file.txt', + isDirectory: false, + }); + + expect(result).toEqual({ success: true }); + expect(mockShell.openPath).toHaveBeenCalledWith('/test/folder'); + }); + + it('should return error when opening folder fails', async () => { + const error = new Error('Failed to open folder'); + vi.mocked(mockShell.openPath).mockRejectedValue(error); + + const result = await localFileCtr.handleOpenLocalFolder({ + path: '/test/folder', + isDirectory: true, + }); + + expect(result).toEqual({ success: false, error: 'Failed to open folder' }); + }); + }); + + describe('readFile', () => { + it('should read file successfully with default location', async () => { + const mockFileContent = 'line1\nline2\nline3\nline4\nline5'; + vi.mocked(mockLoadFile).mockResolvedValue({ + content: mockFileContent, + filename: 'test.txt', + fileType: 'txt', + createdTime: new Date('2024-01-01'), + modifiedTime: new Date('2024-01-02'), + }); + + const result = await localFileCtr.readFile({ path: '/test/file.txt' }); + + expect(result.filename).toBe('test.txt'); + expect(result.fileType).toBe('txt'); + expect(result.totalLineCount).toBe(5); + expect(result.content).toBe(mockFileContent); + }); + + it('should read file with custom location range', async () => { + const mockFileContent = 'line1\nline2\nline3\nline4\nline5'; + vi.mocked(mockLoadFile).mockResolvedValue({ + content: mockFileContent, + filename: 'test.txt', + fileType: 'txt', + createdTime: new Date('2024-01-01'), + modifiedTime: new Date('2024-01-02'), + }); + + const result = await localFileCtr.readFile({ path: '/test/file.txt', loc: [1, 3] }); + + expect(result.content).toBe('line2\nline3'); + expect(result.lineCount).toBe(2); + expect(result.totalLineCount).toBe(5); + }); + + it('should handle file read error', async () => { + vi.mocked(mockLoadFile).mockRejectedValue(new Error('File not found')); + + const result = await localFileCtr.readFile({ path: '/test/missing.txt' }); + + expect(result.content).toContain('Error accessing or processing file'); + expect(result.lineCount).toBe(0); + expect(result.charCount).toBe(0); + }); + }); + + describe('readFiles', () => { + it('should read multiple files successfully', async () => { + vi.mocked(mockLoadFile).mockResolvedValue({ + content: 'file content', + filename: 'test.txt', + fileType: 'txt', + createdTime: new Date('2024-01-01'), + modifiedTime: new Date('2024-01-02'), + }); + + const result = await localFileCtr.readFiles({ + paths: ['/test/file1.txt', '/test/file2.txt'], + }); + + expect(result).toHaveLength(2); + expect(mockLoadFile).toHaveBeenCalledTimes(2); + }); + }); + + describe('handleWriteFile', () => { + it('should write file successfully', async () => { + vi.mocked(mockFsPromises.mkdir).mockResolvedValue(undefined); + vi.mocked(mockFsPromises.writeFile).mockResolvedValue(undefined); + + const result = await localFileCtr.handleWriteFile({ + path: '/test/file.txt', + content: 'test content', + }); + + expect(result).toEqual({ success: true }); + }); + + it('should return error when path is empty', async () => { + const result = await localFileCtr.handleWriteFile({ + path: '', + content: 'test content', + }); + + expect(result).toEqual({ success: false, error: 'Path cannot be empty' }); + }); + + it('should return error when content is undefined', async () => { + const result = await localFileCtr.handleWriteFile({ + path: '/test/file.txt', + content: undefined as any, + }); + + expect(result).toEqual({ success: false, error: 'Content cannot be empty' }); + }); + + it('should handle write error', async () => { + vi.mocked(mockFsPromises.mkdir).mockResolvedValue(undefined); + vi.mocked(mockFsPromises.writeFile).mockRejectedValue(new Error('Write failed')); + + const result = await localFileCtr.handleWriteFile({ + path: '/test/file.txt', + content: 'test content', + }); + + expect(result).toEqual({ success: false, error: 'Failed to write file: Write failed' }); + }); + }); + + describe('handleRenameFile', () => { + it('should rename file successfully', async () => { + vi.mocked(mockFsPromises.rename).mockResolvedValue(undefined); + + const result = await localFileCtr.handleRenameFile({ + path: '/test/old.txt', + newName: 'new.txt', + }); + + expect(result).toEqual({ success: true, newPath: '/test/new.txt' }); + expect(mockFsPromises.rename).toHaveBeenCalledWith('/test/old.txt', '/test/new.txt'); + }); + + it('should skip rename when paths are identical', async () => { + const result = await localFileCtr.handleRenameFile({ + path: '/test/file.txt', + newName: 'file.txt', + }); + + expect(result).toEqual({ success: true, newPath: '/test/file.txt' }); + expect(mockFsPromises.rename).not.toHaveBeenCalled(); + }); + + it('should reject invalid new name with path separators', async () => { + const result = await localFileCtr.handleRenameFile({ + path: '/test/old.txt', + newName: '../new.txt', + }); + + expect(result.success).toBe(false); + expect(result.error).toContain('Invalid new name'); + }); + + it('should reject invalid new name with special characters', async () => { + const result = await localFileCtr.handleRenameFile({ + path: '/test/old.txt', + newName: 'new:file.txt', + }); + + expect(result.success).toBe(false); + expect(result.error).toContain('Invalid new name'); + }); + + it('should handle file not found error', async () => { + const error: any = new Error('File not found'); + error.code = 'ENOENT'; + vi.mocked(mockFsPromises.rename).mockRejectedValue(error); + + const result = await localFileCtr.handleRenameFile({ + path: '/test/old.txt', + newName: 'new.txt', + }); + + expect(result.success).toBe(false); + expect(result.error).toContain('File or directory not found'); + }); + + it('should handle file already exists error', async () => { + const error: any = new Error('File exists'); + error.code = 'EEXIST'; + vi.mocked(mockFsPromises.rename).mockRejectedValue(error); + + const result = await localFileCtr.handleRenameFile({ + path: '/test/old.txt', + newName: 'new.txt', + }); + + expect(result.success).toBe(false); + expect(result.error).toContain('already exists'); + }); + }); + + describe('handleLocalFilesSearch', () => { + it('should search files successfully', async () => { + const mockResults = [ + { + name: 'test.txt', + path: '/test/test.txt', + isDirectory: false, + size: 100, + type: 'txt', + }, + ]; + mockSearchService.search.mockResolvedValue(mockResults); + + const result = await localFileCtr.handleLocalFilesSearch({ keywords: 'test' }); + + expect(result).toEqual(mockResults); + expect(mockSearchService.search).toHaveBeenCalledWith('test', { limit: 30 }); + }); + + it('should return empty array on search error', async () => { + mockSearchService.search.mockRejectedValue(new Error('Search failed')); + + const result = await localFileCtr.handleLocalFilesSearch({ keywords: 'test' }); + + expect(result).toEqual([]); + }); + }); + + describe('handleGlobFiles', () => { + it('should glob files successfully', async () => { + const mockFiles = [ + { path: '/test/file1.txt', stats: { mtime: new Date('2024-01-02') } }, + { path: '/test/file2.txt', stats: { mtime: new Date('2024-01-01') } }, + ]; + vi.mocked(mockFg).mockResolvedValue(mockFiles); + + const result = await localFileCtr.handleGlobFiles({ + pattern: '*.txt', + path: '/test', + }); + + expect(result.success).toBe(true); + expect(result.files).toEqual(['/test/file1.txt', '/test/file2.txt']); + expect(result.total_files).toBe(2); + }); + + it('should handle glob error', async () => { + vi.mocked(mockFg).mockRejectedValue(new Error('Glob failed')); + + const result = await localFileCtr.handleGlobFiles({ + pattern: '*.txt', + }); + + expect(result).toEqual({ + success: false, + files: [], + total_files: 0, + }); + }); + }); +}); diff --git a/packages/agent-runtime/src/core/InterventionChecker.ts b/packages/agent-runtime/src/core/InterventionChecker.ts new file mode 100644 index 0000000000..ed5c13a9be --- /dev/null +++ b/packages/agent-runtime/src/core/InterventionChecker.ts @@ -0,0 +1,173 @@ +import type { + ArgumentMatcher, + HumanInterventionPolicy, + HumanInterventionRule, + ShouldInterveneParams, +} from '@lobechat/types'; + +/** + * Intervention Checker + * Determines whether a tool call requires human intervention + */ +export class InterventionChecker { + /** + * Check if a tool call requires intervention + * + * @param params - Parameters object containing config, toolArgs, confirmedHistory, and toolKey + * @returns Policy to apply + */ + static shouldIntervene(params: ShouldInterveneParams): HumanInterventionPolicy { + const { config, toolArgs = {}, confirmedHistory = [], toolKey } = params; + + // No config means never intervene (auto-execute) + if (!config) return 'never'; + + // Simple policy string + if (typeof config === 'string') { + // For 'first' policy, check if already confirmed + if (config === 'first' && toolKey && confirmedHistory.includes(toolKey)) { + return 'never'; + } + return config; + } + + // Array of rules - find first matching rule + for (const rule of config) { + if (this.matchesRule(rule, toolArgs)) { + const policy = rule.policy; + + // For 'first' policy, check if already confirmed + if (policy === 'first' && toolKey && confirmedHistory.includes(toolKey)) { + return 'never'; + } + + return policy; + } + } + + // No rule matched - default to always for safety + return 'always'; + } + + /** + * Check if tool arguments match a rule + * + * @param rule - Rule to check + * @param toolArgs - Tool call arguments + * @returns true if matches + */ + private static matchesRule(rule: HumanInterventionRule, toolArgs: Record): boolean { + // No match criteria means it's a default rule + if (!rule.match) return true; + + // Check each parameter matcher + for (const [paramName, matcher] of Object.entries(rule.match)) { + const paramValue = toolArgs[paramName]; + + // Parameter not present in args + if (paramValue === undefined) return false; + + // Check if value matches + if (!this.matchesArgument(matcher, paramValue)) { + return false; + } + } + + return true; + } + + /** + * Check if a parameter value matches the matcher + * + * @param matcher - Argument matcher + * @param value - Parameter value + * @returns true if matches + */ + private static matchesArgument(matcher: ArgumentMatcher, value: any): boolean { + const strValue = String(value); + + // Simple string matcher + if (typeof matcher === 'string') { + return this.matchPattern(matcher, strValue); + } + + // Complex matcher with type + const { pattern, type } = matcher; + + switch (type) { + case 'exact': { + return strValue === pattern; + } + case 'prefix': { + return strValue.startsWith(pattern); + } + case 'wildcard': { + return this.matchPattern(pattern, strValue); + } + case 'regex': { + return new RegExp(pattern).test(strValue); + } + default: { + return false; + } + } + } + + /** + * Match wildcard pattern (supports * wildcard) + * + * @param pattern - Pattern with wildcards + * @param value - Value to match + * @returns true if matches + */ + private static matchPattern(pattern: string, value: string): boolean { + // Check for colon-based prefix matching (e.g., "git add:*") + if (pattern.includes(':')) { + const [prefix, suffix] = pattern.split(':'); + if (suffix === '*') { + return value.startsWith(prefix + ':') || value === prefix; + } + } + + // Convert wildcard pattern to regex + const regexPattern = pattern + .replaceAll(/[$()+.?[\\\]^{|}]/g, '\\$&') // Escape special chars + .replaceAll('*', '.*'); // Replace * with .* + + return new RegExp(`^${regexPattern}$`).test(value); + } + + /** + * Generate tool key from identifier and API name + * + * @param identifier - Tool identifier + * @param apiName - API name + * @param argsHash - Optional hash of arguments + * @returns Tool key in format "identifier/apiName" or "identifier/apiName#hash" + */ + static generateToolKey(identifier: string, apiName: string, argsHash?: string): string { + const baseKey = `${identifier}/${apiName}`; + return argsHash ? `${baseKey}#${argsHash}` : baseKey; + } + + /** + * Generate simple hash of arguments for 'once' policy + * + * @param args - Tool call arguments + * @returns Hash string + */ + static hashArguments(args: Record): string { + const sortedKeys = Object.keys(args).sort(); + const str = sortedKeys.map((key) => `${key}=${JSON.stringify(args[key])}`).join('&'); + + // Simple hash function + let hash = 0; + for (let i = 0; i < str.length; i++) { + const char = str.charCodeAt(i); + hash = (hash << 5) - hash + char; + hash = hash & hash; // Convert to 32-bit integer + } + + return Math.abs(hash).toString(36); + } +} diff --git a/packages/agent-runtime/src/core/UsageCounter.ts b/packages/agent-runtime/src/core/UsageCounter.ts new file mode 100644 index 0000000000..899c9020ce --- /dev/null +++ b/packages/agent-runtime/src/core/UsageCounter.ts @@ -0,0 +1,248 @@ +import { ModelUsage } from '@lobechat/types'; + +import { Cost, Usage } from '../types/usage'; + +/** + * UsageCounter - Pure accumulator for usage and cost tracking + * Focuses only on usage/cost calculations without managing state + */ +/* eslint-disable unicorn/no-static-only-class */ +export class UsageCounter { + /** + * Create default usage statistics + */ + private static createDefaultUsage(): Usage { + return { + humanInteraction: { + approvalRequests: 0, + promptRequests: 0, + selectRequests: 0, + totalWaitingTimeMs: 0, + }, + llm: { + apiCalls: 0, + processingTimeMs: 0, + tokens: { + input: 0, + output: 0, + total: 0, + }, + }, + tools: { + byTool: [], + totalCalls: 0, + totalTimeMs: 0, + }, + }; + } + + /** + * Create default cost statistics + */ + private static createDefaultCost(): Cost { + return { + calculatedAt: new Date().toISOString(), + currency: 'USD', + llm: { + byModel: [], + currency: 'USD', + total: 0, + }, + tools: { + byTool: [], + currency: 'USD', + total: 0, + }, + total: 0, + }; + } + + /** + * Merge two ModelUsage objects by accumulating token counts + * @param previous - Previous usage statistics + * @param current - Current usage statistics to add + * @returns Merged usage statistics + */ + private static mergeModelUsage( + previous: ModelUsage | undefined, + current: ModelUsage, + ): ModelUsage { + if (!previous) return current; + + const merged: ModelUsage = { ...current }; + + // Accumulate all numeric token fields + const numericFields: (keyof ModelUsage)[] = [ + 'inputCachedTokens', + 'inputCacheMissTokens', + 'inputWriteCacheTokens', + 'inputTextTokens', + 'inputImageTokens', + 'inputAudioTokens', + 'inputCitationTokens', + 'outputTextTokens', + 'outputImageTokens', + 'outputAudioTokens', + 'outputReasoningTokens', + 'acceptedPredictionTokens', + 'rejectedPredictionTokens', + 'totalInputTokens', + 'totalOutputTokens', + 'totalTokens', + ]; + + for (const field of numericFields) { + const prevValue = previous[field] as number | undefined; + const currValue = current[field] as number | undefined; + + if (prevValue !== undefined || currValue !== undefined) { + merged[field] = (prevValue || 0) + (currValue || 0); + } + } + + // Accumulate cost + if (previous.cost !== undefined || current.cost !== undefined) { + merged.cost = (previous.cost || 0) + (current.cost || 0); + } + + return merged; + } + + /** + * Accumulate LLM usage and cost for a specific model + * @param params - Accumulation parameters + * @param params.usage - Current usage statistics (optional, will be created if not provided) + * @param params.cost - Current cost statistics (optional, will be created if not provided) + * @param params.provider - Provider name (e.g., "openai") + * @param params.model - Model name (e.g., "gpt-4") + * @param params.modelUsage - ModelUsage from model-runtime + * @returns Updated usage and cost + */ + static accumulateLLM(params: { + cost?: Cost; + model: string; + modelUsage: ModelUsage; + provider: string; + usage?: Usage; + }): { cost?: Cost; usage: Usage } { + const { usage, cost, provider, model, modelUsage } = params; + + // Ensure usage exists + const newUsage = usage ? structuredClone(usage) : this.createDefaultUsage(); + + // Accumulate token counts to usage.llm + newUsage.llm.tokens.input += modelUsage.totalInputTokens ?? 0; + newUsage.llm.tokens.output += modelUsage.totalOutputTokens ?? 0; + newUsage.llm.tokens.total += modelUsage.totalTokens ?? 0; + newUsage.llm.apiCalls += 1; + + // Ensure cost exists if modelUsage has cost + let newCost = cost + ? structuredClone(cost) + : modelUsage.cost + ? this.createDefaultCost() + : undefined; + + if (modelUsage.cost && newCost) { + const modelId = `${provider}/${model}`; + + // Find or create byModel entry + let modelEntry = newCost.llm.byModel.find((entry) => entry.id === modelId); + + if (!modelEntry) { + modelEntry = { + id: modelId, + model, + provider, + totalCost: 0, + usage: {}, + }; + newCost.llm.byModel.push(modelEntry); + } + + // Merge usage breakdown + modelEntry.usage = UsageCounter.mergeModelUsage(modelEntry.usage, modelUsage); + + // Accumulate costs + modelEntry.totalCost += modelUsage.cost; + newCost.llm.total += modelUsage.cost; + newCost.total += modelUsage.cost; + newCost.calculatedAt = new Date().toISOString(); + } + + return { cost: newCost, usage: newUsage }; + } + + /** + * Accumulate tool usage and cost + * @param params - Accumulation parameters + * @param params.usage - Current usage statistics (optional, will be created if not provided) + * @param params.cost - Current cost statistics (optional, will be created if not provided) + * @param params.toolName - Tool identifier + * @param params.executionTime - Execution time in milliseconds + * @param params.success - Whether the execution was successful + * @param params.toolCost - Optional cost for this tool call + * @returns Updated usage and cost + */ + static accumulateTool(params: { + cost?: Cost; + executionTime: number; + success: boolean; + toolCost?: number; + toolName: string; + usage?: Usage; + }): { cost?: Cost; usage: Usage } { + const { usage, cost, toolName, executionTime, success, toolCost } = params; + + // Ensure usage exists + const newUsage = usage ? structuredClone(usage) : this.createDefaultUsage(); + + // Find or create byTool entry + let toolEntry = newUsage.tools.byTool.find((entry) => entry.name === toolName); + + if (!toolEntry) { + toolEntry = { + calls: 0, + errors: 0, + name: toolName, + totalTimeMs: 0, + }; + newUsage.tools.byTool.push(toolEntry); + } + + // Accumulate tool usage + toolEntry.calls += 1; + toolEntry.totalTimeMs += executionTime; + if (!success) { + toolEntry.errors += 1; + } + + newUsage.tools.totalCalls += 1; + newUsage.tools.totalTimeMs += executionTime; + + // Ensure cost exists if toolCost is provided + let newCost = cost ? structuredClone(cost) : toolCost ? this.createDefaultCost() : undefined; + + if (toolCost && newCost) { + let toolCostEntry = newCost.tools.byTool.find((entry) => entry.name === toolName); + + if (!toolCostEntry) { + toolCostEntry = { + calls: 0, + currency: 'USD', + name: toolName, + totalCost: 0, + }; + newCost.tools.byTool.push(toolCostEntry); + } + + toolCostEntry.calls += 1; + toolCostEntry.totalCost += toolCost; + newCost.tools.total += toolCost; + newCost.total += toolCost; + newCost.calculatedAt = new Date().toISOString(); + } + + return { cost: newCost, usage: newUsage }; + } +} diff --git a/packages/agent-runtime/src/core/__tests__/InterventionChecker.test.ts b/packages/agent-runtime/src/core/__tests__/InterventionChecker.test.ts new file mode 100644 index 0000000000..771cf4f45b --- /dev/null +++ b/packages/agent-runtime/src/core/__tests__/InterventionChecker.test.ts @@ -0,0 +1,334 @@ +import type { HumanInterventionConfig } from '@lobechat/types'; +import { describe, expect, it } from 'vitest'; + +import { InterventionChecker } from '../InterventionChecker'; + +describe('InterventionChecker', () => { + describe('shouldIntervene', () => { + it('should return never when config is undefined', () => { + const result = InterventionChecker.shouldIntervene({ config: undefined, toolArgs: {} }); + expect(result).toBe('never'); + }); + + it('should return the policy when config is a simple string', () => { + expect(InterventionChecker.shouldIntervene({ config: 'never', toolArgs: {} })).toBe('never'); + expect(InterventionChecker.shouldIntervene({ config: 'always', toolArgs: {} })).toBe( + 'always', + ); + expect(InterventionChecker.shouldIntervene({ config: 'first', toolArgs: {} })).toBe('first'); + }); + + it('should handle "first" policy with confirmed history', () => { + const toolKey = 'web-browsing/crawlSinglePage'; + const confirmedHistory = [toolKey]; + + const result = InterventionChecker.shouldIntervene({ + config: 'first', + toolArgs: {}, + confirmedHistory, + toolKey, + }); + expect(result).toBe('never'); + }); + + it('should require intervention for "first" policy without confirmation', () => { + const toolKey = 'web-browsing/crawlSinglePage'; + const confirmedHistory: string[] = []; + + const result = InterventionChecker.shouldIntervene({ + config: 'first', + toolArgs: {}, + confirmedHistory, + toolKey, + }); + expect(result).toBe('first'); + }); + + it('should match rules in order and return first match', () => { + const config: HumanInterventionConfig = [ + { match: { command: 'ls:*' }, policy: 'never' }, + { match: { command: 'git commit:*' }, policy: 'first' }, + { policy: 'always' }, // Default rule + ]; + + expect(InterventionChecker.shouldIntervene({ config, toolArgs: { command: 'ls:' } })).toBe( + 'never', + ); + expect( + InterventionChecker.shouldIntervene({ config, toolArgs: { command: 'git commit:' } }), + ).toBe('first'); + expect( + InterventionChecker.shouldIntervene({ config, toolArgs: { command: 'rm -rf /' } }), + ).toBe('always'); + }); + + it('should return always as default when no rule matches', () => { + const config: HumanInterventionConfig = [{ match: { command: 'ls:*' }, policy: 'never' }]; + + const result = InterventionChecker.shouldIntervene({ + config, + toolArgs: { command: 'rm -rf /' }, + }); + expect(result).toBe('always'); + }); + + it('should handle multiple parameter matching', () => { + const config: HumanInterventionConfig = [ + { + match: { + command: 'git add:*', + path: '/Users/project/*', + }, + policy: 'never', + }, + { policy: 'always' }, + ]; + + // Both match + expect( + InterventionChecker.shouldIntervene({ + config, + toolArgs: { + command: 'git add:.', + path: '/Users/project/file.ts', + }, + }), + ).toBe('never'); + + // Only one matches + expect( + InterventionChecker.shouldIntervene({ + config, + toolArgs: { + command: 'git add:.', + path: '/tmp/file.ts', + }, + }), + ).toBe('always'); + }); + + it('should handle default rule without match', () => { + const config: HumanInterventionConfig = [ + { match: { command: 'ls:*' }, policy: 'never' }, + { policy: 'first' }, // Default rule + ]; + + const result = InterventionChecker.shouldIntervene({ + config, + toolArgs: { command: 'anything' }, + }); + expect(result).toBe('first'); + }); + }); + + describe('matchPattern', () => { + it('should match exact strings', () => { + expect(InterventionChecker['matchPattern']('hello', 'hello')).toBe(true); + expect(InterventionChecker['matchPattern']('hello', 'world')).toBe(false); + }); + + it('should match wildcard patterns', () => { + expect(InterventionChecker['matchPattern']('*.ts', 'file.ts')).toBe(true); + expect(InterventionChecker['matchPattern']('*.ts', 'file.js')).toBe(false); + expect(InterventionChecker['matchPattern']('test*', 'test123')).toBe(true); + expect(InterventionChecker['matchPattern']('test*', 'abc123')).toBe(false); + }); + + it('should match colon-based prefix patterns', () => { + expect(InterventionChecker['matchPattern']('git add:*', 'git add:')).toBe(true); + expect(InterventionChecker['matchPattern']('git add:*', 'git add:.')).toBe(true); + expect(InterventionChecker['matchPattern']('git add:*', 'git add:--all')).toBe(true); + expect(InterventionChecker['matchPattern']('git add:*', 'git commit')).toBe(false); + }); + + it('should match path patterns', () => { + expect( + InterventionChecker['matchPattern']('/Users/project/*', '/Users/project/file.ts'), + ).toBe(true); + expect(InterventionChecker['matchPattern']('/Users/project/*', '/tmp/file.ts')).toBe(false); + }); + }); + + describe('matchesArgument', () => { + it('should match exact type', () => { + const matcher = { pattern: 'git add', type: 'exact' as const }; + expect(InterventionChecker['matchesArgument'](matcher, 'git add')).toBe(true); + expect(InterventionChecker['matchesArgument'](matcher, 'git add:.')).toBe(false); + }); + + it('should match prefix type', () => { + const matcher = { pattern: 'git add', type: 'prefix' as const }; + expect(InterventionChecker['matchesArgument'](matcher, 'git add')).toBe(true); + expect(InterventionChecker['matchesArgument'](matcher, 'git add:.')).toBe(true); + expect(InterventionChecker['matchesArgument'](matcher, 'git commit')).toBe(false); + }); + + it('should match wildcard type', () => { + const matcher = { pattern: 'git *', type: 'wildcard' as const }; + expect(InterventionChecker['matchesArgument'](matcher, 'git add')).toBe(true); + expect(InterventionChecker['matchesArgument'](matcher, 'git commit')).toBe(true); + expect(InterventionChecker['matchesArgument'](matcher, 'npm install')).toBe(false); + }); + + it('should match regex type', () => { + const matcher = { pattern: '^git (add|commit)', type: 'regex' as const }; + expect(InterventionChecker['matchesArgument'](matcher, 'git add')).toBe(true); + expect(InterventionChecker['matchesArgument'](matcher, 'git commit')).toBe(true); + expect(InterventionChecker['matchesArgument'](matcher, 'git push')).toBe(false); + }); + + it('should handle simple string matcher', () => { + expect(InterventionChecker['matchesArgument']('git add:*', 'git add:.')).toBe(true); + expect(InterventionChecker['matchesArgument']('*.ts', 'file.ts')).toBe(true); + expect(InterventionChecker['matchesArgument']('exact', 'exact')).toBe(true); + }); + }); + + describe('generateToolKey', () => { + it('should generate key without args hash', () => { + const key = InterventionChecker.generateToolKey('web-browsing', 'crawlSinglePage'); + expect(key).toBe('web-browsing/crawlSinglePage'); + }); + + it('should generate key with args hash', () => { + const key = InterventionChecker.generateToolKey('bash', 'bash', 'a1b2c3'); + expect(key).toBe('bash/bash#a1b2c3'); + }); + }); + + describe('hashArguments', () => { + it('should generate consistent hash for same arguments', () => { + const args1 = { command: 'ls -la', path: '/tmp' }; + const args2 = { command: 'ls -la', path: '/tmp' }; + + const hash1 = InterventionChecker.hashArguments(args1); + const hash2 = InterventionChecker.hashArguments(args2); + + expect(hash1).toBe(hash2); + }); + + it('should generate different hash for different arguments', () => { + const args1 = { command: 'ls -la' }; + const args2 = { command: 'ls -l' }; + + const hash1 = InterventionChecker.hashArguments(args1); + const hash2 = InterventionChecker.hashArguments(args2); + + expect(hash1).not.toBe(hash2); + }); + + it('should handle key order independence', () => { + const args1 = { a: 1, b: 2 }; + const args2 = { b: 2, a: 1 }; + + const hash1 = InterventionChecker.hashArguments(args1); + const hash2 = InterventionChecker.hashArguments(args2); + + expect(hash1).toBe(hash2); + }); + + it('should handle empty arguments', () => { + const hash = InterventionChecker.hashArguments({}); + expect(hash).toBeDefined(); + expect(typeof hash).toBe('string'); + }); + + it('should handle complex nested objects', () => { + const args = { + config: { nested: { value: 'test' } }, + array: [1, 2, 3], + }; + + const hash = InterventionChecker.hashArguments(args); + expect(hash).toBeDefined(); + expect(typeof hash).toBe('string'); + }); + }); + + describe('Integration scenarios', () => { + it('should handle Bash tool scenario', () => { + const config: HumanInterventionConfig = [ + { match: { command: 'ls:*' }, policy: 'never' }, + { match: { command: 'git add:*' }, policy: 'first' }, + { match: { command: 'git commit:*' }, policy: 'first' }, + { match: { command: 'rm:*' }, policy: 'always' }, + { policy: 'always' }, + ]; + + // Safe commands - never + expect(InterventionChecker.shouldIntervene({ config, toolArgs: { command: 'ls:' } })).toBe( + 'never', + ); + + // Git commands - first + expect( + InterventionChecker.shouldIntervene({ config, toolArgs: { command: 'git add:.' } }), + ).toBe('first'); + expect( + InterventionChecker.shouldIntervene({ config, toolArgs: { command: 'git commit:-m' } }), + ).toBe('first'); + + // Dangerous commands - always + expect(InterventionChecker.shouldIntervene({ config, toolArgs: { command: 'rm:-rf' } })).toBe( + 'always', + ); + expect( + InterventionChecker.shouldIntervene({ config, toolArgs: { command: 'npm install' } }), + ).toBe('always'); + }); + + it('should handle LocalSystem tool scenario', () => { + const config: HumanInterventionConfig = [ + { match: { path: '/Users/project/*' }, policy: 'never' }, + { policy: 'first' }, + ]; + + // Project directory - never + expect( + InterventionChecker.shouldIntervene({ + config, + toolArgs: { path: '/Users/project/file.ts' }, + }), + ).toBe('never'); + + // Outside project - first + expect( + InterventionChecker.shouldIntervene({ config, toolArgs: { path: '/tmp/file.ts' } }), + ).toBe('first'); + }); + + it('should handle Web Browsing tool with simple policy', () => { + const config: HumanInterventionConfig = 'always'; + + expect( + InterventionChecker.shouldIntervene({ config, toolArgs: { url: 'https://example.com' } }), + ).toBe('always'); + }); + + it('should handle first policy with confirmation history', () => { + const config: HumanInterventionConfig = [ + { match: { command: 'git add:*' }, policy: 'first' }, + { policy: 'always' }, + ]; + + const toolKey = 'bash/bash#abc123'; + const args = { command: 'git add:.' }; + + // First time - requires intervention + expect( + InterventionChecker.shouldIntervene({ + config, + toolArgs: args, + confirmedHistory: [], + toolKey, + }), + ).toBe('first'); + + // After confirmation - never + const confirmedHistory = [toolKey]; + expect( + InterventionChecker.shouldIntervene({ config, toolArgs: args, confirmedHistory, toolKey }), + ).toBe('never'); + }); + }); +}); diff --git a/packages/agent-runtime/src/core/__tests__/UsageCounter.test.ts b/packages/agent-runtime/src/core/__tests__/UsageCounter.test.ts new file mode 100644 index 0000000000..70392b9108 --- /dev/null +++ b/packages/agent-runtime/src/core/__tests__/UsageCounter.test.ts @@ -0,0 +1,873 @@ +import { ModelUsage } from '@lobechat/types'; +import { describe, expect, it } from 'vitest'; + +import { UsageCounter } from '../UsageCounter'; +import { AgentRuntime } from '../runtime'; + +describe('UsageCounter', () => { + describe('UsageCounter.accumulateLLM', () => { + it('should accumulate LLM usage tokens', () => { + const state = AgentRuntime.createInitialState(); + + const modelUsage: ModelUsage = { + totalInputTokens: 100, + totalOutputTokens: 50, + totalTokens: 150, + }; + + const { usage } = UsageCounter.accumulateLLM({ + cost: state.cost, + model: 'gpt-4', + modelUsage, + provider: 'openai', + usage: state.usage, + }); + + expect(usage.llm.tokens.input).toBe(100); + expect(usage.llm.tokens.output).toBe(50); + expect(usage.llm.tokens.total).toBe(150); + expect(usage.llm.apiCalls).toBe(1); + }); + + it('should not mutate original usage', () => { + const state = AgentRuntime.createInitialState(); + + const modelUsage: ModelUsage = { + totalInputTokens: 100, + totalOutputTokens: 50, + totalTokens: 150, + }; + + const { usage } = UsageCounter.accumulateLLM({ + cost: state.cost, + model: 'gpt-4', + modelUsage: modelUsage, + provider: 'openai', + usage: state.usage, + }); + + expect(state.usage.llm.tokens.input).toBe(0); + expect(usage).not.toBe(state.usage); + }); + + it('should create new byModel entry when not exists', () => { + const state = AgentRuntime.createInitialState(); + + const modelUsage: ModelUsage = { + cost: 0.05, + totalInputTokens: 100, + totalOutputTokens: 50, + totalTokens: 150, + }; + + const { cost } = UsageCounter.accumulateLLM({ + cost: state.cost, + model: 'gpt-4', + modelUsage: modelUsage, + provider: 'openai', + usage: state.usage, + }); + + expect(cost?.llm.byModel).toHaveLength(1); + expect(cost?.llm.byModel[0]).toEqual({ + id: 'openai/gpt-4', + model: 'gpt-4', + provider: 'openai', + totalCost: 0.05, + usage: { + cost: 0.05, + totalInputTokens: 100, + totalOutputTokens: 50, + totalTokens: 150, + }, + }); + }); + + it('should accumulate to existing byModel entry', () => { + const state = AgentRuntime.createInitialState(); + + const usage1: ModelUsage = { + cost: 0.05, + totalInputTokens: 100, + totalOutputTokens: 50, + totalTokens: 150, + }; + + const usage2: ModelUsage = { + cost: 0.03, + totalInputTokens: 50, + totalOutputTokens: 25, + totalTokens: 75, + }; + + const result1 = UsageCounter.accumulateLLM({ + cost: state.cost, + model: 'gpt-4', + modelUsage: usage1, + provider: 'openai', + usage: state.usage, + }); + const result2 = UsageCounter.accumulateLLM({ + cost: result1.cost, + model: 'gpt-4', + modelUsage: usage2, + provider: 'openai', + usage: result1.usage, + }); + + expect(result2.cost?.llm.byModel).toHaveLength(1); + expect(result2.cost?.llm.byModel[0]).toEqual({ + id: 'openai/gpt-4', + model: 'gpt-4', + provider: 'openai', + totalCost: 0.08, + usage: { + cost: 0.08, + totalInputTokens: 150, + totalOutputTokens: 75, + totalTokens: 225, + }, + }); + }); + + it('should accumulate multiple models separately', () => { + const state = AgentRuntime.createInitialState(); + + const usage1: ModelUsage = { + cost: 0.05, + totalInputTokens: 100, + totalOutputTokens: 50, + totalTokens: 150, + }; + + const usage2: ModelUsage = { + cost: 0.02, + totalInputTokens: 50, + totalOutputTokens: 25, + totalTokens: 75, + }; + + const result1 = UsageCounter.accumulateLLM({ + cost: state.cost, + model: 'gpt-4', + modelUsage: usage1, + provider: 'openai', + usage: state.usage, + }); + const result2 = UsageCounter.accumulateLLM({ + cost: result1.cost, + model: 'claude-3-5-sonnet-20241022', + modelUsage: usage2, + provider: 'anthropic', + usage: result1.usage, + }); + + expect(result2.cost?.llm.byModel).toHaveLength(2); + expect(result2.cost?.llm.byModel[0].id).toBe('openai/gpt-4'); + expect(result2.cost?.llm.byModel[1].id).toBe('anthropic/claude-3-5-sonnet-20241022'); + }); + + it('should accumulate cache-related tokens', () => { + const state = AgentRuntime.createInitialState(); + + const modelUsage: ModelUsage = { + cost: 0.05, + inputCacheMissTokens: 60, + inputCachedTokens: 40, + inputWriteCacheTokens: 20, + totalInputTokens: 100, + totalOutputTokens: 50, + totalTokens: 150, + }; + + const { cost } = UsageCounter.accumulateLLM({ + cost: state.cost, + model: 'claude-3-5-sonnet-20241022', + modelUsage: modelUsage, + provider: 'anthropic', + usage: state.usage, + }); + + expect(cost?.llm.byModel[0].usage).toEqual({ + cost: 0.05, + inputCacheMissTokens: 60, + inputCachedTokens: 40, + inputWriteCacheTokens: 20, + totalInputTokens: 100, + totalOutputTokens: 50, + totalTokens: 150, + }); + }); + + it('should accumulate total costs correctly', () => { + const state = AgentRuntime.createInitialState(); + + const usage1: ModelUsage = { + cost: 0.05, + totalInputTokens: 100, + totalOutputTokens: 50, + totalTokens: 150, + }; + + const usage2: ModelUsage = { + cost: 0.03, + totalInputTokens: 50, + totalOutputTokens: 25, + totalTokens: 75, + }; + + const result1 = UsageCounter.accumulateLLM({ + cost: state.cost, + model: 'gpt-4', + modelUsage: usage1, + provider: 'openai', + usage: state.usage, + }); + const result2 = UsageCounter.accumulateLLM({ + cost: result1.cost, + model: 'claude-3-5-sonnet-20241022', + modelUsage: usage2, + provider: 'anthropic', + usage: result1.usage, + }); + + expect(result2.cost?.llm.total).toBe(0.08); + expect(result2.cost?.total).toBe(0.08); + expect(result2.cost?.calculatedAt).toBeDefined(); + }); + + it('should not accumulate cost when usage.cost is undefined', () => { + const state = AgentRuntime.createInitialState(); + + const modelUsage: ModelUsage = { + totalInputTokens: 100, + totalOutputTokens: 50, + totalTokens: 150, + }; + + const { cost } = UsageCounter.accumulateLLM({ + cost: state.cost, + model: 'gpt-4', + modelUsage: modelUsage, + provider: 'openai', + usage: state.usage, + }); + + expect(cost?.llm.byModel).toHaveLength(0); + expect(cost?.llm.total).toBe(0); + expect(cost?.total).toBe(0); + }); + + it('should increment apiCalls for each accumulation', () => { + const state = AgentRuntime.createInitialState(); + + const modelUsage: ModelUsage = { + totalInputTokens: 100, + totalOutputTokens: 50, + totalTokens: 150, + }; + + const result1 = UsageCounter.accumulateLLM({ + cost: state.cost, + model: 'gpt-4', + modelUsage: modelUsage, + provider: 'openai', + usage: state.usage, + }); + const result2 = UsageCounter.accumulateLLM({ + cost: result1.cost, + model: 'gpt-4', + modelUsage: modelUsage, + provider: 'openai', + usage: result1.usage, + }); + const result3 = UsageCounter.accumulateLLM({ + cost: result2.cost, + model: 'claude-3-5-sonnet-20241022', + modelUsage: modelUsage, + provider: 'anthropic', + usage: result2.usage, + }); + + expect(result3.usage.llm.apiCalls).toBe(3); + }); + + it('should auto-create usage and cost when not provided', () => { + const modelUsage: ModelUsage = { + cost: 0.05, + totalInputTokens: 100, + totalOutputTokens: 50, + totalTokens: 150, + }; + + const { usage, cost } = UsageCounter.accumulateLLM({ + model: 'gpt-4', + modelUsage, + provider: 'openai', + }); + + expect(usage).toBeDefined(); + expect(usage.llm.tokens.input).toBe(100); + expect(usage.llm.tokens.output).toBe(50); + expect(usage.llm.tokens.total).toBe(150); + expect(usage.llm.apiCalls).toBe(1); + + expect(cost).toBeDefined(); + expect(cost?.total).toBe(0.05); + expect(cost?.llm.total).toBe(0.05); + }); + }); + + describe('UsageCounter.accumulateTool', () => { + it('should accumulate tool usage', () => { + const state = AgentRuntime.createInitialState(); + + const { usage } = UsageCounter.accumulateTool({ + cost: state.cost, + executionTime: 1000, + success: true, + toolName: 'search', + usage: state.usage, + }); + + expect(usage.tools.byTool).toHaveLength(1); + expect(usage.tools.byTool[0]).toEqual({ + calls: 1, + errors: 0, + name: 'search', + totalTimeMs: 1000, + }); + expect(usage.tools.totalCalls).toBe(1); + expect(usage.tools.totalTimeMs).toBe(1000); + }); + + it('should not mutate original usage', () => { + const state = AgentRuntime.createInitialState(); + + const { usage } = UsageCounter.accumulateTool({ + cost: state.cost, + executionTime: 1000, + success: true, + toolName: 'search', + usage: state.usage, + }); + + expect(state.usage.tools.totalCalls).toBe(0); + expect(usage).not.toBe(state.usage); + }); + + it('should accumulate errors when success is false', () => { + const state = AgentRuntime.createInitialState(); + + const { usage } = UsageCounter.accumulateTool({ + cost: state.cost, + executionTime: 1000, + success: false, + toolName: 'search', + usage: state.usage, + }); + + expect(usage.tools.byTool[0]).toEqual({ + calls: 1, + errors: 1, + name: 'search', + totalTimeMs: 1000, + }); + }); + + it('should accumulate multiple tool calls', () => { + const state = AgentRuntime.createInitialState(); + + const result1 = UsageCounter.accumulateTool({ + cost: state.cost, + executionTime: 1000, + success: true, + toolName: 'search', + usage: state.usage, + }); + const result2 = UsageCounter.accumulateTool({ + cost: result1.cost, + executionTime: 500, + success: true, + toolName: 'search', + usage: result1.usage, + }); + const result3 = UsageCounter.accumulateTool({ + cost: result2.cost, + executionTime: 200, + success: false, + toolName: 'calculator', + usage: result2.usage, + }); + + expect(result3.usage.tools.byTool).toHaveLength(2); + expect(result3.usage.tools.byTool.find((t) => t.name === 'search')).toEqual({ + calls: 2, + errors: 0, + name: 'search', + totalTimeMs: 1500, + }); + expect(result3.usage.tools.byTool.find((t) => t.name === 'calculator')).toEqual({ + calls: 1, + errors: 1, + name: 'calculator', + totalTimeMs: 200, + }); + expect(result3.usage.tools.totalCalls).toBe(3); + expect(result3.usage.tools.totalTimeMs).toBe(1700); + }); + + it('should accumulate tool cost when provided', () => { + const state = AgentRuntime.createInitialState(); + + const { cost } = UsageCounter.accumulateTool({ + cost: state.cost, + executionTime: 1000, + success: true, + toolCost: 0.01, + toolName: 'premium-search', + usage: state.usage, + }); + + expect(cost?.tools.byTool).toHaveLength(1); + expect(cost?.tools.byTool[0]).toEqual({ + calls: 1, + currency: 'USD', + name: 'premium-search', + totalCost: 0.01, + }); + expect(cost?.tools.total).toBe(0.01); + expect(cost?.total).toBe(0.01); + }); + + it('should accumulate tool cost across multiple calls', () => { + const state = AgentRuntime.createInitialState(); + + const result1 = UsageCounter.accumulateTool({ + cost: state.cost, + executionTime: 1000, + success: true, + toolCost: 0.01, + toolName: 'premium-search', + usage: state.usage, + }); + const result2 = UsageCounter.accumulateTool({ + cost: result1.cost, + executionTime: 500, + success: true, + toolCost: 0.005, + toolName: 'premium-search', + usage: result1.usage, + }); + + expect(result2.cost?.tools.byTool).toHaveLength(1); + expect(result2.cost?.tools.byTool[0]).toEqual({ + calls: 2, + currency: 'USD', + name: 'premium-search', + totalCost: 0.015, + }); + expect(result2.cost?.tools.total).toBe(0.015); + expect(result2.cost?.total).toBe(0.015); + }); + + it('should not accumulate cost when cost is undefined', () => { + const state = AgentRuntime.createInitialState(); + + const { cost } = UsageCounter.accumulateTool({ + cost: state.cost, + executionTime: 1000, + success: true, + toolName: 'free-tool', + usage: state.usage, + }); + + expect(cost?.tools.byTool).toHaveLength(0); + expect(cost?.tools.total).toBe(0); + }); + }); + + describe('mixed accumulation', () => { + it('should accumulate both LLM and tool costs correctly', () => { + const state = AgentRuntime.createInitialState(); + + const llmUsage: ModelUsage = { + cost: 0.05, + totalInputTokens: 100, + totalOutputTokens: 50, + totalTokens: 150, + }; + + const result1 = UsageCounter.accumulateLLM({ + cost: state.cost, + model: 'gpt-4', + modelUsage: llmUsage, + provider: 'openai', + usage: state.usage, + }); + const result2 = UsageCounter.accumulateTool({ + cost: result1.cost, + executionTime: 1000, + success: true, + toolCost: 0.01, + toolName: 'premium-search', + usage: result1.usage, + }); + + expect(result2.cost?.llm.total).toBe(0.05); + expect(result2.cost?.tools.total).toBe(0.01); + expect(result2.cost?.total).toBeCloseTo(0.06); + }); + }); + + describe('mergeModelUsage (private method tests via accumulateLLM)', () => { + it('should merge basic token counts', () => { + const state = AgentRuntime.createInitialState(); + + const usage1: ModelUsage = { + cost: 0.05, + totalInputTokens: 100, + totalOutputTokens: 50, + totalTokens: 150, + }; + + const usage2: ModelUsage = { + cost: 0.03, + totalInputTokens: 200, + totalOutputTokens: 100, + totalTokens: 300, + }; + + const result1 = UsageCounter.accumulateLLM({ + cost: state.cost, + model: 'gpt-4', + modelUsage: usage1, + provider: 'openai', + usage: state.usage, + }); + const result2 = UsageCounter.accumulateLLM({ + cost: result1.cost, + model: 'gpt-4', + modelUsage: usage2, + provider: 'openai', + usage: result1.usage, + }); + + expect(result2.cost?.llm.byModel[0].usage).toEqual({ + cost: 0.08, + totalInputTokens: 300, + totalOutputTokens: 150, + totalTokens: 450, + }); + }); + + it('should merge cache-related tokens', () => { + const state = AgentRuntime.createInitialState(); + + const usage1: ModelUsage = { + cost: 0.05, + inputCacheMissTokens: 30, + inputCachedTokens: 50, + inputWriteCacheTokens: 20, + totalInputTokens: 100, + totalOutputTokens: 50, + totalTokens: 150, + }; + + const usage2: ModelUsage = { + cost: 0.03, + inputCacheMissTokens: 40, + inputCachedTokens: 80, + inputWriteCacheTokens: 30, + totalInputTokens: 150, + totalOutputTokens: 75, + totalTokens: 225, + }; + + const result1 = UsageCounter.accumulateLLM({ + cost: state.cost, + model: 'claude-3-5-sonnet-20241022', + modelUsage: usage1, + provider: 'anthropic', + usage: state.usage, + }); + const result2 = UsageCounter.accumulateLLM({ + cost: result1.cost, + model: 'claude-3-5-sonnet-20241022', + modelUsage: usage2, + provider: 'anthropic', + usage: result1.usage, + }); + + expect(result2.cost?.llm.byModel[0].usage).toEqual({ + cost: 0.08, + inputCacheMissTokens: 70, + inputCachedTokens: 130, + inputWriteCacheTokens: 50, + totalInputTokens: 250, + totalOutputTokens: 125, + totalTokens: 375, + }); + }); + + it('should merge reasoning tokens', () => { + const state = AgentRuntime.createInitialState(); + + const usage1: ModelUsage = { + cost: 0.05, + outputReasoningTokens: 100, + outputTextTokens: 200, + totalInputTokens: 100, + totalOutputTokens: 300, + totalTokens: 400, + }; + + const usage2: ModelUsage = { + cost: 0.03, + outputReasoningTokens: 50, + outputTextTokens: 100, + totalInputTokens: 50, + totalOutputTokens: 150, + totalTokens: 200, + }; + + const result1 = UsageCounter.accumulateLLM({ + cost: state.cost, + model: 'o1', + modelUsage: usage1, + provider: 'openai', + usage: state.usage, + }); + const result2 = UsageCounter.accumulateLLM({ + cost: result1.cost, + model: 'o1', + modelUsage: usage2, + provider: 'openai', + usage: result1.usage, + }); + + expect(result2.cost?.llm.byModel[0].usage).toEqual({ + cost: 0.08, + outputReasoningTokens: 150, + outputTextTokens: 300, + totalInputTokens: 150, + totalOutputTokens: 450, + totalTokens: 600, + }); + }); + + it('should merge audio and image tokens', () => { + const state = AgentRuntime.createInitialState(); + + const usage1: ModelUsage = { + cost: 0.05, + inputAudioTokens: 10, + inputImageTokens: 20, + outputAudioTokens: 5, + outputImageTokens: 15, + totalInputTokens: 30, + totalOutputTokens: 20, + totalTokens: 50, + }; + + const usage2: ModelUsage = { + cost: 0.03, + inputAudioTokens: 15, + inputImageTokens: 25, + outputAudioTokens: 8, + outputImageTokens: 12, + totalInputTokens: 40, + totalOutputTokens: 20, + totalTokens: 60, + }; + + const result1 = UsageCounter.accumulateLLM({ + cost: state.cost, + model: 'gpt-4o-audio-preview', + modelUsage: usage1, + provider: 'openai', + usage: state.usage, + }); + const result2 = UsageCounter.accumulateLLM({ + cost: result1.cost, + model: 'gpt-4o-audio-preview', + modelUsage: usage2, + provider: 'openai', + usage: result1.usage, + }); + + expect(result2.cost?.llm.byModel[0].usage).toEqual({ + cost: 0.08, + inputAudioTokens: 25, + inputImageTokens: 45, + outputAudioTokens: 13, + outputImageTokens: 27, + totalInputTokens: 70, + totalOutputTokens: 40, + totalTokens: 110, + }); + }); + + it('should merge prediction tokens', () => { + const state = AgentRuntime.createInitialState(); + + const usage1: ModelUsage = { + acceptedPredictionTokens: 50, + cost: 0.05, + rejectedPredictionTokens: 10, + totalInputTokens: 100, + totalOutputTokens: 60, + totalTokens: 160, + }; + + const usage2: ModelUsage = { + acceptedPredictionTokens: 30, + cost: 0.03, + rejectedPredictionTokens: 5, + totalInputTokens: 50, + totalOutputTokens: 35, + totalTokens: 85, + }; + + const result1 = UsageCounter.accumulateLLM({ + cost: state.cost, + model: 'gpt-4o', + modelUsage: usage1, + provider: 'openai', + usage: state.usage, + }); + const result2 = UsageCounter.accumulateLLM({ + cost: result1.cost, + model: 'gpt-4o', + modelUsage: usage2, + provider: 'openai', + usage: result1.usage, + }); + + expect(result2.cost?.llm.byModel[0].usage).toEqual({ + acceptedPredictionTokens: 80, + cost: 0.08, + rejectedPredictionTokens: 15, + totalInputTokens: 150, + totalOutputTokens: 95, + totalTokens: 245, + }); + }); + + it('should handle missing fields gracefully', () => { + const state = AgentRuntime.createInitialState(); + + const usage1: ModelUsage = { + cost: 0.05, + totalInputTokens: 100, + // totalOutputTokens is missing + }; + + const usage2: ModelUsage = { + cost: 0.03, + totalOutputTokens: 50, + // totalInputTokens is missing + }; + + const result1 = UsageCounter.accumulateLLM({ + cost: state.cost, + model: 'gpt-4', + modelUsage: usage1, + provider: 'openai', + usage: state.usage, + }); + const result2 = UsageCounter.accumulateLLM({ + cost: result1.cost, + model: 'gpt-4', + modelUsage: usage2, + provider: 'openai', + usage: result1.usage, + }); + + expect(result2.cost?.llm.byModel[0].usage).toEqual({ + cost: 0.08, + totalInputTokens: 100, + totalOutputTokens: 50, + }); + }); + + it('should merge all fields in a comprehensive scenario', () => { + const state = AgentRuntime.createInitialState(); + + const usage1: ModelUsage = { + acceptedPredictionTokens: 10, + cost: 0.05, + inputAudioTokens: 5, + inputCacheMissTokens: 40, + inputCachedTokens: 60, + inputCitationTokens: 10, + inputImageTokens: 20, + inputTextTokens: 100, + inputWriteCacheTokens: 30, + outputAudioTokens: 3, + outputImageTokens: 8, + outputReasoningTokens: 20, + outputTextTokens: 50, + rejectedPredictionTokens: 5, + totalInputTokens: 200, + totalOutputTokens: 80, + totalTokens: 280, + }; + + const usage2: ModelUsage = { + acceptedPredictionTokens: 5, + cost: 0.03, + inputAudioTokens: 3, + inputCacheMissTokens: 20, + inputCachedTokens: 30, + inputCitationTokens: 5, + inputImageTokens: 10, + inputTextTokens: 50, + inputWriteCacheTokens: 15, + outputAudioTokens: 2, + outputImageTokens: 4, + outputReasoningTokens: 10, + outputTextTokens: 25, + rejectedPredictionTokens: 2, + totalInputTokens: 100, + totalOutputTokens: 40, + totalTokens: 140, + }; + + const result1 = UsageCounter.accumulateLLM({ + cost: state.cost, + model: 'claude-3-5-sonnet-20241022', + modelUsage: usage1, + provider: 'anthropic', + usage: state.usage, + }); + const result2 = UsageCounter.accumulateLLM({ + cost: result1.cost, + model: 'claude-3-5-sonnet-20241022', + modelUsage: usage2, + provider: 'anthropic', + usage: result1.usage, + }); + + expect(result2.cost?.llm.byModel[0].usage).toEqual({ + acceptedPredictionTokens: 15, + cost: 0.08, + inputAudioTokens: 8, + inputCacheMissTokens: 60, + inputCachedTokens: 90, + inputCitationTokens: 15, + inputImageTokens: 30, + inputTextTokens: 150, + inputWriteCacheTokens: 45, + outputAudioTokens: 5, + outputImageTokens: 12, + outputReasoningTokens: 30, + outputTextTokens: 75, + rejectedPredictionTokens: 7, + totalInputTokens: 300, + totalOutputTokens: 120, + totalTokens: 420, + }); + }); + }); +}); diff --git a/packages/agent-runtime/src/core/__tests__/runtime.test.ts b/packages/agent-runtime/src/core/__tests__/runtime.test.ts index 76e1ff29bf..09c83d082b 100644 --- a/packages/agent-runtime/src/core/__tests__/runtime.test.ts +++ b/packages/agent-runtime/src/core/__tests__/runtime.test.ts @@ -369,7 +369,7 @@ describe('AgentRuntime', () => { type: 'tool_pending', }); - expect(result.newState.status).toBe('waiting_for_human_input'); + expect(result.newState.status).toBe('waiting_for_human'); expect(result.newState.pendingToolsCalling).toBeDefined(); }); @@ -396,7 +396,7 @@ describe('AgentRuntime', () => { sessionId: 'test-session', }); - expect(result.newState.status).toBe('waiting_for_human_input'); + expect(result.newState.status).toBe('waiting_for_human'); expect(result.newState.pendingHumanPrompt).toEqual({ prompt: 'Please provide input', metadata: { key: 'value' }, @@ -434,7 +434,7 @@ describe('AgentRuntime', () => { sessionId: 'test-session', }); - expect(result.newState.status).toBe('waiting_for_human_input'); + expect(result.newState.status).toBe('waiting_for_human'); }); }); @@ -733,7 +733,7 @@ describe('AgentRuntime', () => { }, tools: { totalCalls: 0, - byTool: {}, + byTool: [], totalTimeMs: 0, }, humanInteraction: { @@ -746,12 +746,12 @@ describe('AgentRuntime', () => { expect(state.cost).toMatchObject({ llm: { - byModel: {}, + byModel: [], total: 0, currency: 'USD', }, tools: { - byTool: {}, + byTool: [], total: 0, currency: 'USD', }, @@ -890,8 +890,8 @@ describe('AgentRuntime', () => { calculateCost(context: CostCalculationContext): Cost { return { - llm: { byModel: {}, total: 15.0, currency: 'USD' }, - tools: { byTool: {}, total: 0, currency: 'USD' }, + llm: { byModel: [], total: 15.0, currency: 'USD' }, + tools: { byTool: [], total: 0, currency: 'USD' }, total: 15.0, currency: 'USD', calculatedAt: new Date().toISOString(), @@ -1018,7 +1018,7 @@ describe('AgentRuntime', () => { result = await runtime.step(result.newState, result.nextContext); // Now should request human approval - expect(result.newState.status).toBe('waiting_for_human_input'); + expect(result.newState.status).toBe('waiting_for_human'); expect(result.newState.pendingToolsCalling).toHaveLength(1); // Step 2: Approve and execute tool call @@ -1210,7 +1210,7 @@ describe('AgentRuntime', () => { expect(agent.tools.safe_tool).toHaveBeenCalled(); // Should be in waiting state (blocked by approval request) - expect(result.newState.status).toBe('waiting_for_human_input'); + expect(result.newState.status).toBe('waiting_for_human'); // Should have pending tool calls expect(result.newState.pendingToolsCalling).toHaveLength(1); @@ -1333,8 +1333,8 @@ describe('AgentRuntime', () => { return { calculatedAt: new Date().toISOString(), currency: 'USD', - llm: { byModel: {}, currency: 'USD', total: 15.0 }, - tools: { byTool: {}, currency: 'USD', total: 0 }, + llm: { byModel: [], currency: 'USD', total: 15.0 }, + tools: { byTool: [], currency: 'USD', total: 0 }, total: 15.0, }; } @@ -1396,8 +1396,8 @@ describe('AgentRuntime', () => { return { calculatedAt: new Date().toISOString(), currency: 'USD', - llm: { byModel: {}, currency: 'USD', total: 0 }, - tools: { byTool: {}, currency: 'USD', total: 20.0 }, + llm: { byModel: [], currency: 'USD', total: 0 }, + tools: { byTool: [], currency: 'USD', total: 20.0 }, total: 20.0, }; } @@ -1438,8 +1438,8 @@ describe('AgentRuntime', () => { const baseCost = context.previousCost || { calculatedAt: new Date().toISOString(), currency: 'USD', - llm: { byModel: {}, currency: 'USD', total: 0 }, - tools: { byTool: {}, currency: 'USD', total: 0 }, + llm: { byModel: [], currency: 'USD', total: 0 }, + tools: { byTool: [], currency: 'USD', total: 0 }, total: 0, }; @@ -1447,7 +1447,7 @@ describe('AgentRuntime', () => { ...baseCost, calculatedAt: new Date().toISOString(), tools: { - byTool: {}, + byTool: [], currency: 'USD', total: baseCost.tools.total + 5.0, }, @@ -1514,15 +1514,17 @@ describe('AgentRuntime', () => { newUsage.tools.totalCalls += 1; newUsage.tools.totalTimeMs += 100; - if (newUsage.tools.byTool[toolName]) { - newUsage.tools.byTool[toolName].calls += 1; - newUsage.tools.byTool[toolName].totalTimeMs += 100; + const existingTool = newUsage.tools.byTool.find((t) => t.name === toolName); + if (existingTool) { + existingTool.calls += 1; + existingTool.totalTimeMs += 100; } else { - newUsage.tools.byTool[toolName] = { + newUsage.tools.byTool.push({ calls: 1, errors: 0, + name: toolName, totalTimeMs: 100, - }; + }); } return newUsage; @@ -1567,10 +1569,14 @@ describe('AgentRuntime', () => { // Should have per-tool statistics expect(result.newState.usage.tools.totalCalls).toBe(2); - expect(result.newState.usage.tools.byTool.analytics_tool).toBeDefined(); - expect(result.newState.usage.tools.byTool.analytics_tool.calls).toBe(1); - expect(result.newState.usage.tools.byTool.logging_tool).toBeDefined(); - expect(result.newState.usage.tools.byTool.logging_tool.calls).toBe(1); + const analyticsTool = result.newState.usage.tools.byTool.find( + (t) => t.name === 'analytics_tool', + ); + const loggingTool = result.newState.usage.tools.byTool.find((t) => t.name === 'logging_tool'); + expect(analyticsTool).toBeDefined(); + expect(analyticsTool!.calls).toBe(1); + expect(loggingTool).toBeDefined(); + expect(loggingTool!.calls).toBe(1); }); }); }); diff --git a/packages/agent-runtime/src/core/index.ts b/packages/agent-runtime/src/core/index.ts index 2bc3e7914d..b5aac38a1e 100644 --- a/packages/agent-runtime/src/core/index.ts +++ b/packages/agent-runtime/src/core/index.ts @@ -1 +1,3 @@ +export * from './InterventionChecker'; export * from './runtime'; +export * from './UsageCounter'; diff --git a/packages/agent-runtime/src/core/runtime.ts b/packages/agent-runtime/src/core/runtime.ts index f0e85ee158..3eafda6f8e 100644 --- a/packages/agent-runtime/src/core/runtime.ts +++ b/packages/agent-runtime/src/core/runtime.ts @@ -120,10 +120,7 @@ export class AgentRuntime { } // Stop execution if blocked - if ( - currentState.status === 'waiting_for_human_input' || - currentState.status === 'interrupted' - ) { + if (currentState.status === 'waiting_for_human' || currentState.status === 'interrupted') { break; } } @@ -273,7 +270,7 @@ export class AgentRuntime { tokens: { input: 0, output: 0, total: 0 }, }, tools: { - byTool: {}, + byTool: [], totalCalls: 0, totalTimeMs: 0, }, @@ -290,12 +287,12 @@ export class AgentRuntime { calculatedAt: now, currency: 'USD', llm: { - byModel: {}, + byModel: [], currency: 'USD', total: 0, }, tools: { - byTool: {}, + byTool: [], currency: 'USD', total: 0, }, @@ -308,7 +305,9 @@ export class AgentRuntime { * @param partialState - Partial state to override defaults * @returns Complete AgentState with defaults filled in */ - static createInitialState(partialState: Partial & { sessionId: string }): AgentState { + static createInitialState( + partialState?: Partial & { sessionId: string }, + ): AgentState { const now = new Date().toISOString(); return { @@ -319,9 +318,10 @@ export class AgentRuntime { messages: [], status: 'idle', stepCount: 0, + toolManifestMap: {}, usage: AgentRuntime.createDefaultUsage(), // User provided values override defaults - ...partialState, + ...(partialState || { sessionId: '' }), }; } @@ -489,7 +489,7 @@ export class AgentRuntime { const newState = structuredClone(state); newState.lastModified = new Date().toISOString(); - newState.status = 'waiting_for_human_input'; + newState.status = 'waiting_for_human'; newState.pendingToolsCalling = pendingToolsCalling; const events: AgentEvent[] = [ @@ -515,7 +515,7 @@ export class AgentRuntime { const newState = structuredClone(state); newState.lastModified = new Date().toISOString(); - newState.status = 'waiting_for_human_input'; + newState.status = 'waiting_for_human'; newState.pendingHumanPrompt = { metadata, prompt }; const events: AgentEvent[] = [ @@ -541,7 +541,7 @@ export class AgentRuntime { const newState = structuredClone(state); newState.lastModified = new Date().toISOString(); - newState.status = 'waiting_for_human_input'; + newState.status = 'waiting_for_human'; newState.pendingHumanSelect = { metadata, multi, options, prompt }; const events: AgentEvent[] = [ @@ -641,13 +641,15 @@ export class AgentRuntime { newState.usage.tools.totalCalls += result.newState.usage.tools.totalCalls; newState.usage.tools.totalTimeMs += result.newState.usage.tools.totalTimeMs; - // Merge per-tool statistics - Object.entries(result.newState.usage.tools.byTool).forEach(([tool, stats]) => { - if (newState.usage.tools.byTool[tool]) { - newState.usage.tools.byTool[tool].calls += stats.calls; - newState.usage.tools.byTool[tool].totalTimeMs += stats.totalTimeMs; + // Merge per-tool statistics (now using array) + result.newState.usage.tools.byTool.forEach((toolStats) => { + const existingTool = newState.usage.tools.byTool.find((t) => t.name === toolStats.name); + if (existingTool) { + existingTool.calls += toolStats.calls; + existingTool.totalTimeMs += toolStats.totalTimeMs; + existingTool.errors += toolStats.errors || 0; } else { - newState.usage.tools.byTool[tool] = { ...stats }; + newState.usage.tools.byTool.push({ ...toolStats }); } }); } @@ -656,6 +658,17 @@ export class AgentRuntime { if (result.newState.cost && newState.cost) { newState.cost.tools.total += result.newState.cost.tools.total; newState.cost.total += result.newState.cost.tools.total; + + // Merge per-tool cost statistics (now using array) + result.newState.cost.tools.byTool.forEach((toolCost) => { + const existingToolCost = newState.cost.tools.byTool.find((t) => t.name === toolCost.name); + if (existingToolCost) { + existingToolCost.calls += toolCost.calls; + existingToolCost.totalCost += toolCost.totalCost; + } else { + newState.cost.tools.byTool.push({ ...toolCost }); + } + }); } } diff --git a/packages/agent-runtime/src/types/instruction.ts b/packages/agent-runtime/src/types/instruction.ts index 517361c05f..f5f05dfc8d 100644 --- a/packages/agent-runtime/src/types/instruction.ts +++ b/packages/agent-runtime/src/types/instruction.ts @@ -29,7 +29,7 @@ export interface AgentRuntimeContext { stepCount: number; }; /** Usage statistics from the current step (if applicable) */ - stepUsage?: ModelUsage; + stepUsage?: ModelUsage | unknown; } /** diff --git a/packages/agent-runtime/src/types/state.ts b/packages/agent-runtime/src/types/state.ts index e1c23f6f05..dd4131e2c1 100644 --- a/packages/agent-runtime/src/types/state.ts +++ b/packages/agent-runtime/src/types/state.ts @@ -8,13 +8,13 @@ import type { Cost, CostLimit, Usage } from './usage'; export interface AgentState { sessionId: string; // --- State Machine --- - status: 'idle' | 'running' | 'waiting_for_human_input' | 'done' | 'error' | 'interrupted'; + status: 'idle' | 'running' | 'waiting_for_human' | 'done' | 'error' | 'interrupted'; // --- Core Context --- messages: any[]; tools?: any[]; systemRole?: string; - + toolManifestMap: Record; // --- Execution Tracking --- /** * Number of execution steps in this session. @@ -46,7 +46,7 @@ export interface AgentState { // --- HIL --- /** - * When status is 'waiting_for_human_input', this stores pending requests + * When status is 'waiting_for_human', this stores pending requests * for human-in-the-loop operations. */ pendingToolsCalling?: ToolsCalling[]; diff --git a/packages/agent-runtime/src/types/usage.ts b/packages/agent-runtime/src/types/usage.ts index 41a22bfbd5..44fd4585fb 100644 --- a/packages/agent-runtime/src/types/usage.ts +++ b/packages/agent-runtime/src/types/usage.ts @@ -1,3 +1,5 @@ +import { ModelUsage } from '@lobechat/types'; + /** * Token usage tracking for different types of operations */ @@ -39,14 +41,16 @@ export interface Usage { /** Tool usage statistics */ tools: { /** Usage breakdown by tool name */ - byTool: Record< - string, - { - calls: number; - errors: number; - totalTimeMs: number; - } - >; + byTool: Array<{ + /** Number of calls */ + calls: number; + /** Number of errors */ + errors: number; + /** Tool name/identifier */ + name: string; + /** Total execution time in milliseconds */ + totalTimeMs: number; + }>; /** Number of tool calls executed */ totalCalls: number; /** Total tool execution time */ @@ -66,15 +70,18 @@ export interface Cost { /** LLM API costs */ llm: { /** Cost per model used */ - byModel: Record< - string, - { - currency: string; - inputTokens: number; - outputTokens: number; - totalCost: number; - } - >; + byModel: Array<{ + /** Model identifier in format "provider/model" */ + id: string; + /** Model name */ + model: string; + /** Provider name */ + provider: string; + /** Total cost for this model */ + totalCost: number; + /** Detailed usage breakdown */ + usage: ModelUsage; + }>; currency: string; /** Total LLM cost */ total: number; @@ -82,14 +89,16 @@ export interface Cost { /** Tool execution costs */ tools: { /** Cost per tool (if tool has associated costs) */ - byTool: Record< - string, - { - calls: number; - currency: string; - totalCost: number; - } - >; + byTool: Array<{ + /** Number of calls */ + calls: number; + /** Currency */ + currency: string; + /** Tool name/identifier */ + name: string; + /** Total cost for this tool */ + totalCost: number; + }>; currency: string; /** Total tool cost */ total: number; diff --git a/packages/context-engine/src/index.ts b/packages/context-engine/src/index.ts index bd6e90d927..a7414e1161 100644 --- a/packages/context-engine/src/index.ts +++ b/packages/context-engine/src/index.ts @@ -32,6 +32,7 @@ export { export type { FunctionCallChecker, GenerateToolsParams, + LobeToolManifest, PluginEnableChecker, ToolNameGenerator, ToolsEngineOptions, diff --git a/packages/context-engine/src/tools/ToolNameResolver.ts b/packages/context-engine/src/tools/ToolNameResolver.ts index dcd31e7ddd..5bd3fb2b1d 100644 --- a/packages/context-engine/src/tools/ToolNameResolver.ts +++ b/packages/context-engine/src/tools/ToolNameResolver.ts @@ -1,7 +1,7 @@ import { ChatToolPayload, MessageToolCall } from '@lobechat/types'; import { Md5 } from 'ts-md5'; -import { LobeChatPluginApi, LobeChatPluginManifest } from './types'; +import { LobeChatPluginApi, LobeToolManifest } from './types'; // Tool naming constants const PLUGIN_SCHEMA_SEPARATOR = '____'; @@ -57,7 +57,7 @@ export class ToolNameResolver { */ resolve( toolCalls: MessageToolCall[], - manifests: Record, + manifests: Record, ): ChatToolPayload[] { return toolCalls .map((toolCall): ChatToolPayload | null => { diff --git a/packages/context-engine/src/tools/ToolsEngine.ts b/packages/context-engine/src/tools/ToolsEngine.ts index 402ffcd9ca..0b9caf7391 100644 --- a/packages/context-engine/src/tools/ToolsEngine.ts +++ b/packages/context-engine/src/tools/ToolsEngine.ts @@ -3,7 +3,7 @@ import debug from 'debug'; import { FunctionCallChecker, GenerateToolsParams, - LobeChatPluginManifest, + LobeToolManifest, PluginEnableChecker, ToolsEngineOptions, ToolsGenerationContext, @@ -18,7 +18,7 @@ const log = debug('context-engine:tools-engine'); * Tools Engine - Unified processing of tools array construction and transformation */ export class ToolsEngine { - private manifestSchemas: Map; + private manifestSchemas: Map; private enableChecker?: PluginEnableChecker; private functionCallChecker?: FunctionCallChecker; private defaultToolIds: string[]; @@ -162,13 +162,13 @@ export class ToolsEngine { context?: ToolsGenerationContext, supportsFunctionCall?: boolean, ): { - enabledManifests: LobeChatPluginManifest[]; + enabledManifests: LobeToolManifest[]; filteredPlugins: Array<{ id: string; reason: 'not_found' | 'disabled' | 'incompatible'; }>; } { - const enabledManifests: LobeChatPluginManifest[] = []; + const enabledManifests: LobeToolManifest[] = []; const filteredPlugins: Array<{ id: string; reason: 'not_found' | 'disabled' | 'incompatible'; @@ -240,7 +240,7 @@ export class ToolsEngine { /** * Convert manifests to UniformTool array */ - private convertManifestsToTools(manifests: LobeChatPluginManifest[]): UniformTool[] { + private convertManifestsToTools(manifests: LobeToolManifest[]): UniformTool[] { log('Converting %d manifests to tools', manifests.length); // Use simplified conversion logic to avoid external package dependencies @@ -290,14 +290,14 @@ export class ToolsEngine { /** * 获取插件的 manifest */ - getPluginManifest(pluginId: string): LobeChatPluginManifest | undefined { + getPluginManifest(pluginId: string): LobeToolManifest | undefined { return this.manifestSchemas.get(pluginId); } /** * 更新插件 manifest schemas(用于动态添加插件) */ - updateManifestSchemas(manifestSchemas: LobeChatPluginManifest[]): void { + updateManifestSchemas(manifestSchemas: LobeToolManifest[]): void { this.manifestSchemas.clear(); for (const schema of manifestSchemas) { this.manifestSchemas.set(schema.identifier, schema); @@ -307,7 +307,7 @@ export class ToolsEngine { /** * 添加单个插件 manifest */ - addPluginManifest(manifest: LobeChatPluginManifest): void { + addPluginManifest(manifest: LobeToolManifest): void { this.manifestSchemas.set(manifest.identifier, manifest); } @@ -317,4 +317,33 @@ export class ToolsEngine { removePluginManifest(pluginId: string): boolean { return this.manifestSchemas.delete(pluginId); } + + /** + * 获取所有 enabled plugin 的 Manifest Map + */ + getEnabledPluginManifests(toolIds: string[] = []): Map { + // Merge user-provided tool IDs with default tool IDs + const allToolIds = [...toolIds, ...this.defaultToolIds]; + + log('Getting enabled plugin manifests for pluginIds=%o', allToolIds); + + const manifestMap = new Map(); + + for (const pluginId of allToolIds) { + const manifest = this.manifestSchemas.get(pluginId); + if (manifest) { + manifestMap.set(pluginId, manifest); + } + } + + log('Returning %d enabled plugin manifests', manifestMap.size); + return manifestMap; + } + + /** + * 获取所有插件的 Manifest Map + */ + getAllPluginManifests(): Map { + return new Map(this.manifestSchemas); + } } diff --git a/packages/context-engine/src/tools/__tests__/ToolsEngine.test.ts b/packages/context-engine/src/tools/__tests__/ToolsEngine.test.ts index a9b07407c9..95406912dc 100644 --- a/packages/context-engine/src/tools/__tests__/ToolsEngine.test.ts +++ b/packages/context-engine/src/tools/__tests__/ToolsEngine.test.ts @@ -1,10 +1,10 @@ import { describe, expect, it, vi } from 'vitest'; import { ToolsEngine } from '../ToolsEngine'; -import type { LobeChatPluginManifest } from '../types'; +import type { LobeToolManifest } from '../types'; // Mock manifest schemas for testing -const mockWebBrowsingManifest: LobeChatPluginManifest = { +const mockWebBrowsingManifest: LobeToolManifest = { api: [ { description: 'Search the web', @@ -26,7 +26,7 @@ const mockWebBrowsingManifest: LobeChatPluginManifest = { type: 'builtin', }; -const mockDalleManifest: LobeChatPluginManifest = { +const mockDalleManifest: LobeToolManifest = { api: [ { description: 'Generate images', @@ -337,6 +337,150 @@ describe('ToolsEngine', () => { }); }); + describe('getEnabledPluginManifests', () => { + it('should return empty Map when no tool IDs provided', () => { + const engine = new ToolsEngine({ + manifestSchemas: [mockWebBrowsingManifest, mockDalleManifest], + }); + + const result = engine.getEnabledPluginManifests([]); + + expect(result.size).toBe(0); + }); + + it('should return Map with plugin manifests for given tool IDs', () => { + const engine = new ToolsEngine({ + manifestSchemas: [mockWebBrowsingManifest, mockDalleManifest], + }); + + const result = engine.getEnabledPluginManifests(['lobe-web-browsing', 'dalle']); + + expect(result.size).toBe(2); + expect(result.get('lobe-web-browsing')).toBe(mockWebBrowsingManifest); + expect(result.get('dalle')).toBe(mockDalleManifest); + }); + + it('should include default tool IDs in the result', () => { + const engine = new ToolsEngine({ + manifestSchemas: [mockWebBrowsingManifest, mockDalleManifest], + defaultToolIds: ['dalle'], + }); + + const result = engine.getEnabledPluginManifests(['lobe-web-browsing']); + + expect(result.size).toBe(2); + expect(result.has('lobe-web-browsing')).toBe(true); + expect(result.has('dalle')).toBe(true); + }); + + it('should handle non-existent plugins gracefully', () => { + const engine = new ToolsEngine({ + manifestSchemas: [mockWebBrowsingManifest], + }); + + const result = engine.getEnabledPluginManifests(['lobe-web-browsing', 'non-existent']); + + expect(result.size).toBe(1); + expect(result.has('lobe-web-browsing')).toBe(true); + expect(result.has('non-existent')).toBe(false); + }); + + it('should return all default tools when called without arguments', () => { + const engine = new ToolsEngine({ + manifestSchemas: [mockWebBrowsingManifest, mockDalleManifest], + defaultToolIds: ['dalle', 'lobe-web-browsing'], + }); + + const result = engine.getEnabledPluginManifests(); + + expect(result.size).toBe(2); + expect(result.has('lobe-web-browsing')).toBe(true); + expect(result.has('dalle')).toBe(true); + }); + + it('should not duplicate tools when same ID appears in both toolIds and defaultToolIds', () => { + const engine = new ToolsEngine({ + manifestSchemas: [mockWebBrowsingManifest, mockDalleManifest], + defaultToolIds: ['dalle'], + }); + + const result = engine.getEnabledPluginManifests(['dalle']); + + expect(result.size).toBe(1); + expect(result.get('dalle')).toBe(mockDalleManifest); + }); + }); + + describe('getAllPluginManifests', () => { + it('should return all plugin manifests', () => { + const engine = new ToolsEngine({ + manifestSchemas: [mockWebBrowsingManifest, mockDalleManifest], + }); + + const result = engine.getAllPluginManifests(); + + expect(result.size).toBe(2); + expect(result.get('lobe-web-browsing')).toBe(mockWebBrowsingManifest); + expect(result.get('dalle')).toBe(mockDalleManifest); + }); + + it('should return empty Map when no manifests are loaded', () => { + const engine = new ToolsEngine({ + manifestSchemas: [], + }); + + const result = engine.getAllPluginManifests(); + + expect(result.size).toBe(0); + }); + + it('should return a new Map instance (not the internal one)', () => { + const engine = new ToolsEngine({ + manifestSchemas: [mockWebBrowsingManifest], + }); + + const result1 = engine.getAllPluginManifests(); + const result2 = engine.getAllPluginManifests(); + + // Should be different Map instances + expect(result1).not.toBe(result2); + + // But have the same content + expect(result1.size).toBe(result2.size); + expect(result1.get('lobe-web-browsing')).toBe(result2.get('lobe-web-browsing')); + }); + + it('should reflect changes after adding a plugin', () => { + const engine = new ToolsEngine({ + manifestSchemas: [mockWebBrowsingManifest], + }); + + let result = engine.getAllPluginManifests(); + expect(result.size).toBe(1); + + engine.addPluginManifest(mockDalleManifest); + + result = engine.getAllPluginManifests(); + expect(result.size).toBe(2); + expect(result.has('dalle')).toBe(true); + }); + + it('should reflect changes after removing a plugin', () => { + const engine = new ToolsEngine({ + manifestSchemas: [mockWebBrowsingManifest, mockDalleManifest], + }); + + let result = engine.getAllPluginManifests(); + expect(result.size).toBe(2); + + engine.removePluginManifest('dalle'); + + result = engine.getAllPluginManifests(); + expect(result.size).toBe(1); + expect(result.has('dalle')).toBe(false); + }); + }); + describe('default behavior', () => { it('should use default enable checker when none provided', () => { const engine = new ToolsEngine({ @@ -373,7 +517,7 @@ describe('ToolsEngine', () => { describe('ToolsEngine Integration Tests (migrated from enabledSchema)', () => { // Mock manifest data similar to the original tool selector tests - const mockManifests: LobeChatPluginManifest[] = [ + const mockManifests: LobeToolManifest[] = [ { identifier: 'plugin-1', api: [{ name: 'api-1', description: 'API 1', parameters: {} }], @@ -730,7 +874,7 @@ describe('ToolsEngine', () => { */ describe('enabledSchema Migration to ToolsEngine', () => { // Sample manifest data that mimics the old toolSelectors test data - const sampleManifests: LobeChatPluginManifest[] = [ + const sampleManifests: LobeToolManifest[] = [ { identifier: 'plugin-1', api: [{ name: 'api-1', description: 'API 1', parameters: {} }], diff --git a/packages/context-engine/src/tools/__tests__/utils.test.ts b/packages/context-engine/src/tools/__tests__/utils.test.ts index 5617ce119d..efd3d9e6e2 100644 --- a/packages/context-engine/src/tools/__tests__/utils.test.ts +++ b/packages/context-engine/src/tools/__tests__/utils.test.ts @@ -1,10 +1,10 @@ import { describe, expect, it } from 'vitest'; -import type { LobeChatPluginManifest } from '../types'; +import type { LobeToolManifest } from '../types'; import { filterValidManifests, validateManifest } from '../utils'; // Mock manifest schemas -const mockBuiltinManifest: LobeChatPluginManifest = { +const mockBuiltinManifest: LobeToolManifest = { api: [ { description: 'Built-in tool', diff --git a/packages/context-engine/src/tools/index.ts b/packages/context-engine/src/tools/index.ts index 8ec9d3ca55..fe3ebc6e38 100644 --- a/packages/context-engine/src/tools/index.ts +++ b/packages/context-engine/src/tools/index.ts @@ -8,6 +8,7 @@ export { ToolNameResolver } from './ToolNameResolver'; export type { FunctionCallChecker, GenerateToolsParams, + LobeToolManifest, PluginEnableChecker, ToolNameGenerator, ToolsEngineOptions, diff --git a/packages/context-engine/src/tools/types.ts b/packages/context-engine/src/tools/types.ts index 165fde7945..abdfd6dc39 100644 --- a/packages/context-engine/src/tools/types.ts +++ b/packages/context-engine/src/tools/types.ts @@ -1,11 +1,26 @@ +import type { HumanInterventionConfig } from '@lobechat/types'; + export interface LobeChatPluginApi { description: string; + /** + * Human intervention configuration + * Controls when and how the tool requires human approval/selection + * + * Can be either: + * - Simple: A policy string ('never', 'always', 'first') + * - Complex: Array of rules for parameter-level control + * + * Examples: + * - 'always' - always require intervention + * - [{ match: { command: "git add:*" }, policy: "never" }, { policy: "always" }] + */ + humanIntervention?: HumanInterventionConfig; name: string; parameters: Record; url?: string; } -export interface LobeChatPluginManifest { +export interface LobeToolManifest { api: LobeChatPluginApi[]; identifier: string; meta: any; @@ -36,7 +51,7 @@ export interface ToolsGenerationContext { */ export type PluginEnableChecker = (params: { context?: ToolsGenerationContext; - manifest: LobeChatPluginManifest; + manifest: LobeToolManifest; model: string; pluginId: string; provider: string; @@ -79,7 +94,7 @@ export interface ToolsEngineOptions { /** Optional tool name generator function */ generateToolName?: ToolNameGenerator; /** Statically injected manifest schemas */ - manifestSchemas: LobeChatPluginManifest[]; + manifestSchemas: LobeToolManifest[]; } /** diff --git a/packages/context-engine/src/tools/utils.ts b/packages/context-engine/src/tools/utils.ts index 3e73edbe0f..defc993e3f 100644 --- a/packages/context-engine/src/tools/utils.ts +++ b/packages/context-engine/src/tools/utils.ts @@ -1,5 +1,5 @@ import { ToolNameResolver } from './ToolNameResolver'; -import { LobeChatPluginManifest } from './types'; +import { LobeToolManifest } from './types'; // Create a singleton instance for backward compatibility const resolver = new ToolNameResolver(); @@ -19,7 +19,7 @@ export const generateToolName = ( /** * Validate manifest schema structure */ -export function validateManifest(manifest: any): manifest is LobeChatPluginManifest { +export function validateManifest(manifest: any): manifest is LobeToolManifest { return Boolean( manifest && typeof manifest === 'object' && @@ -34,9 +34,9 @@ export function validateManifest(manifest: any): manifest is LobeChatPluginManif */ export function filterValidManifests(manifestSchemas: any[]): { invalid: any[]; - valid: LobeChatPluginManifest[]; + valid: LobeToolManifest[]; } { - const valid: LobeChatPluginManifest[] = []; + const valid: LobeToolManifest[] = []; const invalid: any[] = []; for (const manifest of manifestSchemas) { diff --git a/packages/types/src/tool/builtin.ts b/packages/types/src/tool/builtin.ts index 25bab46632..4720dea4d5 100644 --- a/packages/types/src/tool/builtin.ts +++ b/packages/types/src/tool/builtin.ts @@ -1,9 +1,62 @@ -import { LobeChatPluginApi, Meta } from '@lobehub/chat-plugin-sdk'; import { ReactNode } from 'react'; +import type { HumanInterventionConfig, HumanInterventionPolicy } from './intervention'; + +interface Meta { + /** + * avatar + * @desc Avatar of the plugin + * @nameCN 头像 + * @descCN 插件的头像 + */ + avatar?: string; + /** + * description + * @desc Description of the plugin + * @nameCN 描述 + * @descCN 插件的描述 + */ + description?: string; + /** + * tags + * @desc Tags of the plugin + * @nameCN 标签 + * @descCN 插件的标签 + */ + tags?: string[]; + title: string; +} +export interface LobeChatPluginApi { + description: string; + /** + * Human intervention configuration + * Controls when and how the tool requires human approval/selection + * + * Can be either: + * - Simple: A policy string ('never', 'always', 'first') + * - Complex: Array of rules for parameter-level control + * + * Examples: + * - 'always' - always require intervention + * - [{ match: { command: "git add:*" }, policy: "never" }, { policy: "always" }] + */ + humanIntervention?: HumanInterventionConfig; + name: string; + parameters: Record; + url?: string; +} + export interface BuiltinToolManifest { api: LobeChatPluginApi[]; + /** + * Tool-level default human intervention policy + * This policy applies to all APIs that don't specify their own policy + * + * @default 'never' + */ + humanIntervention?: HumanInterventionPolicy; + /** * Plugin name */ diff --git a/packages/types/src/tool/index.ts b/packages/types/src/tool/index.ts index 1385a267cb..a543eeaad8 100644 --- a/packages/types/src/tool/index.ts +++ b/packages/types/src/tool/index.ts @@ -26,6 +26,7 @@ export type LobeToolRenderType = LobePluginType | 'builtin'; export * from './builtin'; export * from './crawler'; export * from './interpreter'; +export * from './intervention'; export * from './plugin'; export * from './search'; export * from './tool'; diff --git a/packages/types/src/tool/intervention.ts b/packages/types/src/tool/intervention.ts new file mode 100644 index 0000000000..af93a8c8e1 --- /dev/null +++ b/packages/types/src/tool/intervention.ts @@ -0,0 +1,114 @@ +/** + * Human Intervention Policy + */ +export type HumanInterventionPolicy = + | 'never' // Never intervene, auto-execute + | 'always' // Always require intervention + | 'first'; // Require intervention on first call only + +/** + * Argument Matcher for parameter-level filtering + * Supports wildcard patterns, prefix matching, and regex + * + * Examples: + * - "git add:*" - matches any git add command + * - "/Users/project/*" - matches paths under /Users/project/ + * - { pattern: "^rm.*", type: "regex" } - regex matching + */ +export type ArgumentMatcher = + | string // Simple string or wildcard pattern + | { + pattern: string; + type: 'exact' | 'prefix' | 'wildcard' | 'regex'; + }; + +/** + * Human Intervention Rule + * Used for parameter-level control of intervention behavior + */ +export interface HumanInterventionRule { + /** + * Parameter filter - matches against tool call arguments + * Key is the parameter name, value is the matcher + * + * Example: + * { command: "git add:*" } - matches when command param starts with "git add" + */ + match?: Record; + + /** + * Policy to apply when this rule matches + */ + policy: HumanInterventionPolicy; +} + +/** + * Human Intervention Configuration + * Can be either: + * - Simple: Direct policy string for uniform behavior + * - Complex: Array of rules for parameter-level control + * + * Examples: + * - "always" - always require intervention + * - [{ match: { command: "ls:*" }, policy: "never" }, { policy: "always" }] + */ +export type HumanInterventionConfig = HumanInterventionPolicy | HumanInterventionRule[]; + +/** + * Human Intervention Response + * User's response to an intervention request + */ +export interface HumanInterventionResponse { + /** + * User's action: + * - approve: Allow the tool call to proceed + * - reject: Deny the tool call + * - select: User made a selection from options + */ + action: 'approve' | 'reject' | 'select'; + + /** + * Additional data based on action type + */ + data?: { + /** + * Whether to remember this decision for future calls + * Only applicable for 'first' policy + */ + remember?: boolean; + + /** + * Selected value(s) for select action + * Can be single string or array for multi-select + */ + selected?: string | string[]; + }; +} + +/** + * Parameters for shouldIntervene method + */ +export interface ShouldInterveneParams { + /** + * Intervention configuration (from manifest or user override) + */ + config: HumanInterventionConfig | undefined; + + /** + * List of confirmed tool calls (for 'first' policy) + * @default [] + */ + confirmedHistory?: string[]; + + /** + * Tool call arguments to check against rules + * @default {} + */ + toolArgs?: Record; + + /** + * Tool key to check against confirmed history + * Format: "identifier/apiName" or "identifier/apiName#argsHash" + */ + toolKey?: string; +} diff --git a/packages/types/src/user/settings/tool.ts b/packages/types/src/user/settings/tool.ts index ec6ab9818c..c0090e0ac2 100644 --- a/packages/types/src/user/settings/tool.ts +++ b/packages/types/src/user/settings/tool.ts @@ -1,5 +1,42 @@ +import type { HumanInterventionConfig } from '../../tool'; + export interface UserToolConfig { dalle: { autoGenerate: boolean; }; + /** + * Human intervention configuration + */ + humanIntervention?: { + /** + * List of confirmed tool calls (for 'once' policy) + * Format: "identifier/apiName" or "identifier/apiName#argsHash" + * + * Examples: + * - "web-browsing/crawlSinglePage" + * - "bash/bash#a1b2c3d4" + */ + confirmed?: string[]; + + /** + * Whether human intervention is enabled globally + * @default true + */ + enabled: boolean; + + /** + * Per-tool intervention policy overrides + * Key format: "identifier/apiName" + * + * Example: + * { + * "web-browsing/crawlSinglePage": "confirm", + * "bash/bash": [ + * { match: { command: "git add:*" }, policy: "auto" }, + * { policy: "confirm" } + * ] + * } + */ + overrides?: Record; + }; } diff --git a/src/features/Conversation/Messages/Assistant/Tool/Render/index.tsx b/src/features/Conversation/Messages/Assistant/Tool/Render/index.tsx index 6b83c97852..945b240caf 100644 --- a/src/features/Conversation/Messages/Assistant/Tool/Render/index.tsx +++ b/src/features/Conversation/Messages/Assistant/Tool/Render/index.tsx @@ -52,8 +52,10 @@ const Render = memo( // 如果是 LOADING_FLAT 则说明还在加载中 // 而 standalone 模式的插件 content 应该始终是 LOADING_FLAT - if (toolMessage.content === LOADING_FLAT && toolMessage.plugin?.type !== 'standalone') - return placeholder; + const inPlaceholder = + toolMessage.content === LOADING_FLAT && toolMessage.plugin?.type !== 'standalone'; + + if (inPlaceholder) return placeholder; return ( diff --git a/src/store/chat/slices/builtinTool/actions/dalle.test.ts b/src/store/chat/slices/builtinTool/actions/__tests__/dalle.test.ts similarity index 95% rename from src/store/chat/slices/builtinTool/actions/dalle.test.ts rename to src/store/chat/slices/builtinTool/actions/__tests__/dalle.test.ts index 127009aa05..6adabe249a 100644 --- a/src/store/chat/slices/builtinTool/actions/dalle.test.ts +++ b/src/store/chat/slices/builtinTool/actions/__tests__/dalle.test.ts @@ -1,18 +1,15 @@ +import { ChatMessage } from '@lobechat/types'; import { act, renderHook } from '@testing-library/react'; import { describe, expect, it, vi } from 'vitest'; -import { fileService } from '@/services/file'; -import { ClientService } from '@/services/file/_deprecated'; import { messageService } from '@/services/message'; import { imageGenerationService } from '@/services/textToImage'; import { uploadService } from '@/services/upload'; +import { useChatStore } from '@/store/chat'; import { chatSelectors } from '@/store/chat/selectors'; import { useFileStore } from '@/store/file'; -import { ChatMessage } from '@/types/message'; import { DallEImageItem } from '@/types/tool/dalle'; -import { useChatStore } from '../../../store'; - describe('chatToolSlice - dalle', () => { describe('generateImageFromPrompts', () => { it('should generate images from prompts, update items, and upload images', async () => { diff --git a/src/store/chat/slices/builtinTool/actions/__tests__/localFile.test.ts b/src/store/chat/slices/builtinTool/actions/__tests__/localSystem.test.ts similarity index 96% rename from src/store/chat/slices/builtinTool/actions/__tests__/localFile.test.ts rename to src/store/chat/slices/builtinTool/actions/__tests__/localSystem.test.ts index 7abe30d81e..ff28ff1fe5 100644 --- a/src/store/chat/slices/builtinTool/actions/__tests__/localFile.test.ts +++ b/src/store/chat/slices/builtinTool/actions/__tests__/localSystem.test.ts @@ -2,9 +2,9 @@ import { LocalFileItem, LocalMoveFilesResultItem } from '@lobechat/electron-clie import { describe, expect, it, vi } from 'vitest'; import { localFileService } from '@/services/electron/localFileService'; -import { ChatStore } from '@/store/chat/store'; +import { ChatStore } from '@/store/chat'; -import { localFileSlice } from '../localFile'; +import { localSystemSlice } from '../localSystem'; vi.mock('@/services/electron/localFileService', () => ({ localFileService: { @@ -31,7 +31,7 @@ const mockStore = { } as unknown as ChatStore; const createStore = () => { - return localFileSlice( + return localSystemSlice( (set) => ({ ...mockStore, set, @@ -191,7 +191,7 @@ describe('localFileSlice', () => { describe('toggleLocalFileLoading', () => { it('should toggle loading state', () => { const mockSetFn = vi.fn(); - const testStore = localFileSlice(mockSetFn, () => mockStore, {} as any); + const testStore = localSystemSlice(mockSetFn, () => mockStore, {} as any); testStore.toggleLocalFileLoading('test-id', true); expect(mockSetFn).toHaveBeenCalledWith( diff --git a/src/store/chat/slices/builtinTool/actions/search.test.ts b/src/store/chat/slices/builtinTool/actions/__tests__/search.test.ts similarity index 100% rename from src/store/chat/slices/builtinTool/actions/search.test.ts rename to src/store/chat/slices/builtinTool/actions/__tests__/search.test.ts diff --git a/src/store/chat/slices/builtinTool/actions/index.ts b/src/store/chat/slices/builtinTool/actions/index.ts index 518e2ae423..512a0abb26 100644 --- a/src/store/chat/slices/builtinTool/actions/index.ts +++ b/src/store/chat/slices/builtinTool/actions/index.ts @@ -4,7 +4,7 @@ import { ChatStore } from '@/store/chat/store'; import { ChatDallEAction, dalleSlice } from './dalle'; import { ChatCodeInterpreterAction, codeInterpreterSlice } from './interpreter'; -import { LocalFileAction, localFileSlice } from './localFile'; +import { LocalFileAction, localSystemSlice } from './localSystem'; import { SearchAction, searchSlice } from './search'; export interface ChatBuiltinToolAction @@ -21,6 +21,6 @@ export const chatToolSlice: StateCreator< > = (...params) => ({ ...dalleSlice(...params), ...searchSlice(...params), - ...localFileSlice(...params), + ...localSystemSlice(...params), ...codeInterpreterSlice(...params), }); diff --git a/src/store/chat/slices/builtinTool/actions/localFile.ts b/src/store/chat/slices/builtinTool/actions/localSystem.ts similarity index 63% rename from src/store/chat/slices/builtinTool/actions/localFile.ts rename to src/store/chat/slices/builtinTool/actions/localSystem.ts index c779d40fde..2820617873 100644 --- a/src/store/chat/slices/builtinTool/actions/localFile.ts +++ b/src/store/chat/slices/builtinTool/actions/localSystem.ts @@ -1,4 +1,9 @@ import { + EditLocalFileParams, + GetCommandOutputParams, + GlobFilesParams, + GrepContentParams, + KillCommandParams, ListLocalFileParams, LocalMoveFilesResultItem, LocalReadFileParams, @@ -6,6 +11,7 @@ import { LocalSearchFilesParams, MoveLocalFilesParams, RenameLocalFileParams, + RunCommandParams, WriteLocalFileParams, } from '@lobechat/electron-client-ipc'; import { StateCreator } from 'zustand/vanilla'; @@ -13,67 +19,91 @@ import { StateCreator } from 'zustand/vanilla'; import { localFileService } from '@/services/electron/localFileService'; import { ChatStore } from '@/store/chat/store'; import { + EditLocalFileState, + GetCommandOutputState, + GlobFilesState, + GrepContentState, + KillCommandState, LocalFileListState, LocalFileSearchState, LocalMoveFilesState, LocalReadFileState, LocalReadFilesState, LocalRenameFileState, + RunCommandState, } from '@/tools/local-system/type'; +/* eslint-disable typescript-sort-keys/interface */ export interface LocalFileAction { internal_triggerLocalFileToolCalling: ( id: string, callingService: () => Promise<{ content: any; state?: T }>, ) => Promise; + // File Operations listLocalFiles: (id: string, params: ListLocalFileParams) => Promise; moveLocalFiles: (id: string, params: MoveLocalFilesParams) => Promise; - reSearchLocalFiles: (id: string, params: LocalSearchFilesParams) => Promise; readLocalFile: (id: string, params: LocalReadFileParams) => Promise; readLocalFiles: (id: string, params: LocalReadFilesParams) => Promise; renameLocalFile: (id: string, params: RenameLocalFileParams) => Promise; - // Added rename action + reSearchLocalFiles: (id: string, params: LocalSearchFilesParams) => Promise; searchLocalFiles: (id: string, params: LocalSearchFilesParams) => Promise; toggleLocalFileLoading: (id: string, loading: boolean) => void; - writeLocalFile: (id: string, params: WriteLocalFileParams) => Promise; -} -export const localFileSlice: StateCreator< + // Shell Commands + editLocalFile: (id: string, params: EditLocalFileParams) => Promise; + getCommandOutput: (id: string, params: GetCommandOutputParams) => Promise; + killCommand: (id: string, params: KillCommandParams) => Promise; + runCommand: (id: string, params: RunCommandParams) => Promise; + + // Search & Find + globLocalFiles: (id: string, params: GlobFilesParams) => Promise; + grepContent: (id: string, params: GrepContentParams) => Promise; +} +/* eslint-enable typescript-sort-keys/interface */ + +/* eslint-disable sort-keys-fix/sort-keys-fix */ +export const localSystemSlice: StateCreator< ChatStore, [['zustand/devtools', never]], [], LocalFileAction > = (set, get) => ({ - internal_triggerLocalFileToolCalling: async (id, callingService) => { - get().toggleLocalFileLoading(id, true); - try { - const { state, content } = await callingService(); - if (state) { - await get().updatePluginState(id, state as any); - } - await get().internal_updateMessageContent(id, JSON.stringify(content)); - } catch (error) { - await get().internal_updateMessagePluginError(id, { - body: error, - message: (error as Error).message, - type: 'PluginServerError', - }); - } - get().toggleLocalFileLoading(id, false); + // ==================== File Editing ==================== + editLocalFile: async (id, params) => { + return get().internal_triggerLocalFileToolCalling(id, async () => { + const result = await localFileService.editLocalFile(params); - return true; - }, + const message = result.success + ? `Successfully replaced ${result.replacements} occurrence(s) in ${params.file_path}` + : `Edit failed: ${result.error}`; + + const state: EditLocalFileState = { message, result }; - listLocalFiles: async (id, params) => { - return get().internal_triggerLocalFileToolCalling(id, async () => { - const result = await localFileService.listLocalFiles(params); - const state: LocalFileListState = { listResults: result }; return { content: result, state }; }); }, + writeLocalFile: async (id, params) => { + return get().internal_triggerLocalFileToolCalling(id, async () => { + const result = await localFileService.writeFile(params); + + let content: { message: string; success: boolean }; + + if (result.success) { + content = { + message: `成功写入文件 ${params.path}`, + success: true, + }; + } else { + const errorMessage = result.error; + + content = { message: errorMessage || '写入文件失败', success: false }; + } + return { content }; + }); + }, moveLocalFiles: async (id, params) => { return get().internal_triggerLocalFileToolCalling(id, async () => { const results: LocalMoveFilesResultItem[] = await localFileService.moveLocalFiles(params); @@ -100,31 +130,6 @@ export const localFileSlice: StateCreator< return { content: { message, results }, state }; }); }, - - reSearchLocalFiles: async (id, params) => { - get().toggleLocalFileLoading(id, true); - - await get().updatePluginArguments(id, params); - - return get().searchLocalFiles(id, params); - }, - - readLocalFile: async (id, params) => { - return get().internal_triggerLocalFileToolCalling(id, async () => { - const result = await localFileService.readLocalFile(params); - const state: LocalReadFileState = { fileContent: result }; - return { content: result, state }; - }); - }, - - readLocalFiles: async (id, params) => { - return get().internal_triggerLocalFileToolCalling(id, async () => { - const results = await localFileService.readLocalFiles(params); - const state: LocalReadFilesState = { filesContent: results }; - return { content: results, state }; - }); - }, - renameLocalFile: async (id, params) => { return get().internal_triggerLocalFileToolCalling(id, async () => { const { path: currentPath, newName } = params; @@ -169,6 +174,33 @@ export const localFileSlice: StateCreator< }); }, + // ==================== Search & Find ==================== + grepContent: async (id, params) => { + return get().internal_triggerLocalFileToolCalling(id, async () => { + const result = await localFileService.grepContent(params); + + const message = result.success + ? `Found ${result.total_matches} matches in ${result.matches.length} locations` + : 'Search failed'; + + const state: GrepContentState = { message, result }; + + return { content: result, state }; + }); + }, + + globLocalFiles: async (id, params) => { + return get().internal_triggerLocalFileToolCalling(id, async () => { + const result = await localFileService.globFiles(params); + + const message = result.success ? `Found ${result.total_files} files` : 'Glob search failed'; + + const state: GlobFilesState = { message, result }; + + return { content: result, state }; + }); + }, + searchLocalFiles: async (id, params) => { return get().internal_triggerLocalFileToolCalling(id, async () => { const result = await localFileService.searchLocalFiles(params); @@ -177,6 +209,89 @@ export const localFileSlice: StateCreator< }); }, + listLocalFiles: async (id, params) => { + return get().internal_triggerLocalFileToolCalling(id, async () => { + const result = await localFileService.listLocalFiles(params); + const state: LocalFileListState = { listResults: result }; + return { content: result, state }; + }); + }, + + reSearchLocalFiles: async (id, params) => { + get().toggleLocalFileLoading(id, true); + + await get().updatePluginArguments(id, params); + + return get().searchLocalFiles(id, params); + }, + + readLocalFile: async (id, params) => { + return get().internal_triggerLocalFileToolCalling(id, async () => { + const result = await localFileService.readLocalFile(params); + const state: LocalReadFileState = { fileContent: result }; + return { content: result, state }; + }); + }, + + readLocalFiles: async (id, params) => { + return get().internal_triggerLocalFileToolCalling(id, async () => { + const results = await localFileService.readLocalFiles(params); + const state: LocalReadFilesState = { filesContent: results }; + return { content: results, state }; + }); + }, + + // ==================== Shell Commands ==================== + runCommand: async (id, params) => { + return get().internal_triggerLocalFileToolCalling(id, async () => { + const result = await localFileService.runCommand(params); + + let message: string; + + if (result.success) { + if (result.shell_id) { + message = `Command started in background with shell_id: ${result.shell_id}`; + } else { + message = `Command completed successfully. Exit code: ${result.exit_code}`; + } + } else { + message = `Command failed: ${result.error}`; + } + + const state: RunCommandState = { message, result }; + + return { content: result, state }; + }); + }, + killCommand: async (id, params) => { + return get().internal_triggerLocalFileToolCalling(id, async () => { + const result = await localFileService.killCommand(params); + + const message = result.success + ? `Successfully killed shell: ${params.shell_id}` + : `Failed to kill shell: ${result.error}`; + + const state: KillCommandState = { message, result }; + + return { content: result, state }; + }); + }, + getCommandOutput: async (id, params) => { + return get().internal_triggerLocalFileToolCalling(id, async () => { + const result = await localFileService.getCommandOutput(params); + + const message = result.success + ? `Output retrieved. Running: ${result.running}` + : `Failed: ${result.error}`; + + const state: GetCommandOutputState = { message, result }; + + return { content: result, state }; + }); + }, + + // ==================== utils ==================== + toggleLocalFileLoading: (id, loading) => { // Assuming a loading state structure similar to searchLoading set( @@ -187,24 +302,23 @@ export const localFileSlice: StateCreator< `toggleLocalFileLoading/${loading ? 'start' : 'end'}`, ); }, - - writeLocalFile: async (id, params) => { - return get().internal_triggerLocalFileToolCalling(id, async () => { - const result = await localFileService.writeFile(params); - - let content: { message: string; success: boolean }; - - if (result.success) { - content = { - message: `成功写入文件 ${params.path}`, - success: true, - }; - } else { - const errorMessage = result.error; - - content = { message: errorMessage || '写入文件失败', success: false }; + internal_triggerLocalFileToolCalling: async (id, callingService) => { + get().toggleLocalFileLoading(id, true); + try { + const { state, content } = await callingService(); + if (state) { + await get().updatePluginState(id, state as any); } - return { content }; - }); + await get().internal_updateMessageContent(id, JSON.stringify(content)); + } catch (error) { + await get().internal_updateMessagePluginError(id, { + body: error, + message: (error as Error).message, + type: 'PluginServerError', + }); + } + get().toggleLocalFileLoading(id, false); + + return true; }, }); diff --git a/src/store/electron/selectors/__tests__/desktopState.test.ts b/src/store/electron/selectors/__tests__/desktopState.test.ts index d7ce38f34e..acf3c08928 100644 --- a/src/store/electron/selectors/__tests__/desktopState.test.ts +++ b/src/store/electron/selectors/__tests__/desktopState.test.ts @@ -6,7 +6,7 @@ import { merge } from '@/utils/merge'; import { desktopStateSelectors } from '../desktopState'; describe('desktopStateSelectors', () => { - describe('usePath', () => { + describe('userPath', () => { it('should return userPath from appState', () => { const state: ElectronState = merge(initialState, { appState: { @@ -23,7 +23,7 @@ describe('desktopStateSelectors', () => { }, }); - expect(desktopStateSelectors.usePath(state)).toEqual({ + expect(desktopStateSelectors.userPath(state)).toEqual({ desktop: '/test/desktop', documents: '/test/documents', downloads: '/test/downloads', @@ -40,7 +40,7 @@ describe('desktopStateSelectors', () => { appState: {}, }); - expect(desktopStateSelectors.usePath(state)).toBeUndefined(); + expect(desktopStateSelectors.userPath(state)).toBeUndefined(); }); }); }); diff --git a/src/store/electron/selectors/desktopState.ts b/src/store/electron/selectors/desktopState.ts index a85239dd35..8e48a6a3fd 100644 --- a/src/store/electron/selectors/desktopState.ts +++ b/src/store/electron/selectors/desktopState.ts @@ -1,7 +1,16 @@ import { ElectronState } from '@/store/electron/initialState'; -const usePath = (s: ElectronState) => s.appState.userPath; +const userPath = (s: ElectronState) => s.appState.userPath; +const userHomePath = (s: ElectronState) => userPath(s)?.home || ''; + +const displayRelativePath = (path: string) => (s: ElectronState) => { + const basePath = userHomePath(s); + + return !!basePath ? path.replaceAll(basePath, '~') : path; +}; export const desktopStateSelectors = { - usePath, + displayRelativePath, + userHomePath, + userPath, }; diff --git a/src/tools/local-system/Placeholder/ListFiles.tsx b/src/tools/local-system/Placeholder/ListFiles.tsx index c276463f6f..d9c492e9e5 100644 --- a/src/tools/local-system/Placeholder/ListFiles.tsx +++ b/src/tools/local-system/Placeholder/ListFiles.tsx @@ -1,7 +1,7 @@ import { ListLocalFileParams } from '@lobechat/electron-client-ipc'; import { Skeleton } from 'antd'; import React, { memo } from 'react'; -import { Flexbox } from 'react-layout-kit'; +import { Center, Flexbox } from 'react-layout-kit'; import { LocalFolder } from '@/features/LocalFile'; @@ -10,14 +10,16 @@ interface ListFilesProps { } export const ListFiles = memo(({ args }) => { return ( - + - - - - - - +
+ + + + + + +
); }); diff --git a/src/tools/local-system/Placeholder/ReadLocalFile.tsx b/src/tools/local-system/Placeholder/ReadLocalFile.tsx deleted file mode 100644 index 85b0d1e4a8..0000000000 --- a/src/tools/local-system/Placeholder/ReadLocalFile.tsx +++ /dev/null @@ -1,9 +0,0 @@ -'use client'; - -import { memo } from 'react'; - -import Skeleton from '../Render/ReadLocalFile/ReadFileSkeleton'; - -const ReadLocalFile = memo(() => ); - -export default ReadLocalFile; diff --git a/src/tools/local-system/Placeholder/SearchFiles.tsx b/src/tools/local-system/Placeholder/SearchFiles.tsx index 7c853dea81..863a2eb606 100644 --- a/src/tools/local-system/Placeholder/SearchFiles.tsx +++ b/src/tools/local-system/Placeholder/SearchFiles.tsx @@ -3,8 +3,8 @@ import { Icon } from '@lobehub/ui'; import { Skeleton } from 'antd'; import { createStyles } from 'antd-style'; import { SearchIcon } from 'lucide-react'; -import { memo } from 'react'; -import { Flexbox } from 'react-layout-kit'; +import React, { memo } from 'react'; +import { Center, Flexbox } from 'react-layout-kit'; const useStyles = createStyles(({ css, token, cx }) => ({ query: cx(css` @@ -29,8 +29,8 @@ const SearchFiles = memo(({ args }) => { const { styles } = useStyles(); return ( - - + + {args.keywords ? ( @@ -42,12 +42,14 @@ const SearchFiles = memo(({ args }) => { - - - - - - +
+ + + + + + +
); }); diff --git a/src/tools/local-system/Placeholder/index.tsx b/src/tools/local-system/Placeholder/index.tsx index 377441a4aa..9d805c3ea5 100644 --- a/src/tools/local-system/Placeholder/index.tsx +++ b/src/tools/local-system/Placeholder/index.tsx @@ -3,8 +3,8 @@ import { memo } from 'react'; import { LocalSystemApiName } from '@/tools/local-system'; +import ReadLocalFile from '../Render/ReadLocalFile/ReadFileSkeleton'; import { ListFiles } from './ListFiles'; -import ReadLocalFile from './ReadLocalFile'; import SearchFiles from './SearchFiles'; const RenderMap = { diff --git a/src/tools/local-system/Render/ReadLocalFile/ReadFileSkeleton.tsx b/src/tools/local-system/Render/ReadLocalFile/ReadFileSkeleton.tsx index b7b07610f2..2346296824 100644 --- a/src/tools/local-system/Render/ReadLocalFile/ReadFileSkeleton.tsx +++ b/src/tools/local-system/Render/ReadLocalFile/ReadFileSkeleton.tsx @@ -9,40 +9,30 @@ const useStyles = createStyles(({ css, token }) => ({ border: 1px solid ${token.colorBorderSecondary}; border-radius: ${token.borderRadiusLG}px; `, - header: css` - margin-block-end: 4px; - `, + meta: css` font-size: 12px; `, - path: css` - margin-block-start: 4px; - `, })); const ReadFileSkeleton = memo(() => { const { styles } = useStyles(); return ( - - + + - - + + + - + {/* Path */} - + ); }); diff --git a/src/tools/local-system/Render/ReadLocalFile/ReadFileView.tsx b/src/tools/local-system/Render/ReadLocalFile/ReadFileView.tsx index 42ccfbb8b1..efbd8c65f1 100644 --- a/src/tools/local-system/Render/ReadLocalFile/ReadFileView.tsx +++ b/src/tools/local-system/Render/ReadLocalFile/ReadFileView.tsx @@ -8,6 +8,8 @@ import { Flexbox } from 'react-layout-kit'; import FileIcon from '@/components/FileIcon'; import { localFileService } from '@/services/electron/localFileService'; +import { useElectronStore } from '@/store/electron'; +import { desktopStateSelectors } from '@/store/electron/selectors'; const useStyles = createStyles(({ css, token, cx }) => ({ actions: cx( @@ -20,11 +22,19 @@ const useStyles = createStyles(({ css, token, cx }) => ({ `, ), container: css` + justify-content: space-between; + + height: 64px; padding: 8px; border: 1px solid ${token.colorBorderSecondary}; border-radius: ${token.borderRadiusLG}px; + transition: all 0.2s ${token.motionEaseInOut}; + .local-file-actions { + opacity: 0; + } + &:hover { border-color: ${token.colorBorder}; @@ -48,10 +58,13 @@ const useStyles = createStyles(({ css, token, cx }) => ({ lineCount: css` color: ${token.colorTextQuaternary}; `, - meta: css` - font-size: 12px; - color: ${token.colorTextTertiary}; - `, + meta: cx( + 'local-file-actions', + css` + font-size: 12px; + color: ${token.colorTextTertiary}; + `, + ), path: css` margin-block-start: 4px; padding-inline: 4px; @@ -104,6 +117,8 @@ const ReadFileView = memo( localFileService.openLocalFolder({ isDirectory: false, path }); }; + const displayPath = useElectronStore(desktopStateSelectors.displayRelativePath(path)); + return ( ( onClick={handleToggleExpand} > - + {filename} @@ -174,7 +189,7 @@ const ReadFileView = memo( {/* Path */} - {path} + {displayPath} {isExpanded && ( diff --git a/src/tools/local-system/Render/ReadLocalFile/style.ts b/src/tools/local-system/Render/ReadLocalFile/style.ts deleted file mode 100644 index 502d3a9d43..0000000000 --- a/src/tools/local-system/Render/ReadLocalFile/style.ts +++ /dev/null @@ -1,37 +0,0 @@ -import { createStyles } from 'antd-style'; - -export const useStyles = createStyles(({ css, token }) => ({ - container: css` - overflow: hidden; - - max-width: 100%; - padding: 12px; - border: 1px solid ${token.colorBorderSecondary}; - border-radius: ${token.borderRadius}px; - `, - fileName: css` - color: ${token.colorTextSecondary}; - `, - meta: css` - font-size: 10px; - color: ${token.colorTextSecondary}; - `, - metaItem: css` - white-space: nowrap; - `, - path: css` - font-size: 12px; - line-height: 1; - `, - previewBox: css` - padding-block: 8px; - padding-inline: 12px; - border-radius: ${token.borderRadiusSM}px; - background: ${token.colorFillTertiary}; - `, - previewText: css` - font-family: ${token.fontFamilyCode}; - font-size: 12px; - color: ${token.colorTextSecondary}; - `, -})); diff --git a/src/tools/local-system/Render/SearchFiles/Result.tsx b/src/tools/local-system/Render/SearchFiles/Result.tsx index 8846566a82..c88da7ffcd 100644 --- a/src/tools/local-system/Render/SearchFiles/Result.tsx +++ b/src/tools/local-system/Render/SearchFiles/Result.tsx @@ -1,12 +1,13 @@ +import { ChatMessagePluginError } from '@lobechat/types'; import { Skeleton } from 'antd'; import { memo } from 'react'; import { Flexbox } from 'react-layout-kit'; import { useChatStore } from '@/store/chat'; import { chatToolSelectors } from '@/store/chat/selectors'; -import FileItem from '@/tools/local-system/components/FileItem'; -import { FileResult } from '@/tools/local-system/type'; -import { ChatMessagePluginError } from '@/types/message'; + +import FileItem from '../../components/FileItem'; +import { FileResult } from '../../type'; interface SearchFilesProps { messageId: string; @@ -29,7 +30,7 @@ const SearchFiles = memo(({ searchResults = [], messageId }) = } return ( - + {searchResults.map((item) => ( ))} diff --git a/src/tools/local-system/Render/SearchFiles/SearchQuery/SearchView.tsx b/src/tools/local-system/Render/SearchFiles/SearchQuery/SearchView.tsx index 4465701f98..fc1fdb1619 100644 --- a/src/tools/local-system/Render/SearchFiles/SearchQuery/SearchView.tsx +++ b/src/tools/local-system/Render/SearchFiles/SearchQuery/SearchView.tsx @@ -1,12 +1,10 @@ import { Icon } from '@lobehub/ui'; -import { Skeleton } from 'antd'; import { createStyles } from 'antd-style'; import { SearchIcon } from 'lucide-react'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import { Flexbox } from 'react-layout-kit'; -import { useIsMobile } from '@/hooks/useIsMobile'; import { shinyTextStylish } from '@/styles/loading'; const useStyles = createStyles(({ css, token }) => ({ @@ -41,16 +39,9 @@ interface SearchBarProps { const SearchBar = memo( ({ defaultQuery, resultsNumber, onEditingChange, searching }) => { const { t } = useTranslation('tool'); - const isMobile = useIsMobile(); const { styles, cx } = useStyles(); return ( - + ( {defaultQuery} - - <> -
{t('search.searchResult')}
- {searching ? : resultsNumber} - + +
{t('search.searchResult')}
+ {resultsNumber}
); diff --git a/src/tools/local-system/Render/SearchFiles/index.tsx b/src/tools/local-system/Render/SearchFiles/index.tsx index 126bf8a514..055fe33d4a 100644 --- a/src/tools/local-system/Render/SearchFiles/index.tsx +++ b/src/tools/local-system/Render/SearchFiles/index.tsx @@ -1,5 +1,6 @@ import { LocalSearchFilesParams } from '@lobechat/electron-client-ipc'; import { memo } from 'react'; +import { Flexbox } from 'react-layout-kit'; import { LocalFileSearchState } from '@/tools/local-system/type'; import { ChatMessagePluginError } from '@/types/message'; @@ -16,14 +17,14 @@ interface SearchFilesProps { const SearchFiles = memo(({ messageId, pluginError, args, pluginState }) => { return ( - <> + - + ); }); diff --git a/src/tools/local-system/type.ts b/src/tools/local-system/type.ts index 04b17f7e5c..1e6cc827f3 100644 --- a/src/tools/local-system/type.ts +++ b/src/tools/local-system/type.ts @@ -1,7 +1,13 @@ import { + EditLocalFileResult, + GetCommandOutputResult, + GlobFilesResult, + GrepContentResult, + KillCommandResult, LocalFileItem, LocalMoveFilesResultItem, LocalReadFileResult, + RunCommandResult, } from '@lobechat/electron-client-ipc'; export interface FileResult { @@ -49,3 +55,36 @@ export interface LocalRenameFileState { oldPath: string; success: boolean; } + +// Shell Command States +export interface RunCommandState { + message: string; + result: RunCommandResult; +} + +export interface GetCommandOutputState { + message: string; + result: GetCommandOutputResult; +} + +export interface KillCommandState { + message: string; + result: KillCommandResult; +} + +// Search & Find States +export interface GrepContentState { + message: string; + result: GrepContentResult; +} + +export interface GlobFilesState { + message: string; + result: GlobFilesResult; +} + +// Edit State +export interface EditLocalFileState { + message: string; + result: EditLocalFileResult; +}