mirror of
https://github.com/lobehub/lobehub.git
synced 2026-03-26 13:19:34 +07:00
✨ feat(provider): add BFL provider support for image generation (#8806)
This commit is contained in:
2
.vscode/settings.json
vendored
2
.vscode/settings.json
vendored
@@ -68,7 +68,7 @@
|
||||
"**/src/config/modelProviders/*.ts": "${filename} • provider",
|
||||
"**/src/config/aiModels/*.ts": "${filename} • model",
|
||||
"**/src/config/paramsSchemas/*/*.json": "${dirname(1)}/${filename} • params",
|
||||
"**/src/libs/model-runtime/*/index.ts": "${dirname} • runtime",
|
||||
"**/packages/model-runtime/src/*/index.ts": "${dirname} • runtime",
|
||||
|
||||
"**/src/server/services/*/index.ts": "${dirname} • server/service",
|
||||
"**/src/server/routers/lambda/*.ts": "${filename} • lambda",
|
||||
|
||||
@@ -151,7 +151,7 @@
|
||||
"@lobehub/charts": "^2.0.0",
|
||||
"@lobehub/chat-plugin-sdk": "^1.32.4",
|
||||
"@lobehub/chat-plugins-gateway": "^1.9.0",
|
||||
"@lobehub/icons": "^2.25.0",
|
||||
"@lobehub/icons": "^2.27.1",
|
||||
"@lobehub/market-sdk": "^0.22.7",
|
||||
"@lobehub/tts": "^2.0.1",
|
||||
"@lobehub/ui": "^2.8.3",
|
||||
|
||||
@@ -3,4 +3,12 @@
|
||||
*/
|
||||
export const DEFAULT_ASPECT_RATIO = '1:1';
|
||||
|
||||
export const PRESET_ASPECT_RATIOS = [DEFAULT_ASPECT_RATIO, '16:9', '9:16', '4:3', '3:4'];
|
||||
export const PRESET_ASPECT_RATIOS = [
|
||||
DEFAULT_ASPECT_RATIO, // '1:1' - 正方形,最常用
|
||||
'16:9', // 现代显示器/电视/视频标准
|
||||
'9:16', // 手机竖屏/短视频
|
||||
'4:3', // 传统显示器/照片
|
||||
'3:4', // 传统竖屏照片
|
||||
'3:2', // 经典照片比例横屏
|
||||
'2:3', // 经典照片比例竖屏
|
||||
];
|
||||
|
||||
846
packages/model-runtime/src/bfl/createImage.test.ts
Normal file
846
packages/model-runtime/src/bfl/createImage.test.ts
Normal file
@@ -0,0 +1,846 @@
|
||||
// @vitest-environment node
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { CreateImagePayload } from '@/libs/model-runtime/types/image';
|
||||
|
||||
import { createBflImage } from './createImage';
|
||||
import { BflStatusResponse } from './types';
|
||||
|
||||
// Mock external dependencies
|
||||
vi.mock('@/utils/imageToBase64', () => ({
|
||||
imageUrlToBase64: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../utils/uriParser', () => ({
|
||||
parseDataUri: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../utils/asyncifyPolling', () => ({
|
||||
asyncifyPolling: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock fetch
|
||||
global.fetch = vi.fn();
|
||||
const mockFetch = vi.mocked(fetch);
|
||||
|
||||
// Mock the console.error to avoid polluting test output
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
const mockOptions = {
|
||||
apiKey: 'test-api-key',
|
||||
provider: 'bfl' as const,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('createBflImage', () => {
|
||||
describe('Parameter mapping and defaults', () => {
|
||||
it('should map standard parameters to BFL-specific parameters', async () => {
|
||||
// Arrange
|
||||
const { asyncifyPolling } = await import('../utils/asyncifyPolling');
|
||||
const mockAsyncifyPolling = vi.mocked(asyncifyPolling);
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
id: 'task-123',
|
||||
polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123',
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
mockAsyncifyPolling.mockResolvedValue({
|
||||
imageUrl: 'https://example.com/result.jpg',
|
||||
});
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-dev',
|
||||
params: {
|
||||
prompt: 'A beautiful landscape',
|
||||
aspectRatio: '16:9',
|
||||
cfg: 7.5,
|
||||
steps: 20,
|
||||
seed: 12345,
|
||||
},
|
||||
};
|
||||
|
||||
// Act
|
||||
await createBflImage(payload, mockOptions);
|
||||
|
||||
// Assert
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
'https://api.bfl.ai/v1/flux-dev',
|
||||
expect.objectContaining({
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-key': 'test-api-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
output_format: 'png',
|
||||
safety_tolerance: 6,
|
||||
prompt: 'A beautiful landscape',
|
||||
aspect_ratio: '16:9',
|
||||
guidance: 7.5,
|
||||
steps: 20,
|
||||
seed: 12345,
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should add raw: true for ultra models', async () => {
|
||||
// Arrange
|
||||
const { asyncifyPolling } = await import('../utils/asyncifyPolling');
|
||||
const mockAsyncifyPolling = vi.mocked(asyncifyPolling);
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
id: 'task-123',
|
||||
polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123',
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
mockAsyncifyPolling.mockResolvedValue({
|
||||
imageUrl: 'https://example.com/result.jpg',
|
||||
});
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-pro-1.1-ultra',
|
||||
params: {
|
||||
prompt: 'Ultra quality image',
|
||||
},
|
||||
};
|
||||
|
||||
// Act
|
||||
await createBflImage(payload, mockOptions);
|
||||
|
||||
// Assert
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
'https://api.bfl.ai/v1/flux-pro-1.1-ultra',
|
||||
expect.objectContaining({
|
||||
body: JSON.stringify({
|
||||
output_format: 'png',
|
||||
safety_tolerance: 6,
|
||||
raw: true,
|
||||
prompt: 'Ultra quality image',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should filter out undefined values', async () => {
|
||||
// Arrange
|
||||
const { asyncifyPolling } = await import('../utils/asyncifyPolling');
|
||||
const mockAsyncifyPolling = vi.mocked(asyncifyPolling);
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
id: 'task-123',
|
||||
polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123',
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
mockAsyncifyPolling.mockResolvedValue({
|
||||
imageUrl: 'https://example.com/result.jpg',
|
||||
});
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-dev',
|
||||
params: {
|
||||
prompt: 'Test image',
|
||||
cfg: undefined,
|
||||
seed: 12345,
|
||||
steps: undefined,
|
||||
} as any,
|
||||
};
|
||||
|
||||
// Act
|
||||
await createBflImage(payload, mockOptions);
|
||||
|
||||
// Assert
|
||||
const callArgs = mockFetch.mock.calls[0][1];
|
||||
const requestBody = JSON.parse(callArgs?.body as string);
|
||||
|
||||
expect(requestBody).toEqual({
|
||||
output_format: 'png',
|
||||
safety_tolerance: 6,
|
||||
prompt: 'Test image',
|
||||
seed: 12345,
|
||||
});
|
||||
|
||||
expect(requestBody).not.toHaveProperty('guidance');
|
||||
expect(requestBody).not.toHaveProperty('steps');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Image URL handling', () => {
|
||||
it('should convert single imageUrl to image_prompt base64', async () => {
|
||||
// Arrange
|
||||
const { parseDataUri } = await import('../utils/uriParser');
|
||||
const { imageUrlToBase64 } = await import('@/utils/imageToBase64');
|
||||
const { asyncifyPolling } = await import('../utils/asyncifyPolling');
|
||||
|
||||
const mockParseDataUri = vi.mocked(parseDataUri);
|
||||
const mockImageUrlToBase64 = vi.mocked(imageUrlToBase64);
|
||||
const mockAsyncifyPolling = vi.mocked(asyncifyPolling);
|
||||
|
||||
mockParseDataUri.mockReturnValue({ type: 'url', base64: null, mimeType: null });
|
||||
mockImageUrlToBase64.mockResolvedValue({
|
||||
base64: 'base64EncodedImage',
|
||||
mimeType: 'image/jpeg',
|
||||
});
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
id: 'task-123',
|
||||
polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123',
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
mockAsyncifyPolling.mockResolvedValue({
|
||||
imageUrl: 'https://example.com/result.jpg',
|
||||
});
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-pro-1.1',
|
||||
params: {
|
||||
prompt: 'Transform this image',
|
||||
imageUrl: 'https://example.com/input.jpg',
|
||||
},
|
||||
};
|
||||
|
||||
// Act
|
||||
await createBflImage(payload, mockOptions);
|
||||
|
||||
// Assert
|
||||
expect(mockParseDataUri).toHaveBeenCalledWith('https://example.com/input.jpg');
|
||||
expect(mockImageUrlToBase64).toHaveBeenCalledWith('https://example.com/input.jpg');
|
||||
|
||||
const callArgs = mockFetch.mock.calls[0][1];
|
||||
const requestBody = JSON.parse(callArgs?.body as string);
|
||||
|
||||
expect(requestBody).toEqual({
|
||||
output_format: 'png',
|
||||
safety_tolerance: 6,
|
||||
prompt: 'Transform this image',
|
||||
image_prompt: 'base64EncodedImage',
|
||||
});
|
||||
|
||||
expect(requestBody).not.toHaveProperty('imageUrl');
|
||||
});
|
||||
|
||||
it('should handle base64 imageUrl directly', async () => {
|
||||
// Arrange
|
||||
const { parseDataUri } = await import('../utils/uriParser');
|
||||
const { asyncifyPolling } = await import('../utils/asyncifyPolling');
|
||||
|
||||
const mockParseDataUri = vi.mocked(parseDataUri);
|
||||
const mockAsyncifyPolling = vi.mocked(asyncifyPolling);
|
||||
|
||||
mockParseDataUri.mockReturnValue({
|
||||
type: 'base64',
|
||||
base64: '/9j/4AAQSkZJRgABAQEAYABgAAD',
|
||||
mimeType: 'image/jpeg',
|
||||
});
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
id: 'task-123',
|
||||
polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123',
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
mockAsyncifyPolling.mockResolvedValue({
|
||||
imageUrl: 'https://example.com/result.jpg',
|
||||
});
|
||||
|
||||
const base64Image = 'data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEAYABgAAD';
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-pro-1.1',
|
||||
params: {
|
||||
prompt: 'Transform this image',
|
||||
imageUrl: base64Image,
|
||||
},
|
||||
};
|
||||
|
||||
// Act
|
||||
await createBflImage(payload, mockOptions);
|
||||
|
||||
// Assert
|
||||
const callArgs = mockFetch.mock.calls[0][1];
|
||||
const requestBody = JSON.parse(callArgs?.body as string);
|
||||
|
||||
expect(requestBody.image_prompt).toBe('/9j/4AAQSkZJRgABAQEAYABgAAD');
|
||||
});
|
||||
|
||||
it('should convert multiple imageUrls for Kontext models', async () => {
|
||||
// Arrange
|
||||
const { parseDataUri } = await import('../utils/uriParser');
|
||||
const { imageUrlToBase64 } = await import('@/utils/imageToBase64');
|
||||
const { asyncifyPolling } = await import('../utils/asyncifyPolling');
|
||||
|
||||
const mockParseDataUri = vi.mocked(parseDataUri);
|
||||
const mockImageUrlToBase64 = vi.mocked(imageUrlToBase64);
|
||||
const mockAsyncifyPolling = vi.mocked(asyncifyPolling);
|
||||
|
||||
mockParseDataUri.mockReturnValue({ type: 'url', base64: null, mimeType: null });
|
||||
mockImageUrlToBase64
|
||||
.mockResolvedValueOnce({ base64: 'base64image1', mimeType: 'image/jpeg' })
|
||||
.mockResolvedValueOnce({ base64: 'base64image2', mimeType: 'image/jpeg' })
|
||||
.mockResolvedValueOnce({ base64: 'base64image3', mimeType: 'image/jpeg' });
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
id: 'task-123',
|
||||
polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123',
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
mockAsyncifyPolling.mockResolvedValue({
|
||||
imageUrl: 'https://example.com/result.jpg',
|
||||
});
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-kontext-pro',
|
||||
params: {
|
||||
prompt: 'Create variation of these images',
|
||||
imageUrls: [
|
||||
'https://example.com/input1.jpg',
|
||||
'https://example.com/input2.jpg',
|
||||
'https://example.com/input3.jpg',
|
||||
],
|
||||
},
|
||||
};
|
||||
|
||||
// Act
|
||||
await createBflImage(payload, mockOptions);
|
||||
|
||||
// Assert
|
||||
const callArgs = mockFetch.mock.calls[0][1];
|
||||
const requestBody = JSON.parse(callArgs?.body as string);
|
||||
|
||||
expect(requestBody).toEqual({
|
||||
output_format: 'png',
|
||||
safety_tolerance: 6,
|
||||
prompt: 'Create variation of these images',
|
||||
input_image: 'base64image1',
|
||||
input_image_2: 'base64image2',
|
||||
input_image_3: 'base64image3',
|
||||
});
|
||||
|
||||
expect(requestBody).not.toHaveProperty('imageUrls');
|
||||
});
|
||||
|
||||
it('should limit imageUrls to maximum 4 images', async () => {
|
||||
// Arrange
|
||||
const { parseDataUri } = await import('../utils/uriParser');
|
||||
const { imageUrlToBase64 } = await import('@/utils/imageToBase64');
|
||||
const { asyncifyPolling } = await import('../utils/asyncifyPolling');
|
||||
|
||||
const mockParseDataUri = vi.mocked(parseDataUri);
|
||||
const mockImageUrlToBase64 = vi.mocked(imageUrlToBase64);
|
||||
const mockAsyncifyPolling = vi.mocked(asyncifyPolling);
|
||||
|
||||
mockParseDataUri.mockReturnValue({ type: 'url', base64: null, mimeType: null });
|
||||
mockImageUrlToBase64
|
||||
.mockResolvedValueOnce({ base64: 'base64image1', mimeType: 'image/jpeg' })
|
||||
.mockResolvedValueOnce({ base64: 'base64image2', mimeType: 'image/jpeg' })
|
||||
.mockResolvedValueOnce({ base64: 'base64image3', mimeType: 'image/jpeg' })
|
||||
.mockResolvedValueOnce({ base64: 'base64image4', mimeType: 'image/jpeg' });
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
id: 'task-123',
|
||||
polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123',
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
mockAsyncifyPolling.mockResolvedValue({
|
||||
imageUrl: 'https://example.com/result.jpg',
|
||||
});
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-kontext-max',
|
||||
params: {
|
||||
prompt: 'Create variation of these images',
|
||||
imageUrls: [
|
||||
'https://example.com/input1.jpg',
|
||||
'https://example.com/input2.jpg',
|
||||
'https://example.com/input3.jpg',
|
||||
'https://example.com/input4.jpg',
|
||||
'https://example.com/input5.jpg', // This should be ignored
|
||||
],
|
||||
},
|
||||
};
|
||||
|
||||
// Act
|
||||
await createBflImage(payload, mockOptions);
|
||||
|
||||
// Assert
|
||||
expect(mockImageUrlToBase64).toHaveBeenCalledTimes(4);
|
||||
|
||||
const callArgs = mockFetch.mock.calls[0][1];
|
||||
const requestBody = JSON.parse(callArgs?.body as string);
|
||||
|
||||
expect(requestBody).toEqual({
|
||||
output_format: 'png',
|
||||
safety_tolerance: 6,
|
||||
prompt: 'Create variation of these images',
|
||||
input_image: 'base64image1',
|
||||
input_image_2: 'base64image2',
|
||||
input_image_3: 'base64image3',
|
||||
input_image_4: 'base64image4',
|
||||
});
|
||||
|
||||
expect(requestBody).not.toHaveProperty('input_image_5');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Model endpoint mapping', () => {
|
||||
it('should map models to correct endpoints', async () => {
|
||||
// Arrange
|
||||
const { asyncifyPolling } = await import('../utils/asyncifyPolling');
|
||||
const mockAsyncifyPolling = vi.mocked(asyncifyPolling);
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
id: 'task-123',
|
||||
polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123',
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
mockAsyncifyPolling.mockResolvedValue({
|
||||
imageUrl: 'https://example.com/result.jpg',
|
||||
});
|
||||
|
||||
const testCases = [
|
||||
{ model: 'flux-dev', endpoint: '/v1/flux-dev' },
|
||||
{ model: 'flux-pro', endpoint: '/v1/flux-pro' },
|
||||
{ model: 'flux-pro-1.1', endpoint: '/v1/flux-pro-1.1' },
|
||||
{ model: 'flux-pro-1.1-ultra', endpoint: '/v1/flux-pro-1.1-ultra' },
|
||||
{ model: 'flux-kontext-pro', endpoint: '/v1/flux-kontext-pro' },
|
||||
{ model: 'flux-kontext-max', endpoint: '/v1/flux-kontext-max' },
|
||||
];
|
||||
|
||||
// Act & Assert
|
||||
for (const { model, endpoint } of testCases) {
|
||||
vi.clearAllMocks();
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model,
|
||||
params: {
|
||||
prompt: `Test image for ${model}`,
|
||||
},
|
||||
};
|
||||
|
||||
await createBflImage(payload, mockOptions);
|
||||
|
||||
expect(mockFetch).toHaveBeenCalledWith(`https://api.bfl.ai${endpoint}`, expect.any(Object));
|
||||
}
|
||||
});
|
||||
|
||||
it('should throw error for unsupported model', async () => {
|
||||
// Arrange
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'unsupported-model',
|
||||
params: {
|
||||
prompt: 'Test image',
|
||||
},
|
||||
};
|
||||
|
||||
// Act & Assert
|
||||
await expect(createBflImage(payload, mockOptions)).rejects.toMatchObject({
|
||||
error: expect.objectContaining({
|
||||
message: 'Unsupported BFL model: unsupported-model',
|
||||
}),
|
||||
errorType: 'ModelNotFound',
|
||||
provider: 'bfl',
|
||||
});
|
||||
});
|
||||
|
||||
it('should use custom baseURL when provided', async () => {
|
||||
// Arrange
|
||||
const { asyncifyPolling } = await import('../utils/asyncifyPolling');
|
||||
const mockAsyncifyPolling = vi.mocked(asyncifyPolling);
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
id: 'task-123',
|
||||
polling_url: 'https://custom-api.bfl.ai/v1/get_result?id=task-123',
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
mockAsyncifyPolling.mockResolvedValue({
|
||||
imageUrl: 'https://example.com/result.jpg',
|
||||
});
|
||||
|
||||
const customOptions = {
|
||||
...mockOptions,
|
||||
baseURL: 'https://custom-api.bfl.ai',
|
||||
};
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-dev',
|
||||
params: {
|
||||
prompt: 'Test with custom URL',
|
||||
},
|
||||
};
|
||||
|
||||
// Act
|
||||
await createBflImage(payload, customOptions);
|
||||
|
||||
// Assert
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
'https://custom-api.bfl.ai/v1/flux-dev',
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Status handling', () => {
|
||||
it('should return success when status is Ready with result', async () => {
|
||||
// Arrange
|
||||
const { asyncifyPolling } = await import('../utils/asyncifyPolling');
|
||||
const mockAsyncifyPolling = vi.mocked(asyncifyPolling);
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
id: 'task-123',
|
||||
polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123',
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
// Mock the asyncifyPolling to call checkStatus with Ready status
|
||||
mockAsyncifyPolling.mockImplementation(async ({ checkStatus }) => {
|
||||
const result = checkStatus({
|
||||
id: 'task-123',
|
||||
status: BflStatusResponse.Ready,
|
||||
result: {
|
||||
sample: 'https://example.com/generated-image.jpg',
|
||||
},
|
||||
});
|
||||
|
||||
if (result.status === 'success') {
|
||||
return result.data;
|
||||
}
|
||||
throw result.error;
|
||||
});
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-dev',
|
||||
params: {
|
||||
prompt: 'Test successful generation',
|
||||
},
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createBflImage(payload, mockOptions);
|
||||
|
||||
// Assert
|
||||
expect(result).toEqual({
|
||||
imageUrl: 'https://example.com/generated-image.jpg',
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw error when status is Ready but no result', async () => {
|
||||
// Arrange
|
||||
const { asyncifyPolling } = await import('../utils/asyncifyPolling');
|
||||
const mockAsyncifyPolling = vi.mocked(asyncifyPolling);
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
id: 'task-123',
|
||||
polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123',
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
mockAsyncifyPolling.mockImplementation(async ({ checkStatus }) => {
|
||||
const result = checkStatus({
|
||||
id: 'task-123',
|
||||
status: BflStatusResponse.Ready,
|
||||
result: null,
|
||||
});
|
||||
|
||||
if (result.status === 'success') {
|
||||
return result.data;
|
||||
}
|
||||
throw result.error;
|
||||
});
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-dev',
|
||||
params: {
|
||||
prompt: 'Test no result error',
|
||||
},
|
||||
};
|
||||
|
||||
// Act & Assert
|
||||
await expect(createBflImage(payload, mockOptions)).rejects.toMatchObject({
|
||||
error: expect.any(Object),
|
||||
errorType: 'ProviderBizError',
|
||||
provider: 'bfl',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle error statuses', async () => {
|
||||
// Arrange
|
||||
const { asyncifyPolling } = await import('../utils/asyncifyPolling');
|
||||
const mockAsyncifyPolling = vi.mocked(asyncifyPolling);
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
id: 'task-123',
|
||||
polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123',
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
const errorStatuses = [
|
||||
BflStatusResponse.Error,
|
||||
BflStatusResponse.ContentModerated,
|
||||
BflStatusResponse.RequestModerated,
|
||||
];
|
||||
|
||||
for (const status of errorStatuses) {
|
||||
mockAsyncifyPolling.mockImplementation(async ({ checkStatus }) => {
|
||||
const result = checkStatus({
|
||||
id: 'task-123',
|
||||
status,
|
||||
details: { error: 'Test error details' },
|
||||
});
|
||||
|
||||
if (result.status === 'success') {
|
||||
return result.data;
|
||||
}
|
||||
throw result.error;
|
||||
});
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-dev',
|
||||
params: {
|
||||
prompt: `Test ${status} error`,
|
||||
},
|
||||
};
|
||||
|
||||
// Act & Assert
|
||||
await expect(createBflImage(payload, mockOptions)).rejects.toMatchObject({
|
||||
error: expect.any(Object),
|
||||
errorType: 'ProviderBizError',
|
||||
provider: 'bfl',
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle TaskNotFound status', async () => {
|
||||
// Arrange
|
||||
const { asyncifyPolling } = await import('../utils/asyncifyPolling');
|
||||
const mockAsyncifyPolling = vi.mocked(asyncifyPolling);
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
id: 'task-123',
|
||||
polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123',
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
mockAsyncifyPolling.mockImplementation(async ({ checkStatus }) => {
|
||||
const result = checkStatus({
|
||||
id: 'task-123',
|
||||
status: BflStatusResponse.TaskNotFound,
|
||||
});
|
||||
|
||||
if (result.status === 'success') {
|
||||
return result.data;
|
||||
}
|
||||
throw result.error;
|
||||
});
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-dev',
|
||||
params: {
|
||||
prompt: 'Test task not found',
|
||||
},
|
||||
};
|
||||
|
||||
// Act & Assert
|
||||
await expect(createBflImage(payload, mockOptions)).rejects.toMatchObject({
|
||||
error: expect.any(Object),
|
||||
errorType: 'ProviderBizError',
|
||||
provider: 'bfl',
|
||||
});
|
||||
});
|
||||
|
||||
it('should continue polling for Pending status', async () => {
|
||||
// Arrange
|
||||
const { asyncifyPolling } = await import('../utils/asyncifyPolling');
|
||||
const mockAsyncifyPolling = vi.mocked(asyncifyPolling);
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
id: 'task-123',
|
||||
polling_url: 'https://api.bfl.ai/v1/get_result?id=task-123',
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
mockAsyncifyPolling.mockImplementation(async ({ checkStatus }) => {
|
||||
// First call - Pending status
|
||||
const pendingResult = checkStatus({
|
||||
id: 'task-123',
|
||||
status: BflStatusResponse.Pending,
|
||||
});
|
||||
|
||||
expect(pendingResult.status).toBe('pending');
|
||||
|
||||
// Simulate successful completion
|
||||
const successResult = checkStatus({
|
||||
id: 'task-123',
|
||||
status: BflStatusResponse.Ready,
|
||||
result: {
|
||||
sample: 'https://example.com/generated-image.jpg',
|
||||
},
|
||||
});
|
||||
|
||||
return successResult.data;
|
||||
});
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-dev',
|
||||
params: {
|
||||
prompt: 'Test pending status',
|
||||
},
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createBflImage(payload, mockOptions);
|
||||
|
||||
// Assert
|
||||
expect(result).toEqual({
|
||||
imageUrl: 'https://example.com/generated-image.jpg',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error handling', () => {
|
||||
it('should handle fetch errors during task submission', async () => {
|
||||
// Arrange
|
||||
mockFetch.mockRejectedValue(new Error('Network error'));
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-dev',
|
||||
params: {
|
||||
prompt: 'Test network error',
|
||||
},
|
||||
};
|
||||
|
||||
// Act & Assert
|
||||
await expect(createBflImage(payload, mockOptions)).rejects.toThrow();
|
||||
});
|
||||
|
||||
it('should handle HTTP error responses', async () => {
|
||||
// Arrange
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 400,
|
||||
statusText: 'Bad Request',
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
detail: [{ msg: 'Invalid prompt' }],
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-dev',
|
||||
params: {
|
||||
prompt: 'Test HTTP error',
|
||||
},
|
||||
};
|
||||
|
||||
// Act & Assert
|
||||
await expect(createBflImage(payload, mockOptions)).rejects.toMatchObject({
|
||||
error: expect.any(Object),
|
||||
errorType: 'ProviderBizError',
|
||||
provider: 'bfl',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle HTTP error responses without detail', async () => {
|
||||
// Arrange
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 500,
|
||||
statusText: 'Internal Server Error',
|
||||
json: () => Promise.resolve({}),
|
||||
} as Response);
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-dev',
|
||||
params: {
|
||||
prompt: 'Test HTTP error without detail',
|
||||
},
|
||||
};
|
||||
|
||||
// Act & Assert
|
||||
await expect(createBflImage(payload, mockOptions)).rejects.toMatchObject({
|
||||
error: expect.any(Object),
|
||||
errorType: 'ProviderBizError',
|
||||
provider: 'bfl',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle non-JSON error responses', async () => {
|
||||
// Arrange
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 500,
|
||||
statusText: 'Internal Server Error',
|
||||
json: () => Promise.reject(new Error('Invalid JSON')),
|
||||
} as Response);
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-dev',
|
||||
params: {
|
||||
prompt: 'Test non-JSON error',
|
||||
},
|
||||
};
|
||||
|
||||
// Act & Assert
|
||||
await expect(createBflImage(payload, mockOptions)).rejects.toMatchObject({
|
||||
error: expect.any(Object),
|
||||
errorType: 'ProviderBizError',
|
||||
provider: 'bfl',
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
279
packages/model-runtime/src/bfl/createImage.ts
Normal file
279
packages/model-runtime/src/bfl/createImage.ts
Normal file
@@ -0,0 +1,279 @@
|
||||
import createDebug from 'debug';
|
||||
|
||||
import { RuntimeImageGenParamsValue } from '@/libs/standard-parameters/index';
|
||||
import { imageUrlToBase64 } from '@/utils/imageToBase64';
|
||||
|
||||
import { AgentRuntimeErrorType } from '../error';
|
||||
import { CreateImagePayload, CreateImageResponse } from '../types/image';
|
||||
import { type TaskResult, asyncifyPolling } from '../utils/asyncifyPolling';
|
||||
import { AgentRuntimeError } from '../utils/createError';
|
||||
import { parseDataUri } from '../utils/uriParser';
|
||||
import {
|
||||
BFL_ENDPOINTS,
|
||||
BflAsyncResponse,
|
||||
BflModelId,
|
||||
BflRequest,
|
||||
BflResultResponse,
|
||||
BflStatusResponse,
|
||||
} from './types';
|
||||
|
||||
const log = createDebug('lobe-image:bfl');
|
||||
|
||||
const BASE_URL = 'https://api.bfl.ai';
|
||||
|
||||
interface BflCreateImageOptions {
|
||||
apiKey: string;
|
||||
baseURL?: string;
|
||||
provider: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert image URL to base64 format required by BFL API
|
||||
*/
|
||||
async function convertImageToBase64(imageUrl: string): Promise<string> {
|
||||
try {
|
||||
const { type } = parseDataUri(imageUrl);
|
||||
|
||||
if (type === 'base64') {
|
||||
// Already in base64 format, extract the base64 part
|
||||
const base64Match = imageUrl.match(/^data:[^;]+;base64,(.+)$/);
|
||||
if (base64Match) {
|
||||
return base64Match[1];
|
||||
}
|
||||
throw new Error('Invalid base64 format');
|
||||
}
|
||||
|
||||
if (type === 'url') {
|
||||
// Convert URL to base64
|
||||
const { base64 } = await imageUrlToBase64(imageUrl);
|
||||
return base64;
|
||||
}
|
||||
|
||||
throw new Error(`Invalid image URL format: ${imageUrl}`);
|
||||
} catch (error) {
|
||||
log('Error converting image to base64: %O', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build request payload for different BFL models
|
||||
*/
|
||||
async function buildRequestPayload(
|
||||
model: BflModelId,
|
||||
params: CreateImagePayload['params'],
|
||||
): Promise<BflRequest> {
|
||||
log('Building request payload for model: %s', model);
|
||||
|
||||
// Define parameter mapping (BFL API specific)
|
||||
const paramsMap = new Map<RuntimeImageGenParamsValue, string>([
|
||||
['aspectRatio', 'aspect_ratio'],
|
||||
['cfg', 'guidance'],
|
||||
]);
|
||||
|
||||
// Fixed parameters for all BFL models
|
||||
const defaultPayload: Record<string, unknown> = {
|
||||
output_format: 'png',
|
||||
safety_tolerance: 6,
|
||||
...(model.includes('ultra') && { raw: true }),
|
||||
};
|
||||
|
||||
// Map user parameters, filtering out undefined values
|
||||
const userPayload: Record<string, unknown> = Object.fromEntries(
|
||||
(Object.entries(params) as [keyof typeof params, any][])
|
||||
.filter(([, value]) => value !== undefined)
|
||||
.map(([key, value]) => [paramsMap.get(key) ?? key, value]),
|
||||
);
|
||||
|
||||
// Handle multiple input images (imageUrls) for Kontext models
|
||||
if (params.imageUrls && params.imageUrls.length > 0) {
|
||||
for (let i = 0; i < Math.min(params.imageUrls.length, 4); i++) {
|
||||
const fieldName = i === 0 ? 'input_image' : `input_image_${i + 1}`;
|
||||
userPayload[fieldName] = await convertImageToBase64(params.imageUrls[i]);
|
||||
}
|
||||
// Remove the original imageUrls field as it's now mapped to input_image_*
|
||||
delete userPayload.imageUrls;
|
||||
}
|
||||
|
||||
// Handle single image input (imageUrl)
|
||||
if (params.imageUrl) {
|
||||
userPayload.image_prompt = await convertImageToBase64(params.imageUrl);
|
||||
// Remove the original imageUrl field as it's now mapped to image_prompt
|
||||
delete userPayload.imageUrl;
|
||||
}
|
||||
|
||||
// Combine default and user payload
|
||||
const payload = {
|
||||
...defaultPayload,
|
||||
...userPayload,
|
||||
};
|
||||
|
||||
return payload as BflRequest;
|
||||
}
|
||||
|
||||
/**
|
||||
* Submit image generation task to BFL API
|
||||
*/
|
||||
async function submitTask(
|
||||
model: BflModelId,
|
||||
payload: BflRequest,
|
||||
options: BflCreateImageOptions,
|
||||
): Promise<BflAsyncResponse> {
|
||||
const endpoint = BFL_ENDPOINTS[model];
|
||||
const url = `${options.baseURL || BASE_URL}${endpoint}`;
|
||||
|
||||
log('Submitting task to: %s', url);
|
||||
|
||||
const response = await fetch(url, {
|
||||
body: JSON.stringify(payload),
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-key': options.apiKey,
|
||||
},
|
||||
method: 'POST',
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorData;
|
||||
try {
|
||||
errorData = await response.json();
|
||||
} catch {
|
||||
// Failed to parse JSON error response
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
`BFL API error (${response.status}): ${errorData?.detail?.[0]?.msg || response.statusText}`,
|
||||
);
|
||||
}
|
||||
|
||||
const data: BflAsyncResponse = await response.json();
|
||||
log('Task submitted successfully with ID: %s', data.id);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Query task status using BFL API
|
||||
*/
|
||||
async function queryTaskStatus(
|
||||
pollingUrl: string,
|
||||
options: BflCreateImageOptions,
|
||||
): Promise<BflResultResponse> {
|
||||
log('Querying task status using polling URL: %s', pollingUrl);
|
||||
|
||||
const response = await fetch(pollingUrl, {
|
||||
headers: {
|
||||
'accept': 'application/json',
|
||||
'x-key': options.apiKey,
|
||||
},
|
||||
method: 'GET',
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorData;
|
||||
try {
|
||||
errorData = await response.json();
|
||||
} catch {
|
||||
// Failed to parse JSON error response
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
`Failed to query task status (${response.status}): ${errorData?.detail?.[0]?.msg || response.statusText}`,
|
||||
);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create image using BFL API with async task polling
|
||||
*/
|
||||
export async function createBflImage(
|
||||
payload: CreateImagePayload,
|
||||
options: BflCreateImageOptions,
|
||||
): Promise<CreateImageResponse> {
|
||||
const { model, params } = payload;
|
||||
|
||||
if (!BFL_ENDPOINTS[model as BflModelId]) {
|
||||
throw AgentRuntimeError.createImage({
|
||||
error: new Error(`Unsupported BFL model: ${model}`),
|
||||
errorType: AgentRuntimeErrorType.ModelNotFound,
|
||||
provider: options.provider,
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
// 1. Build request payload
|
||||
const requestPayload = await buildRequestPayload(model as BflModelId, params);
|
||||
|
||||
// 2. Submit image generation task
|
||||
const taskResponse = await submitTask(model as BflModelId, requestPayload, options);
|
||||
|
||||
// 3. Poll task status until completion using asyncifyPolling
|
||||
return await asyncifyPolling<BflResultResponse, CreateImageResponse>({
|
||||
checkStatus: (taskStatus: BflResultResponse): TaskResult<CreateImageResponse> => {
|
||||
log('Task %s status: %s', taskResponse.id, taskStatus.status);
|
||||
|
||||
switch (taskStatus.status) {
|
||||
case BflStatusResponse.Ready: {
|
||||
if (!taskStatus.result?.sample) {
|
||||
return {
|
||||
error: new Error('Task succeeded but no image generated'),
|
||||
status: 'failed',
|
||||
};
|
||||
}
|
||||
|
||||
const imageUrl = taskStatus.result.sample;
|
||||
log('Image generated successfully: %s', imageUrl);
|
||||
|
||||
return {
|
||||
data: { imageUrl },
|
||||
status: 'success',
|
||||
};
|
||||
}
|
||||
case BflStatusResponse.Error:
|
||||
case BflStatusResponse.ContentModerated:
|
||||
case BflStatusResponse.RequestModerated: {
|
||||
// Extract error details if available, otherwise use status
|
||||
let errorMessage = `Image generation failed with status: ${taskStatus.status}`;
|
||||
|
||||
// Check for additional error details in various possible fields
|
||||
if (taskStatus.details && typeof taskStatus.details === 'object') {
|
||||
errorMessage += ` - Details: ${JSON.stringify(taskStatus.details)}`;
|
||||
} else if (taskStatus.result && typeof taskStatus.result === 'object') {
|
||||
errorMessage += ` - Result: ${JSON.stringify(taskStatus.result)}`;
|
||||
}
|
||||
|
||||
return {
|
||||
error: new Error(errorMessage),
|
||||
status: 'failed',
|
||||
};
|
||||
}
|
||||
case BflStatusResponse.TaskNotFound: {
|
||||
return {
|
||||
error: new Error('Task not found - may have expired'),
|
||||
status: 'failed',
|
||||
};
|
||||
}
|
||||
default: {
|
||||
// Continue polling for Pending status or other unknown statuses
|
||||
return { status: 'pending' };
|
||||
}
|
||||
}
|
||||
},
|
||||
logger: {
|
||||
debug: (message: any, ...args: any[]) => log(message, ...args),
|
||||
error: (message: any, ...args: any[]) => log(message, ...args),
|
||||
},
|
||||
pollingQuery: () => queryTaskStatus(taskResponse.polling_url, options),
|
||||
});
|
||||
} catch (error) {
|
||||
log('Error in createBflImage: %O', error);
|
||||
|
||||
throw AgentRuntimeError.createImage({
|
||||
error: error as any,
|
||||
errorType: 'ProviderBizError',
|
||||
provider: options.provider,
|
||||
});
|
||||
}
|
||||
}
|
||||
269
packages/model-runtime/src/bfl/index.test.ts
Normal file
269
packages/model-runtime/src/bfl/index.test.ts
Normal file
@@ -0,0 +1,269 @@
|
||||
// @vitest-environment node
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { CreateImagePayload } from '@/libs/model-runtime/types/image';
|
||||
|
||||
import { LobeBflAI } from './index';
|
||||
|
||||
// Mock the createBflImage function
|
||||
vi.mock('./createImage', () => ({
|
||||
createBflImage: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock the console.error to avoid polluting test output
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
const bizErrorType = 'ProviderBizError';
|
||||
const invalidErrorType = 'InvalidProviderAPIKey';
|
||||
|
||||
let instance: LobeBflAI;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
instance = new LobeBflAI({ apiKey: 'test-api-key' });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('LobeBflAI', () => {
|
||||
describe('init', () => {
|
||||
it('should correctly initialize with an API key', () => {
|
||||
const instance = new LobeBflAI({ apiKey: 'test_api_key' });
|
||||
expect(instance).toBeInstanceOf(LobeBflAI);
|
||||
});
|
||||
|
||||
it('should initialize with custom baseURL', () => {
|
||||
const customBaseURL = 'https://custom-api.bfl.ai';
|
||||
const instance = new LobeBflAI({
|
||||
apiKey: 'test_api_key',
|
||||
baseURL: customBaseURL,
|
||||
});
|
||||
expect(instance).toBeInstanceOf(LobeBflAI);
|
||||
});
|
||||
|
||||
it('should throw InvalidProviderAPIKey if no apiKey is provided', () => {
|
||||
expect(() => {
|
||||
new LobeBflAI({});
|
||||
}).toThrow();
|
||||
});
|
||||
|
||||
it('should throw InvalidProviderAPIKey if apiKey is undefined', () => {
|
||||
expect(() => {
|
||||
new LobeBflAI({ apiKey: undefined });
|
||||
}).toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('createImage', () => {
|
||||
let mockCreateBflImage: any;
|
||||
|
||||
beforeEach(async () => {
|
||||
const { createBflImage } = await import('./createImage');
|
||||
mockCreateBflImage = vi.mocked(createBflImage);
|
||||
});
|
||||
|
||||
it('should create image successfully with basic parameters', async () => {
|
||||
// Arrange
|
||||
const mockImageResponse = {
|
||||
imageUrl: 'https://example.com/generated-image.jpg',
|
||||
};
|
||||
mockCreateBflImage.mockResolvedValue(mockImageResponse);
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-dev',
|
||||
params: {
|
||||
prompt: 'A beautiful landscape with mountains',
|
||||
width: 1024,
|
||||
height: 1024,
|
||||
},
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await instance.createImage(payload);
|
||||
|
||||
// Assert
|
||||
expect(mockCreateBflImage).toHaveBeenCalledWith(payload, {
|
||||
apiKey: 'test-api-key',
|
||||
baseURL: undefined,
|
||||
provider: 'bfl',
|
||||
});
|
||||
expect(result).toEqual(mockImageResponse);
|
||||
});
|
||||
|
||||
it('should pass custom baseURL to createBflImage', async () => {
|
||||
// Arrange
|
||||
const customBaseURL = 'https://custom-api.bfl.ai';
|
||||
const customInstance = new LobeBflAI({
|
||||
apiKey: 'test-api-key',
|
||||
baseURL: customBaseURL,
|
||||
});
|
||||
|
||||
const mockImageResponse = {
|
||||
imageUrl: 'https://example.com/generated-image.jpg',
|
||||
};
|
||||
mockCreateBflImage.mockResolvedValue(mockImageResponse);
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-pro',
|
||||
params: {
|
||||
prompt: 'Test image',
|
||||
},
|
||||
};
|
||||
|
||||
// Act
|
||||
await customInstance.createImage(payload);
|
||||
|
||||
// Assert
|
||||
expect(mockCreateBflImage).toHaveBeenCalledWith(payload, {
|
||||
apiKey: 'test-api-key',
|
||||
baseURL: customBaseURL,
|
||||
provider: 'bfl',
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error handling', () => {
|
||||
it('should throw InvalidProviderAPIKey on 401 error', async () => {
|
||||
// Arrange
|
||||
const apiError = new Error('Unauthorized') as Error & { status: number };
|
||||
apiError.status = 401;
|
||||
mockCreateBflImage.mockRejectedValue(apiError);
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-dev',
|
||||
params: {
|
||||
prompt: 'Test image',
|
||||
},
|
||||
};
|
||||
|
||||
// Act & Assert
|
||||
await expect(instance.createImage(payload)).rejects.toEqual({
|
||||
error: { error: apiError },
|
||||
errorType: invalidErrorType,
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw ProviderBizError on other errors', async () => {
|
||||
// Arrange
|
||||
const apiError = new Error('Some other error');
|
||||
mockCreateBflImage.mockRejectedValue(apiError);
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-dev',
|
||||
params: {
|
||||
prompt: 'Test image',
|
||||
},
|
||||
};
|
||||
|
||||
// Act & Assert
|
||||
await expect(instance.createImage(payload)).rejects.toEqual({
|
||||
error: { error: apiError },
|
||||
errorType: bizErrorType,
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw ProviderBizError on non-401 status errors', async () => {
|
||||
// Arrange
|
||||
const apiError = new Error('Server error') as Error & { status: number };
|
||||
apiError.status = 500;
|
||||
mockCreateBflImage.mockRejectedValue(apiError);
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-dev',
|
||||
params: {
|
||||
prompt: 'Test image',
|
||||
},
|
||||
};
|
||||
|
||||
// Act & Assert
|
||||
await expect(instance.createImage(payload)).rejects.toEqual({
|
||||
error: { error: apiError },
|
||||
errorType: bizErrorType,
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw ProviderBizError on errors without status property', async () => {
|
||||
// Arrange
|
||||
const apiError = new Error('Network error');
|
||||
mockCreateBflImage.mockRejectedValue(apiError);
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-pro-1.1',
|
||||
params: {
|
||||
prompt: 'Test image',
|
||||
},
|
||||
};
|
||||
|
||||
// Act & Assert
|
||||
await expect(instance.createImage(payload)).rejects.toEqual({
|
||||
error: { error: apiError },
|
||||
errorType: bizErrorType,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge cases', () => {
|
||||
it('should handle different model types', async () => {
|
||||
// Arrange
|
||||
const mockImageResponse = {
|
||||
imageUrl: 'https://example.com/generated-image.jpg',
|
||||
};
|
||||
mockCreateBflImage.mockResolvedValue(mockImageResponse);
|
||||
|
||||
const models = [
|
||||
'flux-dev',
|
||||
'flux-pro',
|
||||
'flux-pro-1.1',
|
||||
'flux-pro-1.1-ultra',
|
||||
'flux-kontext-pro',
|
||||
'flux-kontext-max',
|
||||
];
|
||||
|
||||
// Act & Assert
|
||||
for (const model of models) {
|
||||
const payload: CreateImagePayload = {
|
||||
model,
|
||||
params: {
|
||||
prompt: `Test image for ${model}`,
|
||||
},
|
||||
};
|
||||
|
||||
await instance.createImage(payload);
|
||||
|
||||
expect(mockCreateBflImage).toHaveBeenCalledWith(payload, {
|
||||
apiKey: 'test-api-key',
|
||||
baseURL: undefined,
|
||||
provider: 'bfl',
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle empty params object', async () => {
|
||||
// Arrange
|
||||
const mockImageResponse = {
|
||||
imageUrl: 'https://example.com/generated-image.jpg',
|
||||
};
|
||||
mockCreateBflImage.mockResolvedValue(mockImageResponse);
|
||||
|
||||
const payload: CreateImagePayload = {
|
||||
model: 'flux-dev',
|
||||
params: {
|
||||
prompt: 'Empty params test',
|
||||
},
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await instance.createImage(payload);
|
||||
|
||||
// Assert
|
||||
expect(mockCreateBflImage).toHaveBeenCalledWith(payload, {
|
||||
apiKey: 'test-api-key',
|
||||
baseURL: undefined,
|
||||
provider: 'bfl',
|
||||
});
|
||||
expect(result).toEqual(mockImageResponse);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
49
packages/model-runtime/src/bfl/index.ts
Normal file
49
packages/model-runtime/src/bfl/index.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
import createDebug from 'debug';
|
||||
import { ClientOptions } from 'openai';
|
||||
|
||||
import { LobeRuntimeAI } from '../BaseAI';
|
||||
import { AgentRuntimeErrorType } from '../error';
|
||||
import { CreateImagePayload, CreateImageResponse } from '../types/image';
|
||||
import { AgentRuntimeError } from '../utils/createError';
|
||||
import { createBflImage } from './createImage';
|
||||
|
||||
const log = createDebug('lobe-image:bfl');
|
||||
|
||||
export class LobeBflAI implements LobeRuntimeAI {
|
||||
private apiKey: string;
|
||||
baseURL?: string;
|
||||
|
||||
constructor({ apiKey, baseURL }: ClientOptions = {}) {
|
||||
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
|
||||
|
||||
this.apiKey = apiKey;
|
||||
this.baseURL = baseURL || undefined;
|
||||
|
||||
log('BFL AI initialized');
|
||||
}
|
||||
|
||||
async createImage(payload: CreateImagePayload): Promise<CreateImageResponse> {
|
||||
const { model, params } = payload;
|
||||
log('Creating image with model: %s and params: %O', model, params);
|
||||
|
||||
try {
|
||||
return await createBflImage(payload, {
|
||||
apiKey: this.apiKey,
|
||||
baseURL: this.baseURL,
|
||||
provider: 'bfl',
|
||||
});
|
||||
} catch (error) {
|
||||
log('Error in createImage: %O', error);
|
||||
|
||||
// Check for authentication errors based on HTTP status or error properties
|
||||
if (error instanceof Error && 'status' in error && (error as any).status === 401) {
|
||||
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey, {
|
||||
error,
|
||||
});
|
||||
}
|
||||
|
||||
// Wrap other errors
|
||||
throw AgentRuntimeError.createError(AgentRuntimeErrorType.ProviderBizError, { error });
|
||||
}
|
||||
}
|
||||
}
|
||||
113
packages/model-runtime/src/bfl/types.ts
Normal file
113
packages/model-runtime/src/bfl/types.ts
Normal file
@@ -0,0 +1,113 @@
|
||||
// BFL API Types
|
||||
|
||||
export enum BflStatusResponse {
|
||||
ContentModerated = 'Content Moderated',
|
||||
Error = 'Error',
|
||||
Pending = 'Pending',
|
||||
Ready = 'Ready',
|
||||
RequestModerated = 'Request Moderated',
|
||||
TaskNotFound = 'Task not found',
|
||||
}
|
||||
|
||||
export interface BflAsyncResponse {
|
||||
id: string;
|
||||
polling_url: string;
|
||||
}
|
||||
|
||||
export interface BflAsyncWebhookResponse {
|
||||
id: string;
|
||||
status: string;
|
||||
webhook_url: string;
|
||||
}
|
||||
|
||||
export interface BflResultResponse {
|
||||
details?: Record<string, any> | null;
|
||||
id: string;
|
||||
preview?: Record<string, any> | null;
|
||||
progress?: number | null;
|
||||
result?: any;
|
||||
status: BflStatusResponse;
|
||||
}
|
||||
|
||||
// Kontext series request (flux-kontext-pro, flux-kontext-max)
|
||||
export interface BflFluxKontextRequest {
|
||||
aspect_ratio?: string | null;
|
||||
input_image?: string | null;
|
||||
input_image_2?: string | null;
|
||||
input_image_3?: string | null;
|
||||
input_image_4?: string | null;
|
||||
output_format?: 'jpeg' | 'png' | null;
|
||||
prompt: string;
|
||||
prompt_upsampling?: boolean;
|
||||
safety_tolerance?: number;
|
||||
seed?: number | null;
|
||||
webhook_secret?: string | null;
|
||||
webhook_url?: string | null;
|
||||
}
|
||||
|
||||
// FLUX 1.1 Pro request
|
||||
export interface BflFluxPro11Request {
|
||||
height?: number;
|
||||
image_prompt?: string | null;
|
||||
output_format?: 'jpeg' | 'png' | null;
|
||||
prompt?: string | null;
|
||||
prompt_upsampling?: boolean;
|
||||
safety_tolerance?: number;
|
||||
seed?: number | null;
|
||||
webhook_secret?: string | null;
|
||||
webhook_url?: string | null;
|
||||
width?: number;
|
||||
}
|
||||
|
||||
// FLUX 1.1 Pro Ultra request
|
||||
export interface BflFluxPro11UltraRequest {
|
||||
aspect_ratio?: string;
|
||||
prompt: string;
|
||||
raw?: boolean;
|
||||
safety_tolerance?: number;
|
||||
seed?: number | null;
|
||||
}
|
||||
|
||||
// FLUX Pro request
|
||||
export interface BflFluxProRequest {
|
||||
guidance?: number;
|
||||
height?: number;
|
||||
image_prompt?: string | null;
|
||||
prompt?: string | null;
|
||||
safety_tolerance?: number;
|
||||
seed?: number | null;
|
||||
steps?: number;
|
||||
width?: number;
|
||||
}
|
||||
|
||||
// FLUX Dev request
|
||||
export interface BflFluxDevRequest {
|
||||
guidance?: number;
|
||||
height?: number;
|
||||
image_prompt?: string | null;
|
||||
prompt: string;
|
||||
safety_tolerance?: number;
|
||||
seed?: number | null;
|
||||
steps?: number;
|
||||
width?: number;
|
||||
}
|
||||
|
||||
// Model endpoint mapping
|
||||
export const BFL_ENDPOINTS = {
|
||||
'flux-dev': '/v1/flux-dev',
|
||||
'flux-kontext-max': '/v1/flux-kontext-max',
|
||||
'flux-kontext-pro': '/v1/flux-kontext-pro',
|
||||
'flux-pro': '/v1/flux-pro',
|
||||
'flux-pro-1.1': '/v1/flux-pro-1.1',
|
||||
'flux-pro-1.1-ultra': '/v1/flux-pro-1.1-ultra',
|
||||
} as const;
|
||||
|
||||
export type BflModelId = keyof typeof BFL_ENDPOINTS;
|
||||
|
||||
// Union type for all request types
|
||||
export type BflRequest =
|
||||
| BflFluxKontextRequest
|
||||
| BflFluxPro11Request
|
||||
| BflFluxPro11UltraRequest
|
||||
| BflFluxProRequest
|
||||
| BflFluxDevRequest;
|
||||
@@ -3,6 +3,7 @@ export { LobeAzureAI } from './azureai';
|
||||
export { LobeAzureOpenAI } from './azureOpenai';
|
||||
export * from './BaseAI';
|
||||
export { LobeBedrockAI } from './bedrock';
|
||||
export { LobeBflAI } from './bfl';
|
||||
export { LobeDeepSeekAI } from './deepseek';
|
||||
export * from './error';
|
||||
export { LobeGoogleAI } from './google';
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import createDebug from 'debug';
|
||||
|
||||
import { CreateImagePayload, CreateImageResponse } from '../types/image';
|
||||
import { type TaskResult, asyncifyPolling } from '../utils/asyncifyPolling';
|
||||
import { AgentRuntimeError } from '../utils/createError';
|
||||
import { CreateImageOptions } from '../utils/openaiCompatibleFactory';
|
||||
|
||||
@@ -139,93 +140,47 @@ export async function createQwenImage(
|
||||
// 1. Create image generation task
|
||||
const taskId = await createImageTask(payload, apiKey);
|
||||
|
||||
// 2. Poll task status until completion
|
||||
let taskStatus: QwenImageTaskResponse | null = null;
|
||||
let retries = 0;
|
||||
let consecutiveFailures = 0;
|
||||
const maxConsecutiveFailures = 3; // Allow up to 3 consecutive query failures
|
||||
// Using Infinity for maxRetries is safe because:
|
||||
// 1. Vercel runtime has execution time limits
|
||||
// 2. Qwen's API will eventually return FAILED status for timed-out tasks
|
||||
// 3. Our exponential backoff ensures reasonable retry intervals
|
||||
const maxRetries = Infinity;
|
||||
const initialRetryInterval = 500; // 500ms initial interval
|
||||
const maxRetryInterval = 5000; // 5 seconds max interval
|
||||
const backoffMultiplier = 1.5; // exponential backoff multiplier
|
||||
// 2. Poll task status until completion using asyncifyPolling
|
||||
const result = await asyncifyPolling<QwenImageTaskResponse, CreateImageResponse>({
|
||||
checkStatus: (taskStatus: QwenImageTaskResponse): TaskResult<CreateImageResponse> => {
|
||||
log('Task %s status: %s', taskId, taskStatus.output.task_status);
|
||||
|
||||
while (retries < maxRetries) {
|
||||
try {
|
||||
taskStatus = await queryTaskStatus(taskId, apiKey);
|
||||
consecutiveFailures = 0; // Reset consecutive failures on success
|
||||
} catch (error) {
|
||||
consecutiveFailures++;
|
||||
log(
|
||||
'Failed to query task status (attempt %d/%d, consecutive failures: %d/%d): %O',
|
||||
retries + 1,
|
||||
maxRetries,
|
||||
consecutiveFailures,
|
||||
maxConsecutiveFailures,
|
||||
error,
|
||||
);
|
||||
if (taskStatus.output.task_status === 'SUCCEEDED') {
|
||||
if (!taskStatus.output.results || taskStatus.output.results.length === 0) {
|
||||
return {
|
||||
error: new Error('Task succeeded but no images generated'),
|
||||
status: 'failed',
|
||||
};
|
||||
}
|
||||
|
||||
// If we've failed too many times in a row, give up
|
||||
if (consecutiveFailures >= maxConsecutiveFailures) {
|
||||
throw new Error(
|
||||
`Failed to query task status after ${consecutiveFailures} consecutive attempts: ${error}`,
|
||||
);
|
||||
const imageUrl = taskStatus.output.results[0].url;
|
||||
log('Image generated successfully: %s', imageUrl);
|
||||
|
||||
return {
|
||||
data: { imageUrl },
|
||||
status: 'success',
|
||||
};
|
||||
}
|
||||
|
||||
// Wait before retrying
|
||||
const currentRetryInterval = Math.min(
|
||||
initialRetryInterval * Math.pow(backoffMultiplier, retries),
|
||||
maxRetryInterval,
|
||||
);
|
||||
await new Promise((resolve) => {
|
||||
setTimeout(resolve, currentRetryInterval);
|
||||
});
|
||||
retries++;
|
||||
continue; // Skip the rest of the loop and retry
|
||||
}
|
||||
|
||||
// At this point, taskStatus should not be null since we just got it successfully
|
||||
log(
|
||||
'Task %s status: %s (attempt %d/%d)',
|
||||
taskId,
|
||||
taskStatus!.output.task_status,
|
||||
retries + 1,
|
||||
maxRetries,
|
||||
);
|
||||
|
||||
if (taskStatus!.output.task_status === 'SUCCEEDED') {
|
||||
if (!taskStatus!.output.results || taskStatus!.output.results.length === 0) {
|
||||
throw new Error('Task succeeded but no images generated');
|
||||
if (taskStatus.output.task_status === 'FAILED') {
|
||||
const errorMessage = taskStatus.output.error_message || 'Image generation task failed';
|
||||
return {
|
||||
error: new Error(`Qwen image generation failed: ${errorMessage}`),
|
||||
status: 'failed',
|
||||
};
|
||||
}
|
||||
|
||||
// Return the first generated image
|
||||
const imageUrl = taskStatus!.output.results[0].url;
|
||||
log('Image generated successfully: %s', imageUrl);
|
||||
// Continue polling for pending/running status or other unknown statuses
|
||||
return { status: 'pending' };
|
||||
},
|
||||
logger: {
|
||||
debug: (message: any, ...args: any[]) => log(message, ...args),
|
||||
error: (message: any, ...args: any[]) => log(message, ...args),
|
||||
},
|
||||
pollingQuery: () => queryTaskStatus(taskId, apiKey),
|
||||
});
|
||||
|
||||
return { imageUrl };
|
||||
} else if (taskStatus!.output.task_status === 'FAILED') {
|
||||
throw new Error(taskStatus!.output.error_message || 'Image generation task failed');
|
||||
}
|
||||
|
||||
// Calculate dynamic retry interval with exponential backoff
|
||||
const currentRetryInterval = Math.min(
|
||||
initialRetryInterval * Math.pow(backoffMultiplier, retries),
|
||||
maxRetryInterval,
|
||||
);
|
||||
|
||||
log('Waiting %dms before next retry', currentRetryInterval);
|
||||
|
||||
// Wait before retrying
|
||||
await new Promise((resolve) => {
|
||||
setTimeout(resolve, currentRetryInterval);
|
||||
});
|
||||
retries++;
|
||||
}
|
||||
|
||||
throw new Error(`Image generation timeout after ${maxRetries} attempts`);
|
||||
return result;
|
||||
} catch (error) {
|
||||
log('Error in createQwenImage: %O', error);
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import { LobeAzureOpenAI } from './azureOpenai';
|
||||
import { LobeAzureAI } from './azureai';
|
||||
import { LobeBaichuanAI } from './baichuan';
|
||||
import { LobeBedrockAI } from './bedrock';
|
||||
import { LobeBflAI } from './bfl';
|
||||
import { LobeCloudflareAI } from './cloudflare';
|
||||
import { LobeCohereAI } from './cohere';
|
||||
import { LobeDeepSeekAI } from './deepseek';
|
||||
@@ -65,6 +66,7 @@ export const providerRuntimeMap = {
|
||||
azureai: LobeAzureAI,
|
||||
baichuan: LobeBaichuanAI,
|
||||
bedrock: LobeBedrockAI,
|
||||
bfl: LobeBflAI,
|
||||
cloudflare: LobeCloudflareAI,
|
||||
cohere: LobeCohereAI,
|
||||
deepseek: LobeDeepSeekAI,
|
||||
|
||||
491
packages/model-runtime/src/utils/asyncifyPolling.test.ts
Normal file
491
packages/model-runtime/src/utils/asyncifyPolling.test.ts
Normal file
@@ -0,0 +1,491 @@
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { type TaskResult, asyncifyPolling } from './asyncifyPolling';
|
||||
|
||||
describe('asyncifyPolling', () => {
|
||||
describe('basic functionality', () => {
|
||||
it('should return data when task succeeds immediately', async () => {
|
||||
const mockTask = vi.fn().mockResolvedValue({ status: 'completed', data: 'result' });
|
||||
const mockCheckStatus = vi.fn().mockReturnValue({
|
||||
status: 'success',
|
||||
data: 'result',
|
||||
} as TaskResult<string>);
|
||||
|
||||
const result = await asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
});
|
||||
|
||||
expect(result).toBe('result');
|
||||
expect(mockTask).toHaveBeenCalledTimes(1);
|
||||
expect(mockCheckStatus).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should poll multiple times until success', async () => {
|
||||
const mockTask = vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce({ status: 'pending' })
|
||||
.mockResolvedValueOnce({ status: 'pending' })
|
||||
.mockResolvedValueOnce({ status: 'completed', data: 'final-result' });
|
||||
|
||||
const mockCheckStatus = vi
|
||||
.fn()
|
||||
.mockReturnValueOnce({ status: 'pending' })
|
||||
.mockReturnValueOnce({ status: 'pending' })
|
||||
.mockReturnValueOnce({ status: 'success', data: 'final-result' });
|
||||
|
||||
const result = await asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
initialInterval: 10, // fast test
|
||||
});
|
||||
|
||||
expect(result).toBe('final-result');
|
||||
expect(mockTask).toHaveBeenCalledTimes(3);
|
||||
expect(mockCheckStatus).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it('should throw error when task fails', async () => {
|
||||
const mockTask = vi.fn().mockResolvedValue({ status: 'failed', error: 'Task failed' });
|
||||
const mockCheckStatus = vi.fn().mockReturnValue({
|
||||
status: 'failed',
|
||||
error: new Error('Task failed'),
|
||||
});
|
||||
|
||||
await expect(
|
||||
asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
}),
|
||||
).rejects.toThrow('Task failed');
|
||||
|
||||
expect(mockTask).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should handle pending status correctly', async () => {
|
||||
const mockTask = vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce({ status: 'processing' })
|
||||
.mockResolvedValueOnce({ status: 'done' });
|
||||
|
||||
const mockCheckStatus = vi
|
||||
.fn()
|
||||
.mockReturnValueOnce({ status: 'pending' })
|
||||
.mockReturnValueOnce({ status: 'success', data: 'completed' });
|
||||
|
||||
const result = await asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
initialInterval: 10,
|
||||
});
|
||||
|
||||
expect(result).toBe('completed');
|
||||
expect(mockTask).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('retry mechanism', () => {
|
||||
it('should retry with exponential backoff', async () => {
|
||||
const startTime = Date.now();
|
||||
const mockTask = vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce({ status: 'pending' })
|
||||
.mockResolvedValueOnce({ status: 'pending' })
|
||||
.mockResolvedValueOnce({ status: 'success' });
|
||||
|
||||
const mockCheckStatus = vi
|
||||
.fn()
|
||||
.mockReturnValueOnce({ status: 'pending' })
|
||||
.mockReturnValueOnce({ status: 'pending' })
|
||||
.mockReturnValueOnce({ status: 'success', data: 'done' });
|
||||
|
||||
await asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
initialInterval: 50,
|
||||
backoffMultiplier: 2,
|
||||
maxInterval: 200,
|
||||
});
|
||||
|
||||
const elapsed = Date.now() - startTime;
|
||||
// Should wait at least 50ms + 100ms = 150ms
|
||||
expect(elapsed).toBeGreaterThan(140);
|
||||
});
|
||||
|
||||
it('should respect maxInterval limit', async () => {
|
||||
const intervals: number[] = [];
|
||||
const originalSetTimeout = global.setTimeout;
|
||||
|
||||
global.setTimeout = vi.fn((callback, delay) => {
|
||||
intervals.push(delay as number);
|
||||
return originalSetTimeout(callback, 1); // fast execution
|
||||
}) as any;
|
||||
|
||||
const mockTask = vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce({ status: 'pending' })
|
||||
.mockResolvedValueOnce({ status: 'pending' })
|
||||
.mockResolvedValueOnce({ status: 'pending' })
|
||||
.mockResolvedValueOnce({ status: 'success' });
|
||||
|
||||
const mockCheckStatus = vi
|
||||
.fn()
|
||||
.mockReturnValueOnce({ status: 'pending' })
|
||||
.mockReturnValueOnce({ status: 'pending' })
|
||||
.mockReturnValueOnce({ status: 'pending' })
|
||||
.mockReturnValueOnce({ status: 'success', data: 'done' });
|
||||
|
||||
await asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
initialInterval: 100,
|
||||
backoffMultiplier: 3,
|
||||
maxInterval: 200,
|
||||
});
|
||||
|
||||
// Intervals should be: 100, 200 (capped), 200 (capped)
|
||||
expect(intervals).toEqual([100, 200, 200]);
|
||||
|
||||
global.setTimeout = originalSetTimeout;
|
||||
});
|
||||
|
||||
it('should stop after maxRetries', async () => {
|
||||
const mockTask = vi.fn().mockResolvedValue({ status: 'pending' });
|
||||
const mockCheckStatus = vi.fn().mockReturnValue({ status: 'pending' });
|
||||
|
||||
await expect(
|
||||
asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
maxRetries: 3,
|
||||
initialInterval: 1,
|
||||
}),
|
||||
).rejects.toThrow(/timeout after 3 attempts/);
|
||||
|
||||
expect(mockTask).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
});
|
||||
|
||||
describe('error handling', () => {
|
||||
it('should handle consecutive failures', async () => {
|
||||
const mockTask = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(new Error('Network error 1'))
|
||||
.mockRejectedValueOnce(new Error('Network error 2'))
|
||||
.mockResolvedValueOnce({ status: 'success' });
|
||||
|
||||
const mockCheckStatus = vi.fn().mockReturnValue({ status: 'success', data: 'recovered' });
|
||||
|
||||
const result = await asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
maxConsecutiveFailures: 3,
|
||||
initialInterval: 1,
|
||||
});
|
||||
|
||||
expect(result).toBe('recovered');
|
||||
expect(mockTask).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it('should throw after maxConsecutiveFailures', async () => {
|
||||
const mockTask = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(new Error('Network error 1'))
|
||||
.mockRejectedValueOnce(new Error('Network error 2'))
|
||||
.mockRejectedValueOnce(new Error('Network error 3'));
|
||||
|
||||
const mockCheckStatus = vi.fn();
|
||||
|
||||
await expect(
|
||||
asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
maxConsecutiveFailures: 2, // 允许最多2次连续失败
|
||||
initialInterval: 1,
|
||||
}),
|
||||
).rejects.toThrow(/consecutive attempts/);
|
||||
|
||||
expect(mockTask).toHaveBeenCalledTimes(2); // 第1次失败,第2次失败,然后抛出错误
|
||||
expect(mockCheckStatus).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should reset consecutive failures on success', async () => {
|
||||
const mockTask = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(new Error('Network error 1')) // Failure 1 (consecutiveFailures=1)
|
||||
.mockResolvedValueOnce({ status: 'pending' }) // Success 1 (reset to 0)
|
||||
.mockRejectedValueOnce(new Error('Network error 2')) // Failure 2 (consecutiveFailures=1)
|
||||
.mockRejectedValueOnce(new Error('Network error 3')) // Failure 3 (consecutiveFailures=2)
|
||||
.mockResolvedValueOnce({ status: 'success' }); // Success 2 (return result)
|
||||
|
||||
const mockCheckStatus = vi
|
||||
.fn()
|
||||
.mockReturnValueOnce({ status: 'pending' }) // For success 1
|
||||
.mockReturnValueOnce({ status: 'success', data: 'final' }); // For success 2
|
||||
|
||||
const result = await asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
maxConsecutiveFailures: 3, // Allow up to 3 consecutive failures (since there are 2 consecutive failures)
|
||||
initialInterval: 1,
|
||||
});
|
||||
|
||||
expect(result).toBe('final');
|
||||
expect(mockTask).toHaveBeenCalledTimes(5); // Total 5 calls
|
||||
});
|
||||
});
|
||||
|
||||
describe('configuration', () => {
|
||||
it('should use custom intervals and multipliers', async () => {
|
||||
const intervals: number[] = [];
|
||||
const originalSetTimeout = global.setTimeout;
|
||||
|
||||
global.setTimeout = vi.fn((callback, delay) => {
|
||||
intervals.push(delay as number);
|
||||
return originalSetTimeout(callback, 1);
|
||||
}) as any;
|
||||
|
||||
const mockTask = vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce({ status: 'pending' })
|
||||
.mockResolvedValueOnce({ status: 'success' });
|
||||
|
||||
const mockCheckStatus = vi
|
||||
.fn()
|
||||
.mockReturnValueOnce({ status: 'pending' })
|
||||
.mockReturnValueOnce({ status: 'success', data: 'done' });
|
||||
|
||||
await asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
initialInterval: 200,
|
||||
backoffMultiplier: 1.2,
|
||||
});
|
||||
|
||||
expect(intervals[0]).toBe(200);
|
||||
|
||||
global.setTimeout = originalSetTimeout;
|
||||
});
|
||||
|
||||
it('should accept custom logger function', async () => {
|
||||
const mockLogger = {
|
||||
debug: vi.fn(),
|
||||
error: vi.fn(),
|
||||
};
|
||||
|
||||
const mockTask = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(new Error('Test error'))
|
||||
.mockResolvedValueOnce({ status: 'success' });
|
||||
|
||||
const mockCheckStatus = vi.fn().mockReturnValue({ status: 'success', data: 'done' });
|
||||
|
||||
await asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
logger: mockLogger,
|
||||
maxConsecutiveFailures: 3,
|
||||
initialInterval: 1,
|
||||
});
|
||||
|
||||
expect(mockLogger.debug).toHaveBeenCalled();
|
||||
expect(mockLogger.error).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle immediate failure', async () => {
|
||||
const mockTask = vi.fn().mockResolvedValue({ error: 'immediate failure' });
|
||||
const mockCheckStatus = vi.fn().mockReturnValue({
|
||||
status: 'failed',
|
||||
error: new Error('immediate failure'),
|
||||
});
|
||||
|
||||
await expect(
|
||||
asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
}),
|
||||
).rejects.toThrow('immediate failure');
|
||||
|
||||
expect(mockTask).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should handle task throwing exceptions', async () => {
|
||||
const mockTask = vi.fn().mockRejectedValue(new Error('Task exception'));
|
||||
const mockCheckStatus = vi.fn();
|
||||
|
||||
await expect(
|
||||
asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
maxConsecutiveFailures: 1,
|
||||
initialInterval: 1,
|
||||
}),
|
||||
).rejects.toThrow(/consecutive attempts/);
|
||||
});
|
||||
|
||||
it('should timeout correctly with maxRetries = 1', async () => {
|
||||
const mockTask = vi.fn().mockResolvedValue({ status: 'pending' });
|
||||
const mockCheckStatus = vi.fn().mockReturnValue({ status: 'pending' });
|
||||
|
||||
await expect(
|
||||
asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
maxRetries: 1,
|
||||
initialInterval: 1,
|
||||
}),
|
||||
).rejects.toThrow(/timeout after 1 attempts/);
|
||||
|
||||
expect(mockTask).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('custom error handling', () => {
|
||||
it('should allow continuing polling via onPollingError', async () => {
|
||||
const mockTask = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(new Error('Network error'))
|
||||
.mockRejectedValueOnce(new Error('Another error'))
|
||||
.mockResolvedValueOnce({ status: 'success' });
|
||||
|
||||
const mockCheckStatus = vi.fn().mockReturnValue({ status: 'success', data: 'final-result' });
|
||||
|
||||
const onPollingError = vi.fn().mockReturnValue({
|
||||
isContinuePolling: true,
|
||||
});
|
||||
|
||||
const result = await asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
onPollingError,
|
||||
initialInterval: 1,
|
||||
});
|
||||
|
||||
expect(result).toBe('final-result');
|
||||
expect(mockTask).toHaveBeenCalledTimes(3);
|
||||
expect(onPollingError).toHaveBeenCalledTimes(2);
|
||||
|
||||
// Check that error context was passed correctly
|
||||
expect(onPollingError).toHaveBeenCalledWith({
|
||||
error: expect.any(Error),
|
||||
retries: expect.any(Number),
|
||||
consecutiveFailures: expect.any(Number),
|
||||
});
|
||||
});
|
||||
|
||||
it('should stop polling when onPollingError returns false', async () => {
|
||||
const mockTask = vi.fn().mockRejectedValue(new Error('Fatal error'));
|
||||
const mockCheckStatus = vi.fn();
|
||||
|
||||
const onPollingError = vi.fn().mockReturnValue({
|
||||
isContinuePolling: false,
|
||||
});
|
||||
|
||||
await expect(
|
||||
asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
onPollingError,
|
||||
initialInterval: 1,
|
||||
}),
|
||||
).rejects.toThrow('Fatal error');
|
||||
|
||||
expect(mockTask).toHaveBeenCalledTimes(1);
|
||||
expect(onPollingError).toHaveBeenCalledTimes(1);
|
||||
expect(mockCheckStatus).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should throw custom error when provided by onPollingError', async () => {
|
||||
const mockTask = vi.fn().mockRejectedValue(new Error('Original error'));
|
||||
const mockCheckStatus = vi.fn();
|
||||
|
||||
const customError = new Error('Custom error message');
|
||||
const onPollingError = vi.fn().mockReturnValue({
|
||||
isContinuePolling: false,
|
||||
error: customError,
|
||||
});
|
||||
|
||||
await expect(
|
||||
asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
onPollingError,
|
||||
initialInterval: 1,
|
||||
}),
|
||||
).rejects.toThrow('Custom error message');
|
||||
|
||||
expect(onPollingError).toHaveBeenCalledWith({
|
||||
error: expect.objectContaining({ message: 'Original error' }),
|
||||
retries: 0,
|
||||
consecutiveFailures: 1,
|
||||
});
|
||||
});
|
||||
|
||||
it('should provide correct context information to onPollingError', async () => {
|
||||
const mockTask = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(new Error('Error 1'))
|
||||
.mockRejectedValueOnce(new Error('Error 2'))
|
||||
.mockRejectedValueOnce(new Error('Error 3'));
|
||||
|
||||
const mockCheckStatus = vi.fn();
|
||||
|
||||
const onPollingError = vi
|
||||
.fn()
|
||||
.mockReturnValueOnce({ isContinuePolling: true })
|
||||
.mockReturnValueOnce({ isContinuePolling: true })
|
||||
.mockReturnValueOnce({ isContinuePolling: false });
|
||||
|
||||
await expect(
|
||||
asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
onPollingError,
|
||||
initialInterval: 1,
|
||||
}),
|
||||
).rejects.toThrow('Error 3');
|
||||
|
||||
// Verify context progression
|
||||
expect(onPollingError).toHaveBeenNthCalledWith(1, {
|
||||
error: expect.objectContaining({ message: 'Error 1' }),
|
||||
retries: 0,
|
||||
consecutiveFailures: 1,
|
||||
});
|
||||
|
||||
expect(onPollingError).toHaveBeenNthCalledWith(2, {
|
||||
error: expect.objectContaining({ message: 'Error 2' }),
|
||||
retries: 1,
|
||||
consecutiveFailures: 2,
|
||||
});
|
||||
|
||||
expect(onPollingError).toHaveBeenNthCalledWith(3, {
|
||||
error: expect.objectContaining({ message: 'Error 3' }),
|
||||
retries: 2,
|
||||
consecutiveFailures: 3,
|
||||
});
|
||||
});
|
||||
|
||||
it('should fall back to default behavior when onPollingError is not provided', async () => {
|
||||
const mockTask = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(new Error('Error 1'))
|
||||
.mockRejectedValueOnce(new Error('Error 2'))
|
||||
.mockRejectedValueOnce(new Error('Error 3'));
|
||||
|
||||
const mockCheckStatus = vi.fn();
|
||||
|
||||
await expect(
|
||||
asyncifyPolling({
|
||||
pollingQuery: mockTask,
|
||||
checkStatus: mockCheckStatus,
|
||||
maxConsecutiveFailures: 2,
|
||||
initialInterval: 1,
|
||||
}),
|
||||
).rejects.toThrow(/consecutive attempts/);
|
||||
|
||||
expect(mockTask).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
175
packages/model-runtime/src/utils/asyncifyPolling.ts
Normal file
175
packages/model-runtime/src/utils/asyncifyPolling.ts
Normal file
@@ -0,0 +1,175 @@
|
||||
export interface TaskResult<T> {
|
||||
data?: T;
|
||||
error?: any;
|
||||
status: 'pending' | 'success' | 'failed';
|
||||
}
|
||||
|
||||
export interface PollingErrorContext {
|
||||
consecutiveFailures: number;
|
||||
error: any;
|
||||
retries: number;
|
||||
}
|
||||
|
||||
export interface PollingErrorResult {
|
||||
error?: any;
|
||||
isContinuePolling: boolean; // If provided, will replace the original error when thrown
|
||||
}
|
||||
|
||||
export interface AsyncifyPollingOptions<T, R> {
|
||||
// Default 5000ms
|
||||
backoffMultiplier?: number;
|
||||
|
||||
// Status check function to determine task result
|
||||
checkStatus: (result: T) => TaskResult<R>;
|
||||
|
||||
// Retry configuration
|
||||
initialInterval?: number;
|
||||
// Optional logger
|
||||
logger?: {
|
||||
debug?: (...args: any[]) => void;
|
||||
error?: (...args: any[]) => void;
|
||||
};
|
||||
// Default 1.5
|
||||
maxConsecutiveFailures?: number;
|
||||
// Default 500ms
|
||||
maxInterval?: number; // Default 3
|
||||
maxRetries?: number; // Default Infinity
|
||||
|
||||
// Optional custom error handler for polling query failures
|
||||
onPollingError?: (context: PollingErrorContext) => PollingErrorResult;
|
||||
|
||||
// The polling function to execute repeatedly
|
||||
pollingQuery: () => Promise<T>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert polling pattern to async/await pattern
|
||||
*
|
||||
* @param options Polling configuration options
|
||||
* @returns Promise<R> The data returned when task completes
|
||||
* @throws Error When task fails or times out
|
||||
*/
|
||||
export async function asyncifyPolling<T, R>(options: AsyncifyPollingOptions<T, R>): Promise<R> {
|
||||
const {
|
||||
pollingQuery,
|
||||
checkStatus,
|
||||
initialInterval = 500,
|
||||
maxInterval = 5000,
|
||||
backoffMultiplier = 1.5,
|
||||
maxConsecutiveFailures = 3,
|
||||
maxRetries = Infinity,
|
||||
onPollingError,
|
||||
logger,
|
||||
} = options;
|
||||
|
||||
let retries = 0;
|
||||
let consecutiveFailures = 0;
|
||||
|
||||
while (retries < maxRetries) {
|
||||
let pollingResult: T;
|
||||
|
||||
try {
|
||||
// Execute polling function
|
||||
pollingResult = await pollingQuery();
|
||||
|
||||
// Reset consecutive failures counter on successful execution
|
||||
consecutiveFailures = 0;
|
||||
} catch (error) {
|
||||
// Polling function execution failed (network error, etc.)
|
||||
consecutiveFailures++;
|
||||
|
||||
logger?.error?.(
|
||||
`Failed to execute polling function (attempt ${retries + 1}/${maxRetries === Infinity ? '∞' : maxRetries}, consecutive failures: ${consecutiveFailures}/${maxConsecutiveFailures}):`,
|
||||
error,
|
||||
);
|
||||
|
||||
// Handle custom error processing if provided
|
||||
if (onPollingError) {
|
||||
const errorResult = onPollingError({
|
||||
consecutiveFailures,
|
||||
error,
|
||||
retries,
|
||||
});
|
||||
|
||||
if (!errorResult.isContinuePolling) {
|
||||
// Custom error handler decided to stop polling
|
||||
throw errorResult.error || error;
|
||||
}
|
||||
|
||||
// Custom error handler decided to continue polling
|
||||
logger?.debug?.('Custom error handler decided to continue polling');
|
||||
} else {
|
||||
// Default behavior: check if maximum consecutive failures reached
|
||||
if (consecutiveFailures >= maxConsecutiveFailures) {
|
||||
throw new Error(
|
||||
`Failed to execute polling function after ${consecutiveFailures} consecutive attempts: ${error}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait before retry and continue to next loop iteration
|
||||
if (retries < maxRetries - 1) {
|
||||
const currentInterval = Math.min(
|
||||
initialInterval * Math.pow(backoffMultiplier, retries),
|
||||
maxInterval,
|
||||
);
|
||||
|
||||
logger?.debug?.(`Waiting ${currentInterval}ms before next retry`);
|
||||
|
||||
await new Promise((resolve) => {
|
||||
setTimeout(resolve, currentInterval);
|
||||
});
|
||||
}
|
||||
|
||||
retries++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check task status
|
||||
const statusResult = checkStatus(pollingResult);
|
||||
|
||||
logger?.debug?.(`Task status: ${statusResult.status} (attempt ${retries + 1})`);
|
||||
|
||||
switch (statusResult.status) {
|
||||
case 'success': {
|
||||
return statusResult.data as R;
|
||||
}
|
||||
|
||||
case 'failed': {
|
||||
// Task logic failed, throw error immediately (not counted as consecutive failure)
|
||||
throw statusResult.error || new Error('Task failed');
|
||||
}
|
||||
|
||||
case 'pending': {
|
||||
// Continue polling
|
||||
break;
|
||||
}
|
||||
|
||||
default: {
|
||||
// Unknown status, treat as pending
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Wait before next retry if not the last attempt
|
||||
if (retries < maxRetries - 1) {
|
||||
// Calculate dynamic retry interval with exponential backoff
|
||||
const currentInterval = Math.min(
|
||||
initialInterval * Math.pow(backoffMultiplier, retries),
|
||||
maxInterval,
|
||||
);
|
||||
|
||||
logger?.debug?.(`Waiting ${currentInterval}ms before next retry`);
|
||||
|
||||
// Wait for retry interval
|
||||
await new Promise((resolve) => {
|
||||
setTimeout(resolve, currentInterval);
|
||||
});
|
||||
}
|
||||
|
||||
retries++;
|
||||
}
|
||||
|
||||
// Maximum retries reached
|
||||
throw new Error(`Task timeout after ${maxRetries} attempts`);
|
||||
}
|
||||
@@ -46,7 +46,7 @@ const ConfigPanel = memo(() => {
|
||||
const { showDimensionControl } = useDimensionControl();
|
||||
|
||||
return (
|
||||
<Flexbox gap={32} padding={12}>
|
||||
<Flexbox gap={32} padding={12} style={{ overflow: 'auto' }}>
|
||||
<ConfigItemLayout>
|
||||
<ModelSelect />
|
||||
</ConfigItemLayout>
|
||||
|
||||
145
src/config/aiModels/bfl.ts
Normal file
145
src/config/aiModels/bfl.ts
Normal file
@@ -0,0 +1,145 @@
|
||||
import { PRESET_ASPECT_RATIOS } from '@/const/image';
|
||||
import { ModelParamsSchema } from '@/libs/standard-parameters';
|
||||
import { AIImageModelCard } from '@/types/aiModel';
|
||||
|
||||
// https://docs.bfl.ai/api-reference/tasks/edit-or-create-an-image-with-flux-kontext-pro
|
||||
// official support 21:9 ~ 9:21 (ratio 0.43 ~ 2.33)
|
||||
const calculateRatio = (aspectRatio: string): number => {
|
||||
const [width, height] = aspectRatio.split(':').map(Number);
|
||||
return width / height;
|
||||
};
|
||||
|
||||
const defaultAspectRatios = PRESET_ASPECT_RATIOS.filter((ratio) => {
|
||||
const value = calculateRatio(ratio);
|
||||
// BFL API supports ratio range: 21:9 ~ 9:21 (approximately 0.43 ~ 2.33)
|
||||
// Use a small tolerance for floating point comparison
|
||||
return value >= 9 / 21 - 0.001 && value <= 21 / 9 + 0.001;
|
||||
});
|
||||
|
||||
const fluxKontextSeriesParamsSchema: ModelParamsSchema = {
|
||||
aspectRatio: {
|
||||
default: '1:1',
|
||||
enum: defaultAspectRatios,
|
||||
},
|
||||
imageUrls: {
|
||||
default: [],
|
||||
},
|
||||
prompt: { default: '' },
|
||||
seed: { default: null },
|
||||
};
|
||||
|
||||
const imageModels: AIImageModelCard[] = [
|
||||
// https://docs.bfl.ai/api-reference/tasks/edit-or-create-an-image-with-flux-kontext-pro
|
||||
{
|
||||
description: '最先进的上下文图像生成和编辑——结合文本和图像以获得精确、连贯的结果。',
|
||||
displayName: 'FLUX.1 Kontext [pro]',
|
||||
enabled: true,
|
||||
id: 'flux-kontext-pro',
|
||||
parameters: fluxKontextSeriesParamsSchema,
|
||||
// check: https://bfl.ai/pricing
|
||||
pricing: {
|
||||
units: [{ name: 'imageGeneration', rate: 0.04, strategy: 'fixed', unit: 'image' }],
|
||||
},
|
||||
releasedAt: '2025-05-29',
|
||||
type: 'image',
|
||||
},
|
||||
// https://docs.bfl.ai/api-reference/tasks/edit-or-create-an-image-with-flux-kontext-max
|
||||
{
|
||||
description: '最先进的上下文图像生成和编辑——结合文本和图像以获得精确、连贯的结果。',
|
||||
displayName: 'FLUX.1 Kontext [max]',
|
||||
enabled: true,
|
||||
id: 'flux-kontext-max',
|
||||
parameters: fluxKontextSeriesParamsSchema,
|
||||
pricing: {
|
||||
units: [{ name: 'imageGeneration', rate: 0.08, strategy: 'fixed', unit: 'image' }],
|
||||
},
|
||||
releasedAt: '2025-05-29',
|
||||
type: 'image',
|
||||
},
|
||||
// https://docs.bfl.ai/api-reference/tasks/generate-an-image-with-flux-11-[pro]
|
||||
{
|
||||
description: '升级版专业级AI图像生成模型——提供卓越的图像质量和精确的提示词遵循能力。',
|
||||
displayName: 'FLUX1.1 [pro] ',
|
||||
enabled: true,
|
||||
id: 'flux-pro-1.1',
|
||||
parameters: {
|
||||
height: { default: 768, max: 1440, min: 256, step: 32 },
|
||||
imageUrl: { default: null },
|
||||
prompt: { default: '' },
|
||||
seed: { default: null },
|
||||
width: { default: 1024, max: 1440, min: 256, step: 32 },
|
||||
},
|
||||
pricing: {
|
||||
units: [{ name: 'imageGeneration', rate: 0.06, strategy: 'fixed', unit: 'image' }],
|
||||
},
|
||||
releasedAt: '2024-10-02',
|
||||
type: 'image',
|
||||
},
|
||||
// https://docs.bfl.ai/api-reference/tasks/generate-an-image-with-flux-11-[pro]-with-ultra-mode-and-optional-raw-mode
|
||||
{
|
||||
description: '超高分辨率AI图像生成——支持4兆像素输出,10秒内生成超清图像。',
|
||||
displayName: 'FLUX1.1 [pro] Ultra',
|
||||
enabled: true,
|
||||
id: 'flux-pro-1.1-ultra',
|
||||
parameters: {
|
||||
aspectRatio: {
|
||||
default: '16:9',
|
||||
enum: defaultAspectRatios,
|
||||
},
|
||||
imageUrl: { default: null },
|
||||
prompt: { default: '' },
|
||||
seed: { default: null },
|
||||
},
|
||||
pricing: {
|
||||
units: [{ name: 'imageGeneration', rate: 0.06, strategy: 'fixed', unit: 'image' }],
|
||||
},
|
||||
releasedAt: '2024-11-06',
|
||||
type: 'image',
|
||||
},
|
||||
// https://docs.bfl.ai/api-reference/tasks/generate-an-image-with-flux1-[pro]
|
||||
{
|
||||
description: '顶级商用AI图像生成模型——无与伦比的图像质量和多样化输出表现。',
|
||||
displayName: 'FLUX.1 [pro]',
|
||||
enabled: true,
|
||||
id: 'flux-pro',
|
||||
parameters: {
|
||||
cfg: { default: 2.5, max: 5, min: 1.5, step: 0.1 },
|
||||
height: { default: 768, max: 1440, min: 256, step: 32 },
|
||||
imageUrl: { default: null },
|
||||
prompt: { default: '' },
|
||||
seed: { default: null },
|
||||
steps: { default: 40, max: 50, min: 1 },
|
||||
width: { default: 1024, max: 1440, min: 256, step: 32 },
|
||||
},
|
||||
pricing: {
|
||||
units: [{ name: 'imageGeneration', rate: 0.025, strategy: 'fixed', unit: 'image' }],
|
||||
},
|
||||
releasedAt: '2024-08-01',
|
||||
type: 'image',
|
||||
},
|
||||
// https://docs.bfl.ai/api-reference/tasks/generate-an-image-with-flux1-[dev]
|
||||
{
|
||||
description: '开源研发版AI图像生成模型——高效优化,适合非商业用途的创新研究。',
|
||||
displayName: 'FLUX.1 [dev]',
|
||||
enabled: true,
|
||||
id: 'flux-dev',
|
||||
parameters: {
|
||||
cfg: { default: 3, max: 5, min: 1.5, step: 0.1 },
|
||||
height: { default: 768, max: 1440, min: 256, step: 32 },
|
||||
imageUrl: { default: null },
|
||||
prompt: { default: '' },
|
||||
seed: { default: null },
|
||||
steps: { default: 28, max: 50, min: 1 },
|
||||
width: { default: 1024, max: 1440, min: 256, step: 32 },
|
||||
},
|
||||
pricing: {
|
||||
units: [{ name: 'imageGeneration', rate: 0.025, strategy: 'fixed', unit: 'image' }],
|
||||
},
|
||||
releasedAt: '2024-08-01',
|
||||
type: 'image',
|
||||
},
|
||||
];
|
||||
|
||||
export const allModels = [...imageModels];
|
||||
|
||||
export default allModels;
|
||||
@@ -9,6 +9,7 @@ import { default as azure } from './azure';
|
||||
import { default as azureai } from './azureai';
|
||||
import { default as baichuan } from './baichuan';
|
||||
import { default as bedrock } from './bedrock';
|
||||
import { default as bfl } from './bfl';
|
||||
import { default as cloudflare } from './cloudflare';
|
||||
import { default as cohere } from './cohere';
|
||||
import { default as deepseek } from './deepseek';
|
||||
@@ -87,6 +88,7 @@ export const LOBE_DEFAULT_MODEL_LIST = buildDefaultModelList({
|
||||
azureai,
|
||||
baichuan,
|
||||
bedrock,
|
||||
bfl,
|
||||
cloudflare,
|
||||
cohere,
|
||||
deepseek,
|
||||
@@ -146,6 +148,7 @@ export { default as azure } from './azure';
|
||||
export { default as azureai } from './azureai';
|
||||
export { default as baichuan } from './baichuan';
|
||||
export { default as bedrock } from './bedrock';
|
||||
export { default as bfl } from './bfl';
|
||||
export { default as cloudflare } from './cloudflare';
|
||||
export { default as cohere } from './cohere';
|
||||
export { default as deepseek } from './deepseek';
|
||||
|
||||
@@ -166,13 +166,15 @@ export const getLLMConfig = () => {
|
||||
ENABLED_FAL: z.boolean(),
|
||||
FAL_API_KEY: z.string().optional(),
|
||||
|
||||
ENABLED_BFL: z.boolean(),
|
||||
BFL_API_KEY: z.string().optional(),
|
||||
|
||||
ENABLED_MODELSCOPE: z.boolean(),
|
||||
MODELSCOPE_API_KEY: z.string().optional(),
|
||||
|
||||
ENABLED_V0: z.boolean(),
|
||||
V0_API_KEY: z.string().optional(),
|
||||
|
||||
|
||||
ENABLED_AI302: z.boolean(),
|
||||
AI302_API_KEY: z.string().optional(),
|
||||
|
||||
@@ -342,6 +344,9 @@ export const getLLMConfig = () => {
|
||||
ENABLED_FAL: process.env.ENABLED_FAL !== '0',
|
||||
FAL_API_KEY: process.env.FAL_API_KEY,
|
||||
|
||||
ENABLED_BFL: !!process.env.BFL_API_KEY,
|
||||
BFL_API_KEY: process.env.BFL_API_KEY,
|
||||
|
||||
ENABLED_MODELSCOPE: !!process.env.MODELSCOPE_API_KEY,
|
||||
MODELSCOPE_API_KEY: process.env.MODELSCOPE_API_KEY,
|
||||
|
||||
|
||||
21
src/config/modelProviders/bfl.ts
Normal file
21
src/config/modelProviders/bfl.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
import { ModelProviderCard } from '@/types/llm';
|
||||
|
||||
/**
|
||||
* @see https://docs.bfl.ai/
|
||||
*/
|
||||
const Bfl: ModelProviderCard = {
|
||||
chatModels: [],
|
||||
description: '领先的前沿人工智能研究实验室,构建明日的视觉基础设施。',
|
||||
enabled: true,
|
||||
id: 'bfl',
|
||||
name: 'Black Forest Labs',
|
||||
settings: {
|
||||
disableBrowserRequest: true,
|
||||
showAddNewModel: false,
|
||||
showChecker: false,
|
||||
showModelFetcher: false,
|
||||
},
|
||||
url: 'https://bfl.ai/',
|
||||
};
|
||||
|
||||
export default Bfl;
|
||||
@@ -9,6 +9,7 @@ import AzureProvider from './azure';
|
||||
import AzureAIProvider from './azureai';
|
||||
import BaichuanProvider from './baichuan';
|
||||
import BedrockProvider from './bedrock';
|
||||
import BflProvider from './bfl';
|
||||
import CloudflareProvider from './cloudflare';
|
||||
import CohereProvider from './cohere';
|
||||
import DeepSeekProvider from './deepseek';
|
||||
@@ -132,6 +133,7 @@ export const DEFAULT_MODEL_PROVIDER_LIST = [
|
||||
HuggingFaceProvider,
|
||||
CloudflareProvider,
|
||||
GithubProvider,
|
||||
BflProvider,
|
||||
NovitaProvider,
|
||||
PPIOProvider,
|
||||
NvidiaProvider,
|
||||
@@ -191,6 +193,7 @@ export { default as AzureProviderCard } from './azure';
|
||||
export { default as AzureAIProviderCard } from './azureai';
|
||||
export { default as BaichuanProviderCard } from './baichuan';
|
||||
export { default as BedrockProviderCard } from './bedrock';
|
||||
export { default as BflProviderCard } from './bfl';
|
||||
export { default as CloudflareProviderCard } from './cloudflare';
|
||||
export { default as CohereProviderCard } from './cohere';
|
||||
export { default as DeepSeekProviderCard } from './deepseek';
|
||||
|
||||
@@ -77,17 +77,13 @@ export function useDimensionControl() {
|
||||
const aspectRatioOptions = useMemo(() => {
|
||||
const modelOptions = paramsSchema?.aspectRatio?.enum || [];
|
||||
|
||||
// 合并选项,优先使用预设选项,然后添加模型特有的选项
|
||||
const allOptions = [...PRESET_ASPECT_RATIOS];
|
||||
// 如果 schema 里面有 aspectRatio 并且不为空,直接使用 schema 里面的选项
|
||||
if (modelOptions.length > 0) {
|
||||
return modelOptions;
|
||||
}
|
||||
|
||||
// 添加模型选项中不在预设中的选项
|
||||
modelOptions.forEach((option) => {
|
||||
if (!allOptions.includes(option)) {
|
||||
allOptions.push(option);
|
||||
}
|
||||
});
|
||||
|
||||
return allOptions;
|
||||
// 否则使用预设选项
|
||||
return PRESET_ASPECT_RATIOS;
|
||||
}, [paramsSchema]);
|
||||
|
||||
// 只要不是所有维度相关的控件都不显示,那么这个容器就应该显示
|
||||
|
||||
Reference in New Issue
Block a user