From 519e03e87a8e0687f64a1db3eb260ffff03658f1 Mon Sep 17 00:00:00 2001 From: YuTengjing Date: Sun, 17 Aug 2025 13:50:05 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(provider):=20add=20BFL=20provi?= =?UTF-8?q?der=20support=20for=20image=20generation=20(#8806)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/settings.json | 2 +- package.json | 2 +- packages/const/src/image.ts | 10 +- .../model-runtime/src/bfl/createImage.test.ts | 846 ++++++++++++++++++ packages/model-runtime/src/bfl/createImage.ts | 279 ++++++ packages/model-runtime/src/bfl/index.test.ts | 269 ++++++ packages/model-runtime/src/bfl/index.ts | 49 + packages/model-runtime/src/bfl/types.ts | 113 +++ packages/model-runtime/src/index.ts | 1 + .../model-runtime/src/qwen/createImage.ts | 115 +-- packages/model-runtime/src/runtimeMap.ts | 2 + .../src/utils/asyncifyPolling.test.ts | 491 ++++++++++ .../src/utils/asyncifyPolling.ts | 175 ++++ .../@menu/features/ConfigPanel/index.tsx | 2 +- src/config/aiModels/bfl.ts | 145 +++ src/config/aiModels/index.ts | 3 + src/config/llm.ts | 7 +- src/config/modelProviders/bfl.ts | 21 + src/config/modelProviders/index.ts | 3 + .../image/slices/generationConfig/hooks.ts | 16 +- 20 files changed, 2456 insertions(+), 95 deletions(-) create mode 100644 packages/model-runtime/src/bfl/createImage.test.ts create mode 100644 packages/model-runtime/src/bfl/createImage.ts create mode 100644 packages/model-runtime/src/bfl/index.test.ts create mode 100644 packages/model-runtime/src/bfl/index.ts create mode 100644 packages/model-runtime/src/bfl/types.ts create mode 100644 packages/model-runtime/src/utils/asyncifyPolling.test.ts create mode 100644 packages/model-runtime/src/utils/asyncifyPolling.ts create mode 100644 src/config/aiModels/bfl.ts create mode 100644 src/config/modelProviders/bfl.ts diff --git a/.vscode/settings.json b/.vscode/settings.json index fba169ddde..83366d34fa 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -68,7 +68,7 @@ "**/src/config/modelProviders/*.ts": "${filename} • provider", "**/src/config/aiModels/*.ts": "${filename} • model", "**/src/config/paramsSchemas/*/*.json": "${dirname(1)}/${filename} • params", - "**/src/libs/model-runtime/*/index.ts": "${dirname} • runtime", + "**/packages/model-runtime/src/*/index.ts": "${dirname} • runtime", "**/src/server/services/*/index.ts": "${dirname} • server/service", "**/src/server/routers/lambda/*.ts": "${filename} • lambda", diff --git a/package.json b/package.json index 5af17fd61e..3c9f7f34e9 100644 --- a/package.json +++ b/package.json @@ -151,7 +151,7 @@ "@lobehub/charts": "^2.0.0", "@lobehub/chat-plugin-sdk": "^1.32.4", "@lobehub/chat-plugins-gateway": "^1.9.0", - "@lobehub/icons": "^2.25.0", + "@lobehub/icons": "^2.27.1", "@lobehub/market-sdk": "^0.22.7", "@lobehub/tts": "^2.0.1", "@lobehub/ui": "^2.8.3", diff --git a/packages/const/src/image.ts b/packages/const/src/image.ts index 48c097c8a4..d3057ac2cc 100644 --- a/packages/const/src/image.ts +++ b/packages/const/src/image.ts @@ -3,4 +3,12 @@ */ export const DEFAULT_ASPECT_RATIO = '1:1'; -export const PRESET_ASPECT_RATIOS = [DEFAULT_ASPECT_RATIO, '16:9', '9:16', '4:3', '3:4']; +export const PRESET_ASPECT_RATIOS = [ + DEFAULT_ASPECT_RATIO, // '1:1' - 正方形,最常用 + '16:9', // 现代显示器/电视/视频标准 + '9:16', // 手机竖屏/短视频 + '4:3', // 传统显示器/照片 + '3:4', // 传统竖屏照片 + '3:2', // 经典照片比例横屏 + '2:3', // 经典照片比例竖屏 +]; diff --git a/packages/model-runtime/src/bfl/createImage.test.ts b/packages/model-runtime/src/bfl/createImage.test.ts new file mode 100644 index 0000000000..a682efe7ec --- /dev/null +++ b/packages/model-runtime/src/bfl/createImage.test.ts @@ -0,0 +1,846 @@ +// @vitest-environment node +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +import { CreateImagePayload } from '@/libs/model-runtime/types/image'; + +import { createBflImage } from './createImage'; +import { BflStatusResponse } from './types'; + +// Mock external dependencies +vi.mock('@/utils/imageToBase64', () => ({ + imageUrlToBase64: vi.fn(), +})); + +vi.mock('../utils/uriParser', () => ({ + parseDataUri: vi.fn(), +})); + +vi.mock('../utils/asyncifyPolling', () => ({ + asyncifyPolling: vi.fn(), +})); + +// Mock fetch +global.fetch = vi.fn(); +const mockFetch = vi.mocked(fetch); + +// Mock the console.error to avoid polluting test output +vi.spyOn(console, 'error').mockImplementation(() => {}); + +const mockOptions = { + apiKey: 'test-api-key', + provider: 'bfl' as const, +}; + +beforeEach(() => { + vi.clearAllMocks(); +}); + +afterEach(() => { + vi.clearAllMocks(); +}); + +describe('createBflImage', () => { + describe('Parameter mapping and defaults', () => { + it('should map standard parameters to BFL-specific parameters', async () => { + // Arrange + const { asyncifyPolling } = await import('../utils/asyncifyPolling'); + const mockAsyncifyPolling = vi.mocked(asyncifyPolling); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + id: 'task-123', + polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123', + }), + } as Response); + + mockAsyncifyPolling.mockResolvedValue({ + imageUrl: 'https://example.com/result.jpg', + }); + + const payload: CreateImagePayload = { + model: 'flux-dev', + params: { + prompt: 'A beautiful landscape', + aspectRatio: '16:9', + cfg: 7.5, + steps: 20, + seed: 12345, + }, + }; + + // Act + await createBflImage(payload, mockOptions); + + // Assert + expect(mockFetch).toHaveBeenCalledWith( + 'https://api.bfl.ai/v1/flux-dev', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'x-key': 'test-api-key', + }, + body: JSON.stringify({ + output_format: 'png', + safety_tolerance: 6, + prompt: 'A beautiful landscape', + aspect_ratio: '16:9', + guidance: 7.5, + steps: 20, + seed: 12345, + }), + }), + ); + }); + + it('should add raw: true for ultra models', async () => { + // Arrange + const { asyncifyPolling } = await import('../utils/asyncifyPolling'); + const mockAsyncifyPolling = vi.mocked(asyncifyPolling); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + id: 'task-123', + polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123', + }), + } as Response); + + mockAsyncifyPolling.mockResolvedValue({ + imageUrl: 'https://example.com/result.jpg', + }); + + const payload: CreateImagePayload = { + model: 'flux-pro-1.1-ultra', + params: { + prompt: 'Ultra quality image', + }, + }; + + // Act + await createBflImage(payload, mockOptions); + + // Assert + expect(mockFetch).toHaveBeenCalledWith( + 'https://api.bfl.ai/v1/flux-pro-1.1-ultra', + expect.objectContaining({ + body: JSON.stringify({ + output_format: 'png', + safety_tolerance: 6, + raw: true, + prompt: 'Ultra quality image', + }), + }), + ); + }); + + it('should filter out undefined values', async () => { + // Arrange + const { asyncifyPolling } = await import('../utils/asyncifyPolling'); + const mockAsyncifyPolling = vi.mocked(asyncifyPolling); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + id: 'task-123', + polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123', + }), + } as Response); + + mockAsyncifyPolling.mockResolvedValue({ + imageUrl: 'https://example.com/result.jpg', + }); + + const payload: CreateImagePayload = { + model: 'flux-dev', + params: { + prompt: 'Test image', + cfg: undefined, + seed: 12345, + steps: undefined, + } as any, + }; + + // Act + await createBflImage(payload, mockOptions); + + // Assert + const callArgs = mockFetch.mock.calls[0][1]; + const requestBody = JSON.parse(callArgs?.body as string); + + expect(requestBody).toEqual({ + output_format: 'png', + safety_tolerance: 6, + prompt: 'Test image', + seed: 12345, + }); + + expect(requestBody).not.toHaveProperty('guidance'); + expect(requestBody).not.toHaveProperty('steps'); + }); + }); + + describe('Image URL handling', () => { + it('should convert single imageUrl to image_prompt base64', async () => { + // Arrange + const { parseDataUri } = await import('../utils/uriParser'); + const { imageUrlToBase64 } = await import('@/utils/imageToBase64'); + const { asyncifyPolling } = await import('../utils/asyncifyPolling'); + + const mockParseDataUri = vi.mocked(parseDataUri); + const mockImageUrlToBase64 = vi.mocked(imageUrlToBase64); + const mockAsyncifyPolling = vi.mocked(asyncifyPolling); + + mockParseDataUri.mockReturnValue({ type: 'url', base64: null, mimeType: null }); + mockImageUrlToBase64.mockResolvedValue({ + base64: 'base64EncodedImage', + mimeType: 'image/jpeg', + }); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + id: 'task-123', + polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123', + }), + } as Response); + + mockAsyncifyPolling.mockResolvedValue({ + imageUrl: 'https://example.com/result.jpg', + }); + + const payload: CreateImagePayload = { + model: 'flux-pro-1.1', + params: { + prompt: 'Transform this image', + imageUrl: 'https://example.com/input.jpg', + }, + }; + + // Act + await createBflImage(payload, mockOptions); + + // Assert + expect(mockParseDataUri).toHaveBeenCalledWith('https://example.com/input.jpg'); + expect(mockImageUrlToBase64).toHaveBeenCalledWith('https://example.com/input.jpg'); + + const callArgs = mockFetch.mock.calls[0][1]; + const requestBody = JSON.parse(callArgs?.body as string); + + expect(requestBody).toEqual({ + output_format: 'png', + safety_tolerance: 6, + prompt: 'Transform this image', + image_prompt: 'base64EncodedImage', + }); + + expect(requestBody).not.toHaveProperty('imageUrl'); + }); + + it('should handle base64 imageUrl directly', async () => { + // Arrange + const { parseDataUri } = await import('../utils/uriParser'); + const { asyncifyPolling } = await import('../utils/asyncifyPolling'); + + const mockParseDataUri = vi.mocked(parseDataUri); + const mockAsyncifyPolling = vi.mocked(asyncifyPolling); + + mockParseDataUri.mockReturnValue({ + type: 'base64', + base64: '/9j/4AAQSkZJRgABAQEAYABgAAD', + mimeType: 'image/jpeg', + }); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + id: 'task-123', + polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123', + }), + } as Response); + + mockAsyncifyPolling.mockResolvedValue({ + imageUrl: 'https://example.com/result.jpg', + }); + + const base64Image = 'data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEAYABgAAD'; + const payload: CreateImagePayload = { + model: 'flux-pro-1.1', + params: { + prompt: 'Transform this image', + imageUrl: base64Image, + }, + }; + + // Act + await createBflImage(payload, mockOptions); + + // Assert + const callArgs = mockFetch.mock.calls[0][1]; + const requestBody = JSON.parse(callArgs?.body as string); + + expect(requestBody.image_prompt).toBe('/9j/4AAQSkZJRgABAQEAYABgAAD'); + }); + + it('should convert multiple imageUrls for Kontext models', async () => { + // Arrange + const { parseDataUri } = await import('../utils/uriParser'); + const { imageUrlToBase64 } = await import('@/utils/imageToBase64'); + const { asyncifyPolling } = await import('../utils/asyncifyPolling'); + + const mockParseDataUri = vi.mocked(parseDataUri); + const mockImageUrlToBase64 = vi.mocked(imageUrlToBase64); + const mockAsyncifyPolling = vi.mocked(asyncifyPolling); + + mockParseDataUri.mockReturnValue({ type: 'url', base64: null, mimeType: null }); + mockImageUrlToBase64 + .mockResolvedValueOnce({ base64: 'base64image1', mimeType: 'image/jpeg' }) + .mockResolvedValueOnce({ base64: 'base64image2', mimeType: 'image/jpeg' }) + .mockResolvedValueOnce({ base64: 'base64image3', mimeType: 'image/jpeg' }); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + id: 'task-123', + polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123', + }), + } as Response); + + mockAsyncifyPolling.mockResolvedValue({ + imageUrl: 'https://example.com/result.jpg', + }); + + const payload: CreateImagePayload = { + model: 'flux-kontext-pro', + params: { + prompt: 'Create variation of these images', + imageUrls: [ + 'https://example.com/input1.jpg', + 'https://example.com/input2.jpg', + 'https://example.com/input3.jpg', + ], + }, + }; + + // Act + await createBflImage(payload, mockOptions); + + // Assert + const callArgs = mockFetch.mock.calls[0][1]; + const requestBody = JSON.parse(callArgs?.body as string); + + expect(requestBody).toEqual({ + output_format: 'png', + safety_tolerance: 6, + prompt: 'Create variation of these images', + input_image: 'base64image1', + input_image_2: 'base64image2', + input_image_3: 'base64image3', + }); + + expect(requestBody).not.toHaveProperty('imageUrls'); + }); + + it('should limit imageUrls to maximum 4 images', async () => { + // Arrange + const { parseDataUri } = await import('../utils/uriParser'); + const { imageUrlToBase64 } = await import('@/utils/imageToBase64'); + const { asyncifyPolling } = await import('../utils/asyncifyPolling'); + + const mockParseDataUri = vi.mocked(parseDataUri); + const mockImageUrlToBase64 = vi.mocked(imageUrlToBase64); + const mockAsyncifyPolling = vi.mocked(asyncifyPolling); + + mockParseDataUri.mockReturnValue({ type: 'url', base64: null, mimeType: null }); + mockImageUrlToBase64 + .mockResolvedValueOnce({ base64: 'base64image1', mimeType: 'image/jpeg' }) + .mockResolvedValueOnce({ base64: 'base64image2', mimeType: 'image/jpeg' }) + .mockResolvedValueOnce({ base64: 'base64image3', mimeType: 'image/jpeg' }) + .mockResolvedValueOnce({ base64: 'base64image4', mimeType: 'image/jpeg' }); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + id: 'task-123', + polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123', + }), + } as Response); + + mockAsyncifyPolling.mockResolvedValue({ + imageUrl: 'https://example.com/result.jpg', + }); + + const payload: CreateImagePayload = { + model: 'flux-kontext-max', + params: { + prompt: 'Create variation of these images', + imageUrls: [ + 'https://example.com/input1.jpg', + 'https://example.com/input2.jpg', + 'https://example.com/input3.jpg', + 'https://example.com/input4.jpg', + 'https://example.com/input5.jpg', // This should be ignored + ], + }, + }; + + // Act + await createBflImage(payload, mockOptions); + + // Assert + expect(mockImageUrlToBase64).toHaveBeenCalledTimes(4); + + const callArgs = mockFetch.mock.calls[0][1]; + const requestBody = JSON.parse(callArgs?.body as string); + + expect(requestBody).toEqual({ + output_format: 'png', + safety_tolerance: 6, + prompt: 'Create variation of these images', + input_image: 'base64image1', + input_image_2: 'base64image2', + input_image_3: 'base64image3', + input_image_4: 'base64image4', + }); + + expect(requestBody).not.toHaveProperty('input_image_5'); + }); + }); + + describe('Model endpoint mapping', () => { + it('should map models to correct endpoints', async () => { + // Arrange + const { asyncifyPolling } = await import('../utils/asyncifyPolling'); + const mockAsyncifyPolling = vi.mocked(asyncifyPolling); + + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + id: 'task-123', + polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123', + }), + } as Response); + + mockAsyncifyPolling.mockResolvedValue({ + imageUrl: 'https://example.com/result.jpg', + }); + + const testCases = [ + { model: 'flux-dev', endpoint: '/v1/flux-dev' }, + { model: 'flux-pro', endpoint: '/v1/flux-pro' }, + { model: 'flux-pro-1.1', endpoint: '/v1/flux-pro-1.1' }, + { model: 'flux-pro-1.1-ultra', endpoint: '/v1/flux-pro-1.1-ultra' }, + { model: 'flux-kontext-pro', endpoint: '/v1/flux-kontext-pro' }, + { model: 'flux-kontext-max', endpoint: '/v1/flux-kontext-max' }, + ]; + + // Act & Assert + for (const { model, endpoint } of testCases) { + vi.clearAllMocks(); + + const payload: CreateImagePayload = { + model, + params: { + prompt: `Test image for ${model}`, + }, + }; + + await createBflImage(payload, mockOptions); + + expect(mockFetch).toHaveBeenCalledWith(`https://api.bfl.ai${endpoint}`, expect.any(Object)); + } + }); + + it('should throw error for unsupported model', async () => { + // Arrange + const payload: CreateImagePayload = { + model: 'unsupported-model', + params: { + prompt: 'Test image', + }, + }; + + // Act & Assert + await expect(createBflImage(payload, mockOptions)).rejects.toMatchObject({ + error: expect.objectContaining({ + message: 'Unsupported BFL model: unsupported-model', + }), + errorType: 'ModelNotFound', + provider: 'bfl', + }); + }); + + it('should use custom baseURL when provided', async () => { + // Arrange + const { asyncifyPolling } = await import('../utils/asyncifyPolling'); + const mockAsyncifyPolling = vi.mocked(asyncifyPolling); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + id: 'task-123', + polling_url: 'https://custom-api.bfl.ai/v1/get_result?id=task-123', + }), + } as Response); + + mockAsyncifyPolling.mockResolvedValue({ + imageUrl: 'https://example.com/result.jpg', + }); + + const customOptions = { + ...mockOptions, + baseURL: 'https://custom-api.bfl.ai', + }; + + const payload: CreateImagePayload = { + model: 'flux-dev', + params: { + prompt: 'Test with custom URL', + }, + }; + + // Act + await createBflImage(payload, customOptions); + + // Assert + expect(mockFetch).toHaveBeenCalledWith( + 'https://custom-api.bfl.ai/v1/flux-dev', + expect.any(Object), + ); + }); + }); + + describe('Status handling', () => { + it('should return success when status is Ready with result', async () => { + // Arrange + const { asyncifyPolling } = await import('../utils/asyncifyPolling'); + const mockAsyncifyPolling = vi.mocked(asyncifyPolling); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + id: 'task-123', + polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123', + }), + } as Response); + + // Mock the asyncifyPolling to call checkStatus with Ready status + mockAsyncifyPolling.mockImplementation(async ({ checkStatus }) => { + const result = checkStatus({ + id: 'task-123', + status: BflStatusResponse.Ready, + result: { + sample: 'https://example.com/generated-image.jpg', + }, + }); + + if (result.status === 'success') { + return result.data; + } + throw result.error; + }); + + const payload: CreateImagePayload = { + model: 'flux-dev', + params: { + prompt: 'Test successful generation', + }, + }; + + // Act + const result = await createBflImage(payload, mockOptions); + + // Assert + expect(result).toEqual({ + imageUrl: 'https://example.com/generated-image.jpg', + }); + }); + + it('should throw error when status is Ready but no result', async () => { + // Arrange + const { asyncifyPolling } = await import('../utils/asyncifyPolling'); + const mockAsyncifyPolling = vi.mocked(asyncifyPolling); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + id: 'task-123', + polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123', + }), + } as Response); + + mockAsyncifyPolling.mockImplementation(async ({ checkStatus }) => { + const result = checkStatus({ + id: 'task-123', + status: BflStatusResponse.Ready, + result: null, + }); + + if (result.status === 'success') { + return result.data; + } + throw result.error; + }); + + const payload: CreateImagePayload = { + model: 'flux-dev', + params: { + prompt: 'Test no result error', + }, + }; + + // Act & Assert + await expect(createBflImage(payload, mockOptions)).rejects.toMatchObject({ + error: expect.any(Object), + errorType: 'ProviderBizError', + provider: 'bfl', + }); + }); + + it('should handle error statuses', async () => { + // Arrange + const { asyncifyPolling } = await import('../utils/asyncifyPolling'); + const mockAsyncifyPolling = vi.mocked(asyncifyPolling); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + id: 'task-123', + polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123', + }), + } as Response); + + const errorStatuses = [ + BflStatusResponse.Error, + BflStatusResponse.ContentModerated, + BflStatusResponse.RequestModerated, + ]; + + for (const status of errorStatuses) { + mockAsyncifyPolling.mockImplementation(async ({ checkStatus }) => { + const result = checkStatus({ + id: 'task-123', + status, + details: { error: 'Test error details' }, + }); + + if (result.status === 'success') { + return result.data; + } + throw result.error; + }); + + const payload: CreateImagePayload = { + model: 'flux-dev', + params: { + prompt: `Test ${status} error`, + }, + }; + + // Act & Assert + await expect(createBflImage(payload, mockOptions)).rejects.toMatchObject({ + error: expect.any(Object), + errorType: 'ProviderBizError', + provider: 'bfl', + }); + } + }); + + it('should handle TaskNotFound status', async () => { + // Arrange + const { asyncifyPolling } = await import('../utils/asyncifyPolling'); + const mockAsyncifyPolling = vi.mocked(asyncifyPolling); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + id: 'task-123', + polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123', + }), + } as Response); + + mockAsyncifyPolling.mockImplementation(async ({ checkStatus }) => { + const result = checkStatus({ + id: 'task-123', + status: BflStatusResponse.TaskNotFound, + }); + + if (result.status === 'success') { + return result.data; + } + throw result.error; + }); + + const payload: CreateImagePayload = { + model: 'flux-dev', + params: { + prompt: 'Test task not found', + }, + }; + + // Act & Assert + await expect(createBflImage(payload, mockOptions)).rejects.toMatchObject({ + error: expect.any(Object), + errorType: 'ProviderBizError', + provider: 'bfl', + }); + }); + + it('should continue polling for Pending status', async () => { + // Arrange + const { asyncifyPolling } = await import('../utils/asyncifyPolling'); + const mockAsyncifyPolling = vi.mocked(asyncifyPolling); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + id: 'task-123', + polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123', + }), + } as Response); + + mockAsyncifyPolling.mockImplementation(async ({ checkStatus }) => { + // First call - Pending status + const pendingResult = checkStatus({ + id: 'task-123', + status: BflStatusResponse.Pending, + }); + + expect(pendingResult.status).toBe('pending'); + + // Simulate successful completion + const successResult = checkStatus({ + id: 'task-123', + status: BflStatusResponse.Ready, + result: { + sample: 'https://example.com/generated-image.jpg', + }, + }); + + return successResult.data; + }); + + const payload: CreateImagePayload = { + model: 'flux-dev', + params: { + prompt: 'Test pending status', + }, + }; + + // Act + const result = await createBflImage(payload, mockOptions); + + // Assert + expect(result).toEqual({ + imageUrl: 'https://example.com/generated-image.jpg', + }); + }); + }); + + describe('Error handling', () => { + it('should handle fetch errors during task submission', async () => { + // Arrange + mockFetch.mockRejectedValue(new Error('Network error')); + + const payload: CreateImagePayload = { + model: 'flux-dev', + params: { + prompt: 'Test network error', + }, + }; + + // Act & Assert + await expect(createBflImage(payload, mockOptions)).rejects.toThrow(); + }); + + it('should handle HTTP error responses', async () => { + // Arrange + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 400, + statusText: 'Bad Request', + json: () => + Promise.resolve({ + detail: [{ msg: 'Invalid prompt' }], + }), + } as Response); + + const payload: CreateImagePayload = { + model: 'flux-dev', + params: { + prompt: 'Test HTTP error', + }, + }; + + // Act & Assert + await expect(createBflImage(payload, mockOptions)).rejects.toMatchObject({ + error: expect.any(Object), + errorType: 'ProviderBizError', + provider: 'bfl', + }); + }); + + it('should handle HTTP error responses without detail', async () => { + // Arrange + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 500, + statusText: 'Internal Server Error', + json: () => Promise.resolve({}), + } as Response); + + const payload: CreateImagePayload = { + model: 'flux-dev', + params: { + prompt: 'Test HTTP error without detail', + }, + }; + + // Act & Assert + await expect(createBflImage(payload, mockOptions)).rejects.toMatchObject({ + error: expect.any(Object), + errorType: 'ProviderBizError', + provider: 'bfl', + }); + }); + + it('should handle non-JSON error responses', async () => { + // Arrange + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 500, + statusText: 'Internal Server Error', + json: () => Promise.reject(new Error('Invalid JSON')), + } as Response); + + const payload: CreateImagePayload = { + model: 'flux-dev', + params: { + prompt: 'Test non-JSON error', + }, + }; + + // Act & Assert + await expect(createBflImage(payload, mockOptions)).rejects.toMatchObject({ + error: expect.any(Object), + errorType: 'ProviderBizError', + provider: 'bfl', + }); + }); + }); +}); diff --git a/packages/model-runtime/src/bfl/createImage.ts b/packages/model-runtime/src/bfl/createImage.ts new file mode 100644 index 0000000000..12a7073a11 --- /dev/null +++ b/packages/model-runtime/src/bfl/createImage.ts @@ -0,0 +1,279 @@ +import createDebug from 'debug'; + +import { RuntimeImageGenParamsValue } from '@/libs/standard-parameters/index'; +import { imageUrlToBase64 } from '@/utils/imageToBase64'; + +import { AgentRuntimeErrorType } from '../error'; +import { CreateImagePayload, CreateImageResponse } from '../types/image'; +import { type TaskResult, asyncifyPolling } from '../utils/asyncifyPolling'; +import { AgentRuntimeError } from '../utils/createError'; +import { parseDataUri } from '../utils/uriParser'; +import { + BFL_ENDPOINTS, + BflAsyncResponse, + BflModelId, + BflRequest, + BflResultResponse, + BflStatusResponse, +} from './types'; + +const log = createDebug('lobe-image:bfl'); + +const BASE_URL = 'https://api.bfl.ai'; + +interface BflCreateImageOptions { + apiKey: string; + baseURL?: string; + provider: string; +} + +/** + * Convert image URL to base64 format required by BFL API + */ +async function convertImageToBase64(imageUrl: string): Promise { + try { + const { type } = parseDataUri(imageUrl); + + if (type === 'base64') { + // Already in base64 format, extract the base64 part + const base64Match = imageUrl.match(/^data:[^;]+;base64,(.+)$/); + if (base64Match) { + return base64Match[1]; + } + throw new Error('Invalid base64 format'); + } + + if (type === 'url') { + // Convert URL to base64 + const { base64 } = await imageUrlToBase64(imageUrl); + return base64; + } + + throw new Error(`Invalid image URL format: ${imageUrl}`); + } catch (error) { + log('Error converting image to base64: %O', error); + throw error; + } +} + +/** + * Build request payload for different BFL models + */ +async function buildRequestPayload( + model: BflModelId, + params: CreateImagePayload['params'], +): Promise { + log('Building request payload for model: %s', model); + + // Define parameter mapping (BFL API specific) + const paramsMap = new Map([ + ['aspectRatio', 'aspect_ratio'], + ['cfg', 'guidance'], + ]); + + // Fixed parameters for all BFL models + const defaultPayload: Record = { + output_format: 'png', + safety_tolerance: 6, + ...(model.includes('ultra') && { raw: true }), + }; + + // Map user parameters, filtering out undefined values + const userPayload: Record = Object.fromEntries( + (Object.entries(params) as [keyof typeof params, any][]) + .filter(([, value]) => value !== undefined) + .map(([key, value]) => [paramsMap.get(key) ?? key, value]), + ); + + // Handle multiple input images (imageUrls) for Kontext models + if (params.imageUrls && params.imageUrls.length > 0) { + for (let i = 0; i < Math.min(params.imageUrls.length, 4); i++) { + const fieldName = i === 0 ? 'input_image' : `input_image_${i + 1}`; + userPayload[fieldName] = await convertImageToBase64(params.imageUrls[i]); + } + // Remove the original imageUrls field as it's now mapped to input_image_* + delete userPayload.imageUrls; + } + + // Handle single image input (imageUrl) + if (params.imageUrl) { + userPayload.image_prompt = await convertImageToBase64(params.imageUrl); + // Remove the original imageUrl field as it's now mapped to image_prompt + delete userPayload.imageUrl; + } + + // Combine default and user payload + const payload = { + ...defaultPayload, + ...userPayload, + }; + + return payload as BflRequest; +} + +/** + * Submit image generation task to BFL API + */ +async function submitTask( + model: BflModelId, + payload: BflRequest, + options: BflCreateImageOptions, +): Promise { + const endpoint = BFL_ENDPOINTS[model]; + const url = `${options.baseURL || BASE_URL}${endpoint}`; + + log('Submitting task to: %s', url); + + const response = await fetch(url, { + body: JSON.stringify(payload), + headers: { + 'Content-Type': 'application/json', + 'x-key': options.apiKey, + }, + method: 'POST', + }); + + if (!response.ok) { + let errorData; + try { + errorData = await response.json(); + } catch { + // Failed to parse JSON error response + } + + throw new Error( + `BFL API error (${response.status}): ${errorData?.detail?.[0]?.msg || response.statusText}`, + ); + } + + const data: BflAsyncResponse = await response.json(); + log('Task submitted successfully with ID: %s', data.id); + + return data; +} + +/** + * Query task status using BFL API + */ +async function queryTaskStatus( + pollingUrl: string, + options: BflCreateImageOptions, +): Promise { + log('Querying task status using polling URL: %s', pollingUrl); + + const response = await fetch(pollingUrl, { + headers: { + 'accept': 'application/json', + 'x-key': options.apiKey, + }, + method: 'GET', + }); + + if (!response.ok) { + let errorData; + try { + errorData = await response.json(); + } catch { + // Failed to parse JSON error response + } + + throw new Error( + `Failed to query task status (${response.status}): ${errorData?.detail?.[0]?.msg || response.statusText}`, + ); + } + + return response.json(); +} + +/** + * Create image using BFL API with async task polling + */ +export async function createBflImage( + payload: CreateImagePayload, + options: BflCreateImageOptions, +): Promise { + const { model, params } = payload; + + if (!BFL_ENDPOINTS[model as BflModelId]) { + throw AgentRuntimeError.createImage({ + error: new Error(`Unsupported BFL model: ${model}`), + errorType: AgentRuntimeErrorType.ModelNotFound, + provider: options.provider, + }); + } + + try { + // 1. Build request payload + const requestPayload = await buildRequestPayload(model as BflModelId, params); + + // 2. Submit image generation task + const taskResponse = await submitTask(model as BflModelId, requestPayload, options); + + // 3. Poll task status until completion using asyncifyPolling + return await asyncifyPolling({ + checkStatus: (taskStatus: BflResultResponse): TaskResult => { + log('Task %s status: %s', taskResponse.id, taskStatus.status); + + switch (taskStatus.status) { + case BflStatusResponse.Ready: { + if (!taskStatus.result?.sample) { + return { + error: new Error('Task succeeded but no image generated'), + status: 'failed', + }; + } + + const imageUrl = taskStatus.result.sample; + log('Image generated successfully: %s', imageUrl); + + return { + data: { imageUrl }, + status: 'success', + }; + } + case BflStatusResponse.Error: + case BflStatusResponse.ContentModerated: + case BflStatusResponse.RequestModerated: { + // Extract error details if available, otherwise use status + let errorMessage = `Image generation failed with status: ${taskStatus.status}`; + + // Check for additional error details in various possible fields + if (taskStatus.details && typeof taskStatus.details === 'object') { + errorMessage += ` - Details: ${JSON.stringify(taskStatus.details)}`; + } else if (taskStatus.result && typeof taskStatus.result === 'object') { + errorMessage += ` - Result: ${JSON.stringify(taskStatus.result)}`; + } + + return { + error: new Error(errorMessage), + status: 'failed', + }; + } + case BflStatusResponse.TaskNotFound: { + return { + error: new Error('Task not found - may have expired'), + status: 'failed', + }; + } + default: { + // Continue polling for Pending status or other unknown statuses + return { status: 'pending' }; + } + } + }, + logger: { + debug: (message: any, ...args: any[]) => log(message, ...args), + error: (message: any, ...args: any[]) => log(message, ...args), + }, + pollingQuery: () => queryTaskStatus(taskResponse.polling_url, options), + }); + } catch (error) { + log('Error in createBflImage: %O', error); + + throw AgentRuntimeError.createImage({ + error: error as any, + errorType: 'ProviderBizError', + provider: options.provider, + }); + } +} diff --git a/packages/model-runtime/src/bfl/index.test.ts b/packages/model-runtime/src/bfl/index.test.ts new file mode 100644 index 0000000000..fc45bef9e2 --- /dev/null +++ b/packages/model-runtime/src/bfl/index.test.ts @@ -0,0 +1,269 @@ +// @vitest-environment node +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +import { CreateImagePayload } from '@/libs/model-runtime/types/image'; + +import { LobeBflAI } from './index'; + +// Mock the createBflImage function +vi.mock('./createImage', () => ({ + createBflImage: vi.fn(), +})); + +// Mock the console.error to avoid polluting test output +vi.spyOn(console, 'error').mockImplementation(() => {}); + +const bizErrorType = 'ProviderBizError'; +const invalidErrorType = 'InvalidProviderAPIKey'; + +let instance: LobeBflAI; + +beforeEach(() => { + vi.clearAllMocks(); + instance = new LobeBflAI({ apiKey: 'test-api-key' }); +}); + +afterEach(() => { + vi.clearAllMocks(); +}); + +describe('LobeBflAI', () => { + describe('init', () => { + it('should correctly initialize with an API key', () => { + const instance = new LobeBflAI({ apiKey: 'test_api_key' }); + expect(instance).toBeInstanceOf(LobeBflAI); + }); + + it('should initialize with custom baseURL', () => { + const customBaseURL = 'https://custom-api.bfl.ai'; + const instance = new LobeBflAI({ + apiKey: 'test_api_key', + baseURL: customBaseURL, + }); + expect(instance).toBeInstanceOf(LobeBflAI); + }); + + it('should throw InvalidProviderAPIKey if no apiKey is provided', () => { + expect(() => { + new LobeBflAI({}); + }).toThrow(); + }); + + it('should throw InvalidProviderAPIKey if apiKey is undefined', () => { + expect(() => { + new LobeBflAI({ apiKey: undefined }); + }).toThrow(); + }); + }); + + describe('createImage', () => { + let mockCreateBflImage: any; + + beforeEach(async () => { + const { createBflImage } = await import('./createImage'); + mockCreateBflImage = vi.mocked(createBflImage); + }); + + it('should create image successfully with basic parameters', async () => { + // Arrange + const mockImageResponse = { + imageUrl: 'https://example.com/generated-image.jpg', + }; + mockCreateBflImage.mockResolvedValue(mockImageResponse); + + const payload: CreateImagePayload = { + model: 'flux-dev', + params: { + prompt: 'A beautiful landscape with mountains', + width: 1024, + height: 1024, + }, + }; + + // Act + const result = await instance.createImage(payload); + + // Assert + expect(mockCreateBflImage).toHaveBeenCalledWith(payload, { + apiKey: 'test-api-key', + baseURL: undefined, + provider: 'bfl', + }); + expect(result).toEqual(mockImageResponse); + }); + + it('should pass custom baseURL to createBflImage', async () => { + // Arrange + const customBaseURL = 'https://custom-api.bfl.ai'; + const customInstance = new LobeBflAI({ + apiKey: 'test-api-key', + baseURL: customBaseURL, + }); + + const mockImageResponse = { + imageUrl: 'https://example.com/generated-image.jpg', + }; + mockCreateBflImage.mockResolvedValue(mockImageResponse); + + const payload: CreateImagePayload = { + model: 'flux-pro', + params: { + prompt: 'Test image', + }, + }; + + // Act + await customInstance.createImage(payload); + + // Assert + expect(mockCreateBflImage).toHaveBeenCalledWith(payload, { + apiKey: 'test-api-key', + baseURL: customBaseURL, + provider: 'bfl', + }); + }); + + describe('Error handling', () => { + it('should throw InvalidProviderAPIKey on 401 error', async () => { + // Arrange + const apiError = new Error('Unauthorized') as Error & { status: number }; + apiError.status = 401; + mockCreateBflImage.mockRejectedValue(apiError); + + const payload: CreateImagePayload = { + model: 'flux-dev', + params: { + prompt: 'Test image', + }, + }; + + // Act & Assert + await expect(instance.createImage(payload)).rejects.toEqual({ + error: { error: apiError }, + errorType: invalidErrorType, + }); + }); + + it('should throw ProviderBizError on other errors', async () => { + // Arrange + const apiError = new Error('Some other error'); + mockCreateBflImage.mockRejectedValue(apiError); + + const payload: CreateImagePayload = { + model: 'flux-dev', + params: { + prompt: 'Test image', + }, + }; + + // Act & Assert + await expect(instance.createImage(payload)).rejects.toEqual({ + error: { error: apiError }, + errorType: bizErrorType, + }); + }); + + it('should throw ProviderBizError on non-401 status errors', async () => { + // Arrange + const apiError = new Error('Server error') as Error & { status: number }; + apiError.status = 500; + mockCreateBflImage.mockRejectedValue(apiError); + + const payload: CreateImagePayload = { + model: 'flux-dev', + params: { + prompt: 'Test image', + }, + }; + + // Act & Assert + await expect(instance.createImage(payload)).rejects.toEqual({ + error: { error: apiError }, + errorType: bizErrorType, + }); + }); + + it('should throw ProviderBizError on errors without status property', async () => { + // Arrange + const apiError = new Error('Network error'); + mockCreateBflImage.mockRejectedValue(apiError); + + const payload: CreateImagePayload = { + model: 'flux-pro-1.1', + params: { + prompt: 'Test image', + }, + }; + + // Act & Assert + await expect(instance.createImage(payload)).rejects.toEqual({ + error: { error: apiError }, + errorType: bizErrorType, + }); + }); + }); + + describe('Edge cases', () => { + it('should handle different model types', async () => { + // Arrange + const mockImageResponse = { + imageUrl: 'https://example.com/generated-image.jpg', + }; + mockCreateBflImage.mockResolvedValue(mockImageResponse); + + const models = [ + 'flux-dev', + 'flux-pro', + 'flux-pro-1.1', + 'flux-pro-1.1-ultra', + 'flux-kontext-pro', + 'flux-kontext-max', + ]; + + // Act & Assert + for (const model of models) { + const payload: CreateImagePayload = { + model, + params: { + prompt: `Test image for ${model}`, + }, + }; + + await instance.createImage(payload); + + expect(mockCreateBflImage).toHaveBeenCalledWith(payload, { + apiKey: 'test-api-key', + baseURL: undefined, + provider: 'bfl', + }); + } + }); + + it('should handle empty params object', async () => { + // Arrange + const mockImageResponse = { + imageUrl: 'https://example.com/generated-image.jpg', + }; + mockCreateBflImage.mockResolvedValue(mockImageResponse); + + const payload: CreateImagePayload = { + model: 'flux-dev', + params: { + prompt: 'Empty params test', + }, + }; + + // Act + const result = await instance.createImage(payload); + + // Assert + expect(mockCreateBflImage).toHaveBeenCalledWith(payload, { + apiKey: 'test-api-key', + baseURL: undefined, + provider: 'bfl', + }); + expect(result).toEqual(mockImageResponse); + }); + }); + }); +}); diff --git a/packages/model-runtime/src/bfl/index.ts b/packages/model-runtime/src/bfl/index.ts new file mode 100644 index 0000000000..c00d5b38f4 --- /dev/null +++ b/packages/model-runtime/src/bfl/index.ts @@ -0,0 +1,49 @@ +import createDebug from 'debug'; +import { ClientOptions } from 'openai'; + +import { LobeRuntimeAI } from '../BaseAI'; +import { AgentRuntimeErrorType } from '../error'; +import { CreateImagePayload, CreateImageResponse } from '../types/image'; +import { AgentRuntimeError } from '../utils/createError'; +import { createBflImage } from './createImage'; + +const log = createDebug('lobe-image:bfl'); + +export class LobeBflAI implements LobeRuntimeAI { + private apiKey: string; + baseURL?: string; + + constructor({ apiKey, baseURL }: ClientOptions = {}) { + if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey); + + this.apiKey = apiKey; + this.baseURL = baseURL || undefined; + + log('BFL AI initialized'); + } + + async createImage(payload: CreateImagePayload): Promise { + const { model, params } = payload; + log('Creating image with model: %s and params: %O', model, params); + + try { + return await createBflImage(payload, { + apiKey: this.apiKey, + baseURL: this.baseURL, + provider: 'bfl', + }); + } catch (error) { + log('Error in createImage: %O', error); + + // Check for authentication errors based on HTTP status or error properties + if (error instanceof Error && 'status' in error && (error as any).status === 401) { + throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey, { + error, + }); + } + + // Wrap other errors + throw AgentRuntimeError.createError(AgentRuntimeErrorType.ProviderBizError, { error }); + } + } +} diff --git a/packages/model-runtime/src/bfl/types.ts b/packages/model-runtime/src/bfl/types.ts new file mode 100644 index 0000000000..5a281017a6 --- /dev/null +++ b/packages/model-runtime/src/bfl/types.ts @@ -0,0 +1,113 @@ +// BFL API Types + +export enum BflStatusResponse { + ContentModerated = 'Content Moderated', + Error = 'Error', + Pending = 'Pending', + Ready = 'Ready', + RequestModerated = 'Request Moderated', + TaskNotFound = 'Task not found', +} + +export interface BflAsyncResponse { + id: string; + polling_url: string; +} + +export interface BflAsyncWebhookResponse { + id: string; + status: string; + webhook_url: string; +} + +export interface BflResultResponse { + details?: Record | null; + id: string; + preview?: Record | null; + progress?: number | null; + result?: any; + status: BflStatusResponse; +} + +// Kontext series request (flux-kontext-pro, flux-kontext-max) +export interface BflFluxKontextRequest { + aspect_ratio?: string | null; + input_image?: string | null; + input_image_2?: string | null; + input_image_3?: string | null; + input_image_4?: string | null; + output_format?: 'jpeg' | 'png' | null; + prompt: string; + prompt_upsampling?: boolean; + safety_tolerance?: number; + seed?: number | null; + webhook_secret?: string | null; + webhook_url?: string | null; +} + +// FLUX 1.1 Pro request +export interface BflFluxPro11Request { + height?: number; + image_prompt?: string | null; + output_format?: 'jpeg' | 'png' | null; + prompt?: string | null; + prompt_upsampling?: boolean; + safety_tolerance?: number; + seed?: number | null; + webhook_secret?: string | null; + webhook_url?: string | null; + width?: number; +} + +// FLUX 1.1 Pro Ultra request +export interface BflFluxPro11UltraRequest { + aspect_ratio?: string; + prompt: string; + raw?: boolean; + safety_tolerance?: number; + seed?: number | null; +} + +// FLUX Pro request +export interface BflFluxProRequest { + guidance?: number; + height?: number; + image_prompt?: string | null; + prompt?: string | null; + safety_tolerance?: number; + seed?: number | null; + steps?: number; + width?: number; +} + +// FLUX Dev request +export interface BflFluxDevRequest { + guidance?: number; + height?: number; + image_prompt?: string | null; + prompt: string; + safety_tolerance?: number; + seed?: number | null; + steps?: number; + width?: number; +} + +// Model endpoint mapping +export const BFL_ENDPOINTS = { + 'flux-dev': '/v1/flux-dev', + 'flux-kontext-max': '/v1/flux-kontext-max', + 'flux-kontext-pro': '/v1/flux-kontext-pro', + 'flux-pro': '/v1/flux-pro', + 'flux-pro-1.1': '/v1/flux-pro-1.1', + 'flux-pro-1.1-ultra': '/v1/flux-pro-1.1-ultra', +} as const; + +export type BflModelId = keyof typeof BFL_ENDPOINTS; + +// Union type for all request types +export type BflRequest = + | BflFluxKontextRequest + | BflFluxPro11Request + | BflFluxPro11UltraRequest + | BflFluxProRequest + | BflFluxDevRequest; diff --git a/packages/model-runtime/src/index.ts b/packages/model-runtime/src/index.ts index f2ccadbf48..ff61def98b 100644 --- a/packages/model-runtime/src/index.ts +++ b/packages/model-runtime/src/index.ts @@ -3,6 +3,7 @@ export { LobeAzureAI } from './azureai'; export { LobeAzureOpenAI } from './azureOpenai'; export * from './BaseAI'; export { LobeBedrockAI } from './bedrock'; +export { LobeBflAI } from './bfl'; export { LobeDeepSeekAI } from './deepseek'; export * from './error'; export { LobeGoogleAI } from './google'; diff --git a/packages/model-runtime/src/qwen/createImage.ts b/packages/model-runtime/src/qwen/createImage.ts index 20ad8ff8d8..dca58081e1 100644 --- a/packages/model-runtime/src/qwen/createImage.ts +++ b/packages/model-runtime/src/qwen/createImage.ts @@ -1,6 +1,7 @@ import createDebug from 'debug'; import { CreateImagePayload, CreateImageResponse } from '../types/image'; +import { type TaskResult, asyncifyPolling } from '../utils/asyncifyPolling'; import { AgentRuntimeError } from '../utils/createError'; import { CreateImageOptions } from '../utils/openaiCompatibleFactory'; @@ -139,93 +140,47 @@ export async function createQwenImage( // 1. Create image generation task const taskId = await createImageTask(payload, apiKey); - // 2. Poll task status until completion - let taskStatus: QwenImageTaskResponse | null = null; - let retries = 0; - let consecutiveFailures = 0; - const maxConsecutiveFailures = 3; // Allow up to 3 consecutive query failures - // Using Infinity for maxRetries is safe because: - // 1. Vercel runtime has execution time limits - // 2. Qwen's API will eventually return FAILED status for timed-out tasks - // 3. Our exponential backoff ensures reasonable retry intervals - const maxRetries = Infinity; - const initialRetryInterval = 500; // 500ms initial interval - const maxRetryInterval = 5000; // 5 seconds max interval - const backoffMultiplier = 1.5; // exponential backoff multiplier + // 2. Poll task status until completion using asyncifyPolling + const result = await asyncifyPolling({ + checkStatus: (taskStatus: QwenImageTaskResponse): TaskResult => { + log('Task %s status: %s', taskId, taskStatus.output.task_status); - while (retries < maxRetries) { - try { - taskStatus = await queryTaskStatus(taskId, apiKey); - consecutiveFailures = 0; // Reset consecutive failures on success - } catch (error) { - consecutiveFailures++; - log( - 'Failed to query task status (attempt %d/%d, consecutive failures: %d/%d): %O', - retries + 1, - maxRetries, - consecutiveFailures, - maxConsecutiveFailures, - error, - ); + if (taskStatus.output.task_status === 'SUCCEEDED') { + if (!taskStatus.output.results || taskStatus.output.results.length === 0) { + return { + error: new Error('Task succeeded but no images generated'), + status: 'failed', + }; + } - // If we've failed too many times in a row, give up - if (consecutiveFailures >= maxConsecutiveFailures) { - throw new Error( - `Failed to query task status after ${consecutiveFailures} consecutive attempts: ${error}`, - ); + const imageUrl = taskStatus.output.results[0].url; + log('Image generated successfully: %s', imageUrl); + + return { + data: { imageUrl }, + status: 'success', + }; } - // Wait before retrying - const currentRetryInterval = Math.min( - initialRetryInterval * Math.pow(backoffMultiplier, retries), - maxRetryInterval, - ); - await new Promise((resolve) => { - setTimeout(resolve, currentRetryInterval); - }); - retries++; - continue; // Skip the rest of the loop and retry - } - - // At this point, taskStatus should not be null since we just got it successfully - log( - 'Task %s status: %s (attempt %d/%d)', - taskId, - taskStatus!.output.task_status, - retries + 1, - maxRetries, - ); - - if (taskStatus!.output.task_status === 'SUCCEEDED') { - if (!taskStatus!.output.results || taskStatus!.output.results.length === 0) { - throw new Error('Task succeeded but no images generated'); + if (taskStatus.output.task_status === 'FAILED') { + const errorMessage = taskStatus.output.error_message || 'Image generation task failed'; + return { + error: new Error(`Qwen image generation failed: ${errorMessage}`), + status: 'failed', + }; } - // Return the first generated image - const imageUrl = taskStatus!.output.results[0].url; - log('Image generated successfully: %s', imageUrl); + // Continue polling for pending/running status or other unknown statuses + return { status: 'pending' }; + }, + logger: { + debug: (message: any, ...args: any[]) => log(message, ...args), + error: (message: any, ...args: any[]) => log(message, ...args), + }, + pollingQuery: () => queryTaskStatus(taskId, apiKey), + }); - return { imageUrl }; - } else if (taskStatus!.output.task_status === 'FAILED') { - throw new Error(taskStatus!.output.error_message || 'Image generation task failed'); - } - - // Calculate dynamic retry interval with exponential backoff - const currentRetryInterval = Math.min( - initialRetryInterval * Math.pow(backoffMultiplier, retries), - maxRetryInterval, - ); - - log('Waiting %dms before next retry', currentRetryInterval); - - // Wait before retrying - await new Promise((resolve) => { - setTimeout(resolve, currentRetryInterval); - }); - retries++; - } - - throw new Error(`Image generation timeout after ${maxRetries} attempts`); + return result; } catch (error) { log('Error in createQwenImage: %O', error); diff --git a/packages/model-runtime/src/runtimeMap.ts b/packages/model-runtime/src/runtimeMap.ts index a7fe2e4b2f..67e88c4519 100644 --- a/packages/model-runtime/src/runtimeMap.ts +++ b/packages/model-runtime/src/runtimeMap.ts @@ -7,6 +7,7 @@ import { LobeAzureOpenAI } from './azureOpenai'; import { LobeAzureAI } from './azureai'; import { LobeBaichuanAI } from './baichuan'; import { LobeBedrockAI } from './bedrock'; +import { LobeBflAI } from './bfl'; import { LobeCloudflareAI } from './cloudflare'; import { LobeCohereAI } from './cohere'; import { LobeDeepSeekAI } from './deepseek'; @@ -65,6 +66,7 @@ export const providerRuntimeMap = { azureai: LobeAzureAI, baichuan: LobeBaichuanAI, bedrock: LobeBedrockAI, + bfl: LobeBflAI, cloudflare: LobeCloudflareAI, cohere: LobeCohereAI, deepseek: LobeDeepSeekAI, diff --git a/packages/model-runtime/src/utils/asyncifyPolling.test.ts b/packages/model-runtime/src/utils/asyncifyPolling.test.ts new file mode 100644 index 0000000000..002a014959 --- /dev/null +++ b/packages/model-runtime/src/utils/asyncifyPolling.test.ts @@ -0,0 +1,491 @@ +import { describe, expect, it, vi } from 'vitest'; + +import { type TaskResult, asyncifyPolling } from './asyncifyPolling'; + +describe('asyncifyPolling', () => { + describe('basic functionality', () => { + it('should return data when task succeeds immediately', async () => { + const mockTask = vi.fn().mockResolvedValue({ status: 'completed', data: 'result' }); + const mockCheckStatus = vi.fn().mockReturnValue({ + status: 'success', + data: 'result', + } as TaskResult); + + const result = await asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + }); + + expect(result).toBe('result'); + expect(mockTask).toHaveBeenCalledTimes(1); + expect(mockCheckStatus).toHaveBeenCalledTimes(1); + }); + + it('should poll multiple times until success', async () => { + const mockTask = vi + .fn() + .mockResolvedValueOnce({ status: 'pending' }) + .mockResolvedValueOnce({ status: 'pending' }) + .mockResolvedValueOnce({ status: 'completed', data: 'final-result' }); + + const mockCheckStatus = vi + .fn() + .mockReturnValueOnce({ status: 'pending' }) + .mockReturnValueOnce({ status: 'pending' }) + .mockReturnValueOnce({ status: 'success', data: 'final-result' }); + + const result = await asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + initialInterval: 10, // fast test + }); + + expect(result).toBe('final-result'); + expect(mockTask).toHaveBeenCalledTimes(3); + expect(mockCheckStatus).toHaveBeenCalledTimes(3); + }); + + it('should throw error when task fails', async () => { + const mockTask = vi.fn().mockResolvedValue({ status: 'failed', error: 'Task failed' }); + const mockCheckStatus = vi.fn().mockReturnValue({ + status: 'failed', + error: new Error('Task failed'), + }); + + await expect( + asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + }), + ).rejects.toThrow('Task failed'); + + expect(mockTask).toHaveBeenCalledTimes(1); + }); + + it('should handle pending status correctly', async () => { + const mockTask = vi + .fn() + .mockResolvedValueOnce({ status: 'processing' }) + .mockResolvedValueOnce({ status: 'done' }); + + const mockCheckStatus = vi + .fn() + .mockReturnValueOnce({ status: 'pending' }) + .mockReturnValueOnce({ status: 'success', data: 'completed' }); + + const result = await asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + initialInterval: 10, + }); + + expect(result).toBe('completed'); + expect(mockTask).toHaveBeenCalledTimes(2); + }); + }); + + describe('retry mechanism', () => { + it('should retry with exponential backoff', async () => { + const startTime = Date.now(); + const mockTask = vi + .fn() + .mockResolvedValueOnce({ status: 'pending' }) + .mockResolvedValueOnce({ status: 'pending' }) + .mockResolvedValueOnce({ status: 'success' }); + + const mockCheckStatus = vi + .fn() + .mockReturnValueOnce({ status: 'pending' }) + .mockReturnValueOnce({ status: 'pending' }) + .mockReturnValueOnce({ status: 'success', data: 'done' }); + + await asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + initialInterval: 50, + backoffMultiplier: 2, + maxInterval: 200, + }); + + const elapsed = Date.now() - startTime; + // Should wait at least 50ms + 100ms = 150ms + expect(elapsed).toBeGreaterThan(140); + }); + + it('should respect maxInterval limit', async () => { + const intervals: number[] = []; + const originalSetTimeout = global.setTimeout; + + global.setTimeout = vi.fn((callback, delay) => { + intervals.push(delay as number); + return originalSetTimeout(callback, 1); // fast execution + }) as any; + + const mockTask = vi + .fn() + .mockResolvedValueOnce({ status: 'pending' }) + .mockResolvedValueOnce({ status: 'pending' }) + .mockResolvedValueOnce({ status: 'pending' }) + .mockResolvedValueOnce({ status: 'success' }); + + const mockCheckStatus = vi + .fn() + .mockReturnValueOnce({ status: 'pending' }) + .mockReturnValueOnce({ status: 'pending' }) + .mockReturnValueOnce({ status: 'pending' }) + .mockReturnValueOnce({ status: 'success', data: 'done' }); + + await asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + initialInterval: 100, + backoffMultiplier: 3, + maxInterval: 200, + }); + + // Intervals should be: 100, 200 (capped), 200 (capped) + expect(intervals).toEqual([100, 200, 200]); + + global.setTimeout = originalSetTimeout; + }); + + it('should stop after maxRetries', async () => { + const mockTask = vi.fn().mockResolvedValue({ status: 'pending' }); + const mockCheckStatus = vi.fn().mockReturnValue({ status: 'pending' }); + + await expect( + asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + maxRetries: 3, + initialInterval: 1, + }), + ).rejects.toThrow(/timeout after 3 attempts/); + + expect(mockTask).toHaveBeenCalledTimes(3); + }); + }); + + describe('error handling', () => { + it('should handle consecutive failures', async () => { + const mockTask = vi + .fn() + .mockRejectedValueOnce(new Error('Network error 1')) + .mockRejectedValueOnce(new Error('Network error 2')) + .mockResolvedValueOnce({ status: 'success' }); + + const mockCheckStatus = vi.fn().mockReturnValue({ status: 'success', data: 'recovered' }); + + const result = await asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + maxConsecutiveFailures: 3, + initialInterval: 1, + }); + + expect(result).toBe('recovered'); + expect(mockTask).toHaveBeenCalledTimes(3); + }); + + it('should throw after maxConsecutiveFailures', async () => { + const mockTask = vi + .fn() + .mockRejectedValueOnce(new Error('Network error 1')) + .mockRejectedValueOnce(new Error('Network error 2')) + .mockRejectedValueOnce(new Error('Network error 3')); + + const mockCheckStatus = vi.fn(); + + await expect( + asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + maxConsecutiveFailures: 2, // 允许最多2次连续失败 + initialInterval: 1, + }), + ).rejects.toThrow(/consecutive attempts/); + + expect(mockTask).toHaveBeenCalledTimes(2); // 第1次失败,第2次失败,然后抛出错误 + expect(mockCheckStatus).not.toHaveBeenCalled(); + }); + + it('should reset consecutive failures on success', async () => { + const mockTask = vi + .fn() + .mockRejectedValueOnce(new Error('Network error 1')) // Failure 1 (consecutiveFailures=1) + .mockResolvedValueOnce({ status: 'pending' }) // Success 1 (reset to 0) + .mockRejectedValueOnce(new Error('Network error 2')) // Failure 2 (consecutiveFailures=1) + .mockRejectedValueOnce(new Error('Network error 3')) // Failure 3 (consecutiveFailures=2) + .mockResolvedValueOnce({ status: 'success' }); // Success 2 (return result) + + const mockCheckStatus = vi + .fn() + .mockReturnValueOnce({ status: 'pending' }) // For success 1 + .mockReturnValueOnce({ status: 'success', data: 'final' }); // For success 2 + + const result = await asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + maxConsecutiveFailures: 3, // Allow up to 3 consecutive failures (since there are 2 consecutive failures) + initialInterval: 1, + }); + + expect(result).toBe('final'); + expect(mockTask).toHaveBeenCalledTimes(5); // Total 5 calls + }); + }); + + describe('configuration', () => { + it('should use custom intervals and multipliers', async () => { + const intervals: number[] = []; + const originalSetTimeout = global.setTimeout; + + global.setTimeout = vi.fn((callback, delay) => { + intervals.push(delay as number); + return originalSetTimeout(callback, 1); + }) as any; + + const mockTask = vi + .fn() + .mockResolvedValueOnce({ status: 'pending' }) + .mockResolvedValueOnce({ status: 'success' }); + + const mockCheckStatus = vi + .fn() + .mockReturnValueOnce({ status: 'pending' }) + .mockReturnValueOnce({ status: 'success', data: 'done' }); + + await asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + initialInterval: 200, + backoffMultiplier: 1.2, + }); + + expect(intervals[0]).toBe(200); + + global.setTimeout = originalSetTimeout; + }); + + it('should accept custom logger function', async () => { + const mockLogger = { + debug: vi.fn(), + error: vi.fn(), + }; + + const mockTask = vi + .fn() + .mockRejectedValueOnce(new Error('Test error')) + .mockResolvedValueOnce({ status: 'success' }); + + const mockCheckStatus = vi.fn().mockReturnValue({ status: 'success', data: 'done' }); + + await asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + logger: mockLogger, + maxConsecutiveFailures: 3, + initialInterval: 1, + }); + + expect(mockLogger.debug).toHaveBeenCalled(); + expect(mockLogger.error).toHaveBeenCalled(); + }); + }); + + describe('edge cases', () => { + it('should handle immediate failure', async () => { + const mockTask = vi.fn().mockResolvedValue({ error: 'immediate failure' }); + const mockCheckStatus = vi.fn().mockReturnValue({ + status: 'failed', + error: new Error('immediate failure'), + }); + + await expect( + asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + }), + ).rejects.toThrow('immediate failure'); + + expect(mockTask).toHaveBeenCalledTimes(1); + }); + + it('should handle task throwing exceptions', async () => { + const mockTask = vi.fn().mockRejectedValue(new Error('Task exception')); + const mockCheckStatus = vi.fn(); + + await expect( + asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + maxConsecutiveFailures: 1, + initialInterval: 1, + }), + ).rejects.toThrow(/consecutive attempts/); + }); + + it('should timeout correctly with maxRetries = 1', async () => { + const mockTask = vi.fn().mockResolvedValue({ status: 'pending' }); + const mockCheckStatus = vi.fn().mockReturnValue({ status: 'pending' }); + + await expect( + asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + maxRetries: 1, + initialInterval: 1, + }), + ).rejects.toThrow(/timeout after 1 attempts/); + + expect(mockTask).toHaveBeenCalledTimes(1); + }); + }); + + describe('custom error handling', () => { + it('should allow continuing polling via onPollingError', async () => { + const mockTask = vi + .fn() + .mockRejectedValueOnce(new Error('Network error')) + .mockRejectedValueOnce(new Error('Another error')) + .mockResolvedValueOnce({ status: 'success' }); + + const mockCheckStatus = vi.fn().mockReturnValue({ status: 'success', data: 'final-result' }); + + const onPollingError = vi.fn().mockReturnValue({ + isContinuePolling: true, + }); + + const result = await asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + onPollingError, + initialInterval: 1, + }); + + expect(result).toBe('final-result'); + expect(mockTask).toHaveBeenCalledTimes(3); + expect(onPollingError).toHaveBeenCalledTimes(2); + + // Check that error context was passed correctly + expect(onPollingError).toHaveBeenCalledWith({ + error: expect.any(Error), + retries: expect.any(Number), + consecutiveFailures: expect.any(Number), + }); + }); + + it('should stop polling when onPollingError returns false', async () => { + const mockTask = vi.fn().mockRejectedValue(new Error('Fatal error')); + const mockCheckStatus = vi.fn(); + + const onPollingError = vi.fn().mockReturnValue({ + isContinuePolling: false, + }); + + await expect( + asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + onPollingError, + initialInterval: 1, + }), + ).rejects.toThrow('Fatal error'); + + expect(mockTask).toHaveBeenCalledTimes(1); + expect(onPollingError).toHaveBeenCalledTimes(1); + expect(mockCheckStatus).not.toHaveBeenCalled(); + }); + + it('should throw custom error when provided by onPollingError', async () => { + const mockTask = vi.fn().mockRejectedValue(new Error('Original error')); + const mockCheckStatus = vi.fn(); + + const customError = new Error('Custom error message'); + const onPollingError = vi.fn().mockReturnValue({ + isContinuePolling: false, + error: customError, + }); + + await expect( + asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + onPollingError, + initialInterval: 1, + }), + ).rejects.toThrow('Custom error message'); + + expect(onPollingError).toHaveBeenCalledWith({ + error: expect.objectContaining({ message: 'Original error' }), + retries: 0, + consecutiveFailures: 1, + }); + }); + + it('should provide correct context information to onPollingError', async () => { + const mockTask = vi + .fn() + .mockRejectedValueOnce(new Error('Error 1')) + .mockRejectedValueOnce(new Error('Error 2')) + .mockRejectedValueOnce(new Error('Error 3')); + + const mockCheckStatus = vi.fn(); + + const onPollingError = vi + .fn() + .mockReturnValueOnce({ isContinuePolling: true }) + .mockReturnValueOnce({ isContinuePolling: true }) + .mockReturnValueOnce({ isContinuePolling: false }); + + await expect( + asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + onPollingError, + initialInterval: 1, + }), + ).rejects.toThrow('Error 3'); + + // Verify context progression + expect(onPollingError).toHaveBeenNthCalledWith(1, { + error: expect.objectContaining({ message: 'Error 1' }), + retries: 0, + consecutiveFailures: 1, + }); + + expect(onPollingError).toHaveBeenNthCalledWith(2, { + error: expect.objectContaining({ message: 'Error 2' }), + retries: 1, + consecutiveFailures: 2, + }); + + expect(onPollingError).toHaveBeenNthCalledWith(3, { + error: expect.objectContaining({ message: 'Error 3' }), + retries: 2, + consecutiveFailures: 3, + }); + }); + + it('should fall back to default behavior when onPollingError is not provided', async () => { + const mockTask = vi + .fn() + .mockRejectedValueOnce(new Error('Error 1')) + .mockRejectedValueOnce(new Error('Error 2')) + .mockRejectedValueOnce(new Error('Error 3')); + + const mockCheckStatus = vi.fn(); + + await expect( + asyncifyPolling({ + pollingQuery: mockTask, + checkStatus: mockCheckStatus, + maxConsecutiveFailures: 2, + initialInterval: 1, + }), + ).rejects.toThrow(/consecutive attempts/); + + expect(mockTask).toHaveBeenCalledTimes(2); + }); + }); +}); diff --git a/packages/model-runtime/src/utils/asyncifyPolling.ts b/packages/model-runtime/src/utils/asyncifyPolling.ts new file mode 100644 index 0000000000..db8cac9680 --- /dev/null +++ b/packages/model-runtime/src/utils/asyncifyPolling.ts @@ -0,0 +1,175 @@ +export interface TaskResult { + data?: T; + error?: any; + status: 'pending' | 'success' | 'failed'; +} + +export interface PollingErrorContext { + consecutiveFailures: number; + error: any; + retries: number; +} + +export interface PollingErrorResult { + error?: any; + isContinuePolling: boolean; // If provided, will replace the original error when thrown +} + +export interface AsyncifyPollingOptions { + // Default 5000ms + backoffMultiplier?: number; + + // Status check function to determine task result + checkStatus: (result: T) => TaskResult; + + // Retry configuration + initialInterval?: number; + // Optional logger + logger?: { + debug?: (...args: any[]) => void; + error?: (...args: any[]) => void; + }; + // Default 1.5 + maxConsecutiveFailures?: number; + // Default 500ms + maxInterval?: number; // Default 3 + maxRetries?: number; // Default Infinity + + // Optional custom error handler for polling query failures + onPollingError?: (context: PollingErrorContext) => PollingErrorResult; + + // The polling function to execute repeatedly + pollingQuery: () => Promise; +} + +/** + * Convert polling pattern to async/await pattern + * + * @param options Polling configuration options + * @returns Promise The data returned when task completes + * @throws Error When task fails or times out + */ +export async function asyncifyPolling(options: AsyncifyPollingOptions): Promise { + const { + pollingQuery, + checkStatus, + initialInterval = 500, + maxInterval = 5000, + backoffMultiplier = 1.5, + maxConsecutiveFailures = 3, + maxRetries = Infinity, + onPollingError, + logger, + } = options; + + let retries = 0; + let consecutiveFailures = 0; + + while (retries < maxRetries) { + let pollingResult: T; + + try { + // Execute polling function + pollingResult = await pollingQuery(); + + // Reset consecutive failures counter on successful execution + consecutiveFailures = 0; + } catch (error) { + // Polling function execution failed (network error, etc.) + consecutiveFailures++; + + logger?.error?.( + `Failed to execute polling function (attempt ${retries + 1}/${maxRetries === Infinity ? '∞' : maxRetries}, consecutive failures: ${consecutiveFailures}/${maxConsecutiveFailures}):`, + error, + ); + + // Handle custom error processing if provided + if (onPollingError) { + const errorResult = onPollingError({ + consecutiveFailures, + error, + retries, + }); + + if (!errorResult.isContinuePolling) { + // Custom error handler decided to stop polling + throw errorResult.error || error; + } + + // Custom error handler decided to continue polling + logger?.debug?.('Custom error handler decided to continue polling'); + } else { + // Default behavior: check if maximum consecutive failures reached + if (consecutiveFailures >= maxConsecutiveFailures) { + throw new Error( + `Failed to execute polling function after ${consecutiveFailures} consecutive attempts: ${error}`, + ); + } + } + + // Wait before retry and continue to next loop iteration + if (retries < maxRetries - 1) { + const currentInterval = Math.min( + initialInterval * Math.pow(backoffMultiplier, retries), + maxInterval, + ); + + logger?.debug?.(`Waiting ${currentInterval}ms before next retry`); + + await new Promise((resolve) => { + setTimeout(resolve, currentInterval); + }); + } + + retries++; + continue; + } + + // Check task status + const statusResult = checkStatus(pollingResult); + + logger?.debug?.(`Task status: ${statusResult.status} (attempt ${retries + 1})`); + + switch (statusResult.status) { + case 'success': { + return statusResult.data as R; + } + + case 'failed': { + // Task logic failed, throw error immediately (not counted as consecutive failure) + throw statusResult.error || new Error('Task failed'); + } + + case 'pending': { + // Continue polling + break; + } + + default: { + // Unknown status, treat as pending + break; + } + } + + // Wait before next retry if not the last attempt + if (retries < maxRetries - 1) { + // Calculate dynamic retry interval with exponential backoff + const currentInterval = Math.min( + initialInterval * Math.pow(backoffMultiplier, retries), + maxInterval, + ); + + logger?.debug?.(`Waiting ${currentInterval}ms before next retry`); + + // Wait for retry interval + await new Promise((resolve) => { + setTimeout(resolve, currentInterval); + }); + } + + retries++; + } + + // Maximum retries reached + throw new Error(`Task timeout after ${maxRetries} attempts`); +} diff --git a/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/index.tsx b/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/index.tsx index ac47cb1e2c..8ce017b40f 100644 --- a/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/index.tsx +++ b/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/index.tsx @@ -46,7 +46,7 @@ const ConfigPanel = memo(() => { const { showDimensionControl } = useDimensionControl(); return ( - + diff --git a/src/config/aiModels/bfl.ts b/src/config/aiModels/bfl.ts new file mode 100644 index 0000000000..8fdb571ee6 --- /dev/null +++ b/src/config/aiModels/bfl.ts @@ -0,0 +1,145 @@ +import { PRESET_ASPECT_RATIOS } from '@/const/image'; +import { ModelParamsSchema } from '@/libs/standard-parameters'; +import { AIImageModelCard } from '@/types/aiModel'; + +// https://docs.bfl.ai/api-reference/tasks/edit-or-create-an-image-with-flux-kontext-pro +// official support 21:9 ~ 9:21 (ratio 0.43 ~ 2.33) +const calculateRatio = (aspectRatio: string): number => { + const [width, height] = aspectRatio.split(':').map(Number); + return width / height; +}; + +const defaultAspectRatios = PRESET_ASPECT_RATIOS.filter((ratio) => { + const value = calculateRatio(ratio); + // BFL API supports ratio range: 21:9 ~ 9:21 (approximately 0.43 ~ 2.33) + // Use a small tolerance for floating point comparison + return value >= 9 / 21 - 0.001 && value <= 21 / 9 + 0.001; +}); + +const fluxKontextSeriesParamsSchema: ModelParamsSchema = { + aspectRatio: { + default: '1:1', + enum: defaultAspectRatios, + }, + imageUrls: { + default: [], + }, + prompt: { default: '' }, + seed: { default: null }, +}; + +const imageModels: AIImageModelCard[] = [ + // https://docs.bfl.ai/api-reference/tasks/edit-or-create-an-image-with-flux-kontext-pro + { + description: '最先进的上下文图像生成和编辑——结合文本和图像以获得精确、连贯的结果。', + displayName: 'FLUX.1 Kontext [pro]', + enabled: true, + id: 'flux-kontext-pro', + parameters: fluxKontextSeriesParamsSchema, + // check: https://bfl.ai/pricing + pricing: { + units: [{ name: 'imageGeneration', rate: 0.04, strategy: 'fixed', unit: 'image' }], + }, + releasedAt: '2025-05-29', + type: 'image', + }, + // https://docs.bfl.ai/api-reference/tasks/edit-or-create-an-image-with-flux-kontext-max + { + description: '最先进的上下文图像生成和编辑——结合文本和图像以获得精确、连贯的结果。', + displayName: 'FLUX.1 Kontext [max]', + enabled: true, + id: 'flux-kontext-max', + parameters: fluxKontextSeriesParamsSchema, + pricing: { + units: [{ name: 'imageGeneration', rate: 0.08, strategy: 'fixed', unit: 'image' }], + }, + releasedAt: '2025-05-29', + type: 'image', + }, + // https://docs.bfl.ai/api-reference/tasks/generate-an-image-with-flux-11-[pro] + { + description: '升级版专业级AI图像生成模型——提供卓越的图像质量和精确的提示词遵循能力。', + displayName: 'FLUX1.1 [pro] ', + enabled: true, + id: 'flux-pro-1.1', + parameters: { + height: { default: 768, max: 1440, min: 256, step: 32 }, + imageUrl: { default: null }, + prompt: { default: '' }, + seed: { default: null }, + width: { default: 1024, max: 1440, min: 256, step: 32 }, + }, + pricing: { + units: [{ name: 'imageGeneration', rate: 0.06, strategy: 'fixed', unit: 'image' }], + }, + releasedAt: '2024-10-02', + type: 'image', + }, + // https://docs.bfl.ai/api-reference/tasks/generate-an-image-with-flux-11-[pro]-with-ultra-mode-and-optional-raw-mode + { + description: '超高分辨率AI图像生成——支持4兆像素输出,10秒内生成超清图像。', + displayName: 'FLUX1.1 [pro] Ultra', + enabled: true, + id: 'flux-pro-1.1-ultra', + parameters: { + aspectRatio: { + default: '16:9', + enum: defaultAspectRatios, + }, + imageUrl: { default: null }, + prompt: { default: '' }, + seed: { default: null }, + }, + pricing: { + units: [{ name: 'imageGeneration', rate: 0.06, strategy: 'fixed', unit: 'image' }], + }, + releasedAt: '2024-11-06', + type: 'image', + }, + // https://docs.bfl.ai/api-reference/tasks/generate-an-image-with-flux1-[pro] + { + description: '顶级商用AI图像生成模型——无与伦比的图像质量和多样化输出表现。', + displayName: 'FLUX.1 [pro]', + enabled: true, + id: 'flux-pro', + parameters: { + cfg: { default: 2.5, max: 5, min: 1.5, step: 0.1 }, + height: { default: 768, max: 1440, min: 256, step: 32 }, + imageUrl: { default: null }, + prompt: { default: '' }, + seed: { default: null }, + steps: { default: 40, max: 50, min: 1 }, + width: { default: 1024, max: 1440, min: 256, step: 32 }, + }, + pricing: { + units: [{ name: 'imageGeneration', rate: 0.025, strategy: 'fixed', unit: 'image' }], + }, + releasedAt: '2024-08-01', + type: 'image', + }, + // https://docs.bfl.ai/api-reference/tasks/generate-an-image-with-flux1-[dev] + { + description: '开源研发版AI图像生成模型——高效优化,适合非商业用途的创新研究。', + displayName: 'FLUX.1 [dev]', + enabled: true, + id: 'flux-dev', + parameters: { + cfg: { default: 3, max: 5, min: 1.5, step: 0.1 }, + height: { default: 768, max: 1440, min: 256, step: 32 }, + imageUrl: { default: null }, + prompt: { default: '' }, + seed: { default: null }, + steps: { default: 28, max: 50, min: 1 }, + width: { default: 1024, max: 1440, min: 256, step: 32 }, + }, + pricing: { + units: [{ name: 'imageGeneration', rate: 0.025, strategy: 'fixed', unit: 'image' }], + }, + releasedAt: '2024-08-01', + type: 'image', + }, +]; + +export const allModels = [...imageModels]; + +export default allModels; diff --git a/src/config/aiModels/index.ts b/src/config/aiModels/index.ts index 5a850f541c..64a7bfb666 100644 --- a/src/config/aiModels/index.ts +++ b/src/config/aiModels/index.ts @@ -9,6 +9,7 @@ import { default as azure } from './azure'; import { default as azureai } from './azureai'; import { default as baichuan } from './baichuan'; import { default as bedrock } from './bedrock'; +import { default as bfl } from './bfl'; import { default as cloudflare } from './cloudflare'; import { default as cohere } from './cohere'; import { default as deepseek } from './deepseek'; @@ -87,6 +88,7 @@ export const LOBE_DEFAULT_MODEL_LIST = buildDefaultModelList({ azureai, baichuan, bedrock, + bfl, cloudflare, cohere, deepseek, @@ -146,6 +148,7 @@ export { default as azure } from './azure'; export { default as azureai } from './azureai'; export { default as baichuan } from './baichuan'; export { default as bedrock } from './bedrock'; +export { default as bfl } from './bfl'; export { default as cloudflare } from './cloudflare'; export { default as cohere } from './cohere'; export { default as deepseek } from './deepseek'; diff --git a/src/config/llm.ts b/src/config/llm.ts index 50654bdd4d..da6dca0d0b 100644 --- a/src/config/llm.ts +++ b/src/config/llm.ts @@ -166,13 +166,15 @@ export const getLLMConfig = () => { ENABLED_FAL: z.boolean(), FAL_API_KEY: z.string().optional(), + ENABLED_BFL: z.boolean(), + BFL_API_KEY: z.string().optional(), + ENABLED_MODELSCOPE: z.boolean(), MODELSCOPE_API_KEY: z.string().optional(), ENABLED_V0: z.boolean(), V0_API_KEY: z.string().optional(), - ENABLED_AI302: z.boolean(), AI302_API_KEY: z.string().optional(), @@ -342,6 +344,9 @@ export const getLLMConfig = () => { ENABLED_FAL: process.env.ENABLED_FAL !== '0', FAL_API_KEY: process.env.FAL_API_KEY, + ENABLED_BFL: !!process.env.BFL_API_KEY, + BFL_API_KEY: process.env.BFL_API_KEY, + ENABLED_MODELSCOPE: !!process.env.MODELSCOPE_API_KEY, MODELSCOPE_API_KEY: process.env.MODELSCOPE_API_KEY, diff --git a/src/config/modelProviders/bfl.ts b/src/config/modelProviders/bfl.ts new file mode 100644 index 0000000000..2e9f6c19ff --- /dev/null +++ b/src/config/modelProviders/bfl.ts @@ -0,0 +1,21 @@ +import { ModelProviderCard } from '@/types/llm'; + +/** + * @see https://docs.bfl.ai/ + */ +const Bfl: ModelProviderCard = { + chatModels: [], + description: '领先的前沿人工智能研究实验室,构建明日的视觉基础设施。', + enabled: true, + id: 'bfl', + name: 'Black Forest Labs', + settings: { + disableBrowserRequest: true, + showAddNewModel: false, + showChecker: false, + showModelFetcher: false, + }, + url: 'https://bfl.ai/', +}; + +export default Bfl; diff --git a/src/config/modelProviders/index.ts b/src/config/modelProviders/index.ts index 389610b498..96eeca6333 100644 --- a/src/config/modelProviders/index.ts +++ b/src/config/modelProviders/index.ts @@ -9,6 +9,7 @@ import AzureProvider from './azure'; import AzureAIProvider from './azureai'; import BaichuanProvider from './baichuan'; import BedrockProvider from './bedrock'; +import BflProvider from './bfl'; import CloudflareProvider from './cloudflare'; import CohereProvider from './cohere'; import DeepSeekProvider from './deepseek'; @@ -132,6 +133,7 @@ export const DEFAULT_MODEL_PROVIDER_LIST = [ HuggingFaceProvider, CloudflareProvider, GithubProvider, + BflProvider, NovitaProvider, PPIOProvider, NvidiaProvider, @@ -191,6 +193,7 @@ export { default as AzureProviderCard } from './azure'; export { default as AzureAIProviderCard } from './azureai'; export { default as BaichuanProviderCard } from './baichuan'; export { default as BedrockProviderCard } from './bedrock'; +export { default as BflProviderCard } from './bfl'; export { default as CloudflareProviderCard } from './cloudflare'; export { default as CohereProviderCard } from './cohere'; export { default as DeepSeekProviderCard } from './deepseek'; diff --git a/src/store/image/slices/generationConfig/hooks.ts b/src/store/image/slices/generationConfig/hooks.ts index 4b1e44b849..aae63ca77b 100644 --- a/src/store/image/slices/generationConfig/hooks.ts +++ b/src/store/image/slices/generationConfig/hooks.ts @@ -77,17 +77,13 @@ export function useDimensionControl() { const aspectRatioOptions = useMemo(() => { const modelOptions = paramsSchema?.aspectRatio?.enum || []; - // 合并选项,优先使用预设选项,然后添加模型特有的选项 - const allOptions = [...PRESET_ASPECT_RATIOS]; + // 如果 schema 里面有 aspectRatio 并且不为空,直接使用 schema 里面的选项 + if (modelOptions.length > 0) { + return modelOptions; + } - // 添加模型选项中不在预设中的选项 - modelOptions.forEach((option) => { - if (!allOptions.includes(option)) { - allOptions.push(option); - } - }); - - return allOptions; + // 否则使用预设选项 + return PRESET_ASPECT_RATIOS; }, [paramsSchema]); // 只要不是所有维度相关的控件都不显示,那么这个容器就应该显示