From b4ba8bf454dfcdc421ea4a2ceccdd253d9e098f0 Mon Sep 17 00:00:00 2001 From: Arvin Xu Date: Wed, 7 Jan 2026 23:22:19 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20model=20runtime=20provide?= =?UTF-8?q?r=20issue=20=20(#11314)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix * upload * update * fix * fix tests --- .github/workflows/test.yml | 6 +- .gitignore | 1 + src/server/modules/ModelRuntime/index.test.ts | 215 +++++++++++++++++- src/server/modules/ModelRuntime/index.ts | 52 ++++- vitest.config.mts | 1 + 5 files changed, 263 insertions(+), 12 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5a5c4ad55f..f4a97ffef5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -102,8 +102,8 @@ jobs: - name: Install deps run: bun i - - name: Run tests with blob reporter - run: bunx vitest --coverage --reporter=blob --silent='passed-only' --shard=${{ matrix.shard }}/2 + - name: Run tests + run: bunx vitest --coverage --silent='passed-only' --shard=${{ matrix.shard }}/2 - name: Upload blob report if: ${{ !cancelled() }} @@ -139,7 +139,7 @@ jobs: merge-multiple: true - name: Merge reports - run: bunx vitest --merge-reports --coverage + run: bunx vitest --merge-reports --reporter=default --coverage - name: Upload App Coverage to Codecov uses: codecov/codecov-action@v5 diff --git a/.gitignore b/.gitignore index 7886911589..6497d86cc0 100644 --- a/.gitignore +++ b/.gitignore @@ -116,3 +116,4 @@ CLAUDE.local.md e2e/reports out i18n-unused-keys-report.json +.vitest-reports diff --git a/src/server/modules/ModelRuntime/index.test.ts b/src/server/modules/ModelRuntime/index.test.ts index a52df0182c..ecedfc24f1 100644 --- a/src/server/modules/ModelRuntime/index.test.ts +++ b/src/server/modules/ModelRuntime/index.test.ts @@ -26,7 +26,7 @@ import { ClientSecretPayload } from '@lobechat/types'; import { ModelProvider } from 'model-bank'; import { describe, expect, it, vi } from 'vitest'; -import { initModelRuntimeWithUserPayload } from './index'; +import { buildPayloadFromKeyVaults, initModelRuntimeWithUserPayload } from './index'; // 模拟依赖项 vi.mock('@/envs/llm', () => ({ @@ -496,3 +496,216 @@ describe('initModelRuntimeWithUserPayload method', () => { }); }); }); + +/** + * Test cases for buildPayloadFromKeyVaults function + * This function builds ClientSecretPayload based on runtimeProvider (sdkType) + * to ensure provider-specific fields are correctly forwarded + */ +describe('buildPayloadFromKeyVaults', () => { + describe('should build payload with correct fields based on runtimeProvider', () => { + it('OpenAI compatible: returns apiKey, baseURL and runtimeProvider', () => { + const keyVaults = { + apiKey: 'test-api-key', + baseURL: 'https://custom-endpoint.com/v1', + }; + const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.OpenAI); + + expect(payload).toEqual({ + apiKey: 'test-api-key', + baseURL: 'https://custom-endpoint.com/v1', + runtimeProvider: ModelProvider.OpenAI, + }); + }); + + it('Azure: returns apiKey, baseURL, azureApiVersion and runtimeProvider', () => { + const keyVaults = { + apiKey: 'azure-api-key', + baseURL: 'https://my-azure.openai.azure.com', + apiVersion: '2024-06-01', + endpoint: 'https://fallback-endpoint.com', + }; + const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Azure); + + expect(payload).toEqual({ + apiKey: 'azure-api-key', + azureApiVersion: '2024-06-01', + baseURL: 'https://my-azure.openai.azure.com', + runtimeProvider: ModelProvider.Azure, + }); + }); + + it('Azure: uses endpoint as fallback when baseURL is not provided', () => { + const keyVaults = { + apiKey: 'azure-api-key', + endpoint: 'https://fallback-endpoint.com', + apiVersion: '2024-06-01', + }; + const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Azure); + + expect(payload.baseURL).toBe('https://fallback-endpoint.com'); + }); + + it('Cloudflare: returns apiKey, cloudflareBaseURLOrAccountID and runtimeProvider', () => { + const keyVaults = { + apiKey: 'cloudflare-api-key', + baseURLOrAccountID: 'my-account-id', + }; + const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Cloudflare); + + expect(payload).toEqual({ + apiKey: 'cloudflare-api-key', + cloudflareBaseURLOrAccountID: 'my-account-id', + runtimeProvider: ModelProvider.Cloudflare, + }); + }); + + it('Bedrock: returns AWS credentials and runtimeProvider', () => { + const keyVaults = { + accessKeyId: 'aws-access-key', + secretAccessKey: 'aws-secret-key', + region: 'us-east-1', + sessionToken: 'session-token', + }; + const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Bedrock); + + expect(payload).toEqual({ + apiKey: 'aws-secret-keyaws-access-key', + awsAccessKeyId: 'aws-access-key', + awsRegion: 'us-east-1', + awsSecretAccessKey: 'aws-secret-key', + awsSessionToken: 'session-token', + runtimeProvider: ModelProvider.Bedrock, + }); + }); + + it('Ollama: returns baseURL and runtimeProvider', () => { + const keyVaults = { + baseURL: 'http://localhost:11434', + }; + const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Ollama); + + expect(payload).toEqual({ + baseURL: 'http://localhost:11434', + runtimeProvider: ModelProvider.Ollama, + }); + }); + + it('VertexAI: returns apiKey, baseURL, vertexAIRegion and runtimeProvider', () => { + const keyVaults = { + apiKey: 'vertex-credentials-json', + baseURL: 'https://vertex-endpoint.com', + region: 'us-central1', + }; + const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.VertexAI); + + expect(payload).toEqual({ + apiKey: 'vertex-credentials-json', + baseURL: 'https://vertex-endpoint.com', + runtimeProvider: ModelProvider.VertexAI, + vertexAIRegion: 'us-central1', + }); + }); + + it('ComfyUI: returns all auth fields and runtimeProvider', () => { + const keyVaults = { + apiKey: 'comfyui-api-key', + authType: 'bearer', + baseURL: 'http://localhost:8188', + customHeaders: { 'X-Custom': 'header' }, + password: 'pass', + username: 'user', + } as const; + const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.ComfyUI); + + expect(payload).toEqual({ + apiKey: 'comfyui-api-key', + authType: 'bearer', + baseURL: 'http://localhost:8188', + customHeaders: { 'X-Custom': 'header' }, + password: 'pass', + runtimeProvider: ModelProvider.ComfyUI, + username: 'user', + }); + }); + + it('Unknown provider: falls back to default with apiKey, baseURL and runtimeProvider', () => { + const keyVaults = { + apiKey: 'unknown-api-key', + baseURL: 'https://unknown-endpoint.com', + }; + const payload = buildPayloadFromKeyVaults(keyVaults, 'unknown-provider'); + + expect(payload).toEqual({ + apiKey: 'unknown-api-key', + baseURL: 'https://unknown-endpoint.com', + runtimeProvider: 'unknown-provider', + }); + }); + }); + + describe('custom provider with sdkType should include provider-specific fields', () => { + it('custom provider with Azure sdkType includes azureApiVersion', () => { + const keyVaults = { + apiKey: 'custom-azure-key', + baseURL: 'https://custom-azure.openai.azure.com', + apiVersion: '2024-06-01', + }; + // Simulates a custom provider where runtimeProvider is resolved to 'azure' + const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Azure); + + expect(payload.azureApiVersion).toBe('2024-06-01'); + expect(payload.runtimeProvider).toBe(ModelProvider.Azure); + }); + + it('custom provider with Cloudflare sdkType includes cloudflareBaseURLOrAccountID', () => { + const keyVaults = { + apiKey: 'custom-cloudflare-key', + baseURLOrAccountID: 'custom-account-id', + }; + // Simulates a custom provider where runtimeProvider is resolved to 'cloudflare' + const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Cloudflare); + + expect(payload.cloudflareBaseURLOrAccountID).toBe('custom-account-id'); + expect(payload.runtimeProvider).toBe(ModelProvider.Cloudflare); + }); + + it('custom provider with Bedrock sdkType includes AWS credentials', () => { + const keyVaults = { + accessKeyId: 'custom-aws-id', + secretAccessKey: 'custom-aws-secret', + region: 'eu-west-1', + }; + // Simulates a custom provider where runtimeProvider is resolved to 'bedrock' + const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Bedrock); + + expect(payload.awsAccessKeyId).toBe('custom-aws-id'); + expect(payload.awsSecretAccessKey).toBe('custom-aws-secret'); + expect(payload.awsRegion).toBe('eu-west-1'); + expect(payload.runtimeProvider).toBe(ModelProvider.Bedrock); + }); + + it('custom provider with Ollama sdkType includes baseURL', () => { + const keyVaults = { + baseURL: 'http://custom-ollama:11434', + }; + // Simulates a custom provider where runtimeProvider is resolved to 'ollama' + const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Ollama); + + expect(payload.baseURL).toBe('http://custom-ollama:11434'); + expect(payload.runtimeProvider).toBe(ModelProvider.Ollama); + }); + + it('custom provider with VertexAI sdkType includes vertexAIRegion', () => { + const keyVaults = { + apiKey: 'custom-vertex-creds', + region: 'asia-northeast1', + }; + // Simulates a custom provider where runtimeProvider is resolved to 'vertexai' + const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.VertexAI); + + expect(payload.vertexAIRegion).toBe('asia-northeast1'); + expect(payload.runtimeProvider).toBe(ModelProvider.VertexAI); + }); + }); +}); diff --git a/src/server/modules/ModelRuntime/index.ts b/src/server/modules/ModelRuntime/index.ts index 5da1d5df4f..fe07cbe06d 100644 --- a/src/server/modules/ModelRuntime/index.ts +++ b/src/server/modules/ModelRuntime/index.ts @@ -32,6 +32,24 @@ type ProviderKeyVaults = OpenAICompatibleKeyVault & ComfyUIKeyVault & VertexAIKeyVault; +/** + * Resolve the runtime provider for a given provider. + * + * This is the server-side equivalent of the frontend's resolveRuntimeProvider function. + * For builtin providers, returns the provider as-is. + * For custom providers, returns the sdkType from settings (defaults to 'openai'). + * + * @param provider - The provider id + * @param sdkType - The sdkType from provider settings + * @returns The resolved runtime provider + */ +const resolveRuntimeProvider = (provider: string, sdkType?: string): string => { + const isBuiltin = Object.values(ModelProvider).includes(provider as ModelProvider); + if (isBuiltin) return provider; + + return sdkType || 'openai'; +}; + /** * Build ClientSecretPayload from keyVaults stored in database * @@ -39,15 +57,21 @@ type ProviderKeyVaults = OpenAICompatibleKeyVault & * It converts the keyVaults object from database to the ClientSecretPayload format * expected by initModelRuntimeWithUserPayload. * - * @param provider - The model provider + * For custom providers, we use runtimeProvider (sdkType) to determine which fields + * to include in the payload. This ensures that provider-specific fields like + * cloudflareBaseURLOrAccountID or azureApiVersion are correctly forwarded. + * * @param keyVaults - The keyVaults object from database (already decrypted) + * @param runtimeProvider - The runtime provider (sdkType) to use for building payload * @returns ClientSecretPayload for the provider */ export const buildPayloadFromKeyVaults = ( - provider: string, keyVaults: ProviderKeyVaults, + runtimeProvider: string, ): ClientSecretPayload => { - switch (provider) { + // Use runtimeProvider to determine which fields to include + // This handles both builtin providers and custom providers with sdkType + switch (runtimeProvider) { case ModelProvider.Bedrock: { const { accessKeyId, region, secretAccessKey, sessionToken } = keyVaults; const apiKey = (secretAccessKey || '') + (accessKeyId || ''); @@ -58,6 +82,7 @@ export const buildPayloadFromKeyVaults = ( awsRegion: region, awsSecretAccessKey: secretAccessKey, awsSessionToken: sessionToken, + runtimeProvider, }; } @@ -66,17 +91,19 @@ export const buildPayloadFromKeyVaults = ( apiKey: keyVaults.apiKey, azureApiVersion: keyVaults.apiVersion, baseURL: keyVaults.baseURL || keyVaults.endpoint, + runtimeProvider, }; } case ModelProvider.Ollama: { - return { baseURL: keyVaults.baseURL }; + return { baseURL: keyVaults.baseURL, runtimeProvider }; } case ModelProvider.Cloudflare: { return { apiKey: keyVaults.apiKey, cloudflareBaseURLOrAccountID: keyVaults.baseURLOrAccountID, + runtimeProvider, }; } @@ -87,6 +114,7 @@ export const buildPayloadFromKeyVaults = ( baseURL: keyVaults.baseURL, customHeaders: keyVaults.customHeaders, password: keyVaults.password, + runtimeProvider, username: keyVaults.username, }; } @@ -95,6 +123,7 @@ export const buildPayloadFromKeyVaults = ( return { apiKey: keyVaults.apiKey, baseURL: keyVaults.baseURL, + runtimeProvider, vertexAIRegion: keyVaults.region, }; } @@ -103,6 +132,7 @@ export const buildPayloadFromKeyVaults = ( return { apiKey: keyVaults.apiKey, baseURL: keyVaults.baseURL, + runtimeProvider, }; } } @@ -350,10 +380,16 @@ export const initModelRuntimeFromDB = async ( KeyVaultsGateKeeper.getUserKeyVaults, ); - // 2. Build ClientSecretPayload from keyVaults - const keyVaults = (providerConfig?.keyVaults || {}) as ProviderKeyVaults; - const payload = buildPayloadFromKeyVaults(provider, keyVaults); + // 2. Resolve the runtime provider for custom providers + // For custom providers, use sdkType from settings (defaults to 'openai') + const sdkType = providerConfig?.settings?.sdkType; + const runtimeProvider = resolveRuntimeProvider(provider, sdkType); - // 3. Initialize ModelRuntime with the payload + // 3. Build ClientSecretPayload from keyVaults based on runtimeProvider + // This ensures provider-specific fields (e.g., cloudflareBaseURLOrAccountID) are included + const keyVaults = (providerConfig?.keyVaults || {}) as ProviderKeyVaults; + const payload = buildPayloadFromKeyVaults(keyVaults, runtimeProvider); + + // 4. Initialize ModelRuntime with the payload return initModelRuntimeWithUserPayload(provider, payload); }; diff --git a/vitest.config.mts b/vitest.config.mts index 982be40b4e..aa092fbae7 100644 --- a/vitest.config.mts +++ b/vitest.config.mts @@ -92,6 +92,7 @@ export default defineConfig({ '**/e2e/**', ], globals: true, + reporters: ['default', 'blob'], server: { deps: { inline: ['vitest-canvas-mock', '@lobehub/ui', '@lobehub/fluent-emoji'],