feat: support bedrok prompt cache and usage compute (#10337)

This commit is contained in:
YuTengjing
2025-11-22 13:09:07 +08:00
committed by GitHub
parent 8b63246491
commit beb9471e15
8 changed files with 280 additions and 83 deletions

View File

@@ -28,6 +28,7 @@ The project follows a well-organized monorepo structure:
### Git Workflow
- The current release branch is `next` instead of `main` until v2.0.0 is officially released
- Use rebase for git pull
- Git commit messages should prefix with gitmoji
- Git branch name format: `username/feat/feature-name`

View File

@@ -14,6 +14,7 @@ read @.cursor/rules/project-structure.mdc
### Git Workflow
- The current release branch is `next` instead of `main` until v2.0.0 is officially released
- use rebase for git pull
- git commit message should prefix with gitmoji
- git branch name format example: tj/feat/feature-name

View File

@@ -148,10 +148,9 @@ export const createRouterRuntime = ({
}
/**
* TODO: 考虑添加缓存机制,避免重复创建相同配置的 runtimes
* Resolve routers configuration and validate
*/
private async createRuntimesByRouters(model?: string): Promise<RuntimeItem[]> {
// 动态获取 routers支持传入 model
private async resolveRouters(model?: string): Promise<RouterInstance[]> {
const resolvedRouters =
typeof this._routers === 'function'
? await this._routers(this._options, { model })
@@ -161,6 +160,41 @@ export const createRouterRuntime = ({
throw new Error('empty providers');
}
return resolvedRouters;
}
/**
* Create runtime for inference requests (chat, generateObject, etc.)
* Finds the router that matches the model, or uses the last router as fallback
*/
private async createRuntimeForInference(model: string): Promise<RuntimeItem> {
const resolvedRouters = await this.resolveRouters(model);
const matchedRouter =
resolvedRouters.find((router) => {
if (router.models && router.models.length > 0) {
return router.models.includes(model);
}
return false;
}) ?? resolvedRouters.at(-1)!;
const providerAI =
matchedRouter.runtime ?? baseRuntimeMap[matchedRouter.apiType] ?? LobeOpenAI;
const finalOptions = { ...this._params, ...this._options, ...matchedRouter.options };
const runtime: LobeRuntimeAI = new providerAI({ ...finalOptions, id: this._id });
return {
id: matchedRouter.apiType,
models: matchedRouter.models,
runtime,
};
}
/**
* Create all runtimes for listing models
*/
private async createRuntimes(): Promise<RuntimeItem[]> {
const resolvedRouters = await this.resolveRouters();
return resolvedRouters.map((router) => {
const providerAI = router.runtime ?? baseRuntimeMap[router.apiType] ?? LobeOpenAI;
const finalOptions = { ...this._params, ...this._options, ...router.options };
@@ -176,16 +210,8 @@ export const createRouterRuntime = ({
// Check if it can match a specific model, otherwise default to using the last runtime
async getRuntimeByModel(model: string) {
const runtimes = await this.createRuntimesByRouters(model);
for (const runtimeItem of runtimes) {
const models = runtimeItem.models || [];
if (models.includes(model)) {
return runtimeItem.runtime;
}
}
return runtimes.at(-1)!.runtime;
const runtimeItem = await this.createRuntimeForInference(model);
return runtimeItem.runtime;
}
async chat(payload: ChatStreamPayload, options?: ChatMethodOptions) {
@@ -222,9 +248,8 @@ export const createRouterRuntime = ({
async models() {
if (modelsOption && typeof modelsOption === 'function') {
// 延迟创建 runtimes
const runtimes = await this.createRuntimesByRouters();
// 如果是函数式配置,使用最后一个运行时的客户端来调用函数
const runtimes = await this.createRuntimes();
// If it's a functional configuration, use the last runtime's client to call the function
const lastRuntime = runtimes.at(-1)?.runtime;
if (lastRuntime && 'client' in lastRuntime) {
const modelList = await modelsOption({ client: (lastRuntime as any).client });
@@ -232,8 +257,7 @@ export const createRouterRuntime = ({
}
}
// 延迟创建 runtimes
const runtimes = await this.createRuntimesByRouters();
const runtimes = await this.createRuntimes();
return runtimes.at(-1)?.runtime.models?.();
}

View File

@@ -7,18 +7,32 @@ import {
StreamContext,
createCallbacksTransformer,
createSSEProtocolTransformer,
createTokenSpeedCalculator,
} from '../protocol';
import { createBedrockStream } from './common';
export const AWSBedrockClaudeStream = (
res: InvokeModelWithResponseStreamResponse | ReadableStream,
cb?: ChatStreamCallbacks,
options?: {
callbacks?: ChatStreamCallbacks;
inputStartAt?: number;
payload?: Parameters<typeof transformAnthropicStream>[2];
},
): ReadableStream<string> => {
const streamStack: StreamContext = { id: 'chat_' + nanoid() };
const stream = res instanceof ReadableStream ? res : createBedrockStream(res);
const transformWithPayload: typeof transformAnthropicStream = (chunk, ctx) =>
transformAnthropicStream(chunk, ctx, options?.payload);
return stream
.pipeThrough(createSSEProtocolTransformer(transformAnthropicStream, streamStack))
.pipeThrough(createCallbacksTransformer(cb));
.pipeThrough(
createTokenSpeedCalculator(transformWithPayload, {
inputStartAt: options?.inputStartAt,
streamStack,
}),
)
.pipeThrough(createSSEProtocolTransformer((c) => c, streamStack))
.pipeThrough(createCallbacksTransformer(options?.callbacks));
};

View File

@@ -21,6 +21,7 @@ import { MODEL_LIST_CONFIGS, processModelList } from '../../utils/modelParse';
import { StreamingResponse } from '../../utils/response';
import { createAnthropicGenerateObject } from './generateObject';
import { handleAnthropicError } from './handleAnthropicError';
import { resolveCacheTTL } from './resolveCacheTTL';
export interface AnthropicModelCard {
created_at: string;
@@ -33,44 +34,6 @@ type anthropicTools = Anthropic.Tool | Anthropic.WebSearchTool20250305;
const modelsWithSmallContextWindow = new Set(['claude-3-opus-20240229', 'claude-3-haiku-20240307']);
const DEFAULT_BASE_URL = 'https://api.anthropic.com';
const DEFAULT_CACHE_TTL = '5m' as const;
type CacheTTL = Anthropic.Messages.CacheControlEphemeral['ttl'];
/**
* Resolves cache TTL from Anthropic payload or request settings
* Returns the first valid TTL found in system messages or content blocks
*/
const resolveCacheTTL = (
requestPayload: ChatStreamPayload,
anthropicPayload: Anthropic.MessageCreateParams,
): CacheTTL | undefined => {
// Check system messages for cache TTL
if (Array.isArray(anthropicPayload.system)) {
for (const block of anthropicPayload.system) {
const ttl = block.cache_control?.ttl;
if (ttl) return ttl;
}
}
// Check message content blocks for cache TTL
for (const message of anthropicPayload.messages ?? []) {
if (!Array.isArray(message.content)) continue;
for (const block of message.content) {
// Message content blocks might have cache_control property
const ttl = ('cache_control' in block && block.cache_control?.ttl) as CacheTTL | undefined;
if (ttl) return ttl;
}
}
// Use default TTL if context caching is enabled
if (requestPayload.enabledContextCaching) {
return DEFAULT_CACHE_TTL;
}
return undefined;
};
interface AnthropicAIParams extends ClientOptions {
id?: string;

View File

@@ -0,0 +1,44 @@
import Anthropic from '@anthropic-ai/sdk';
import { ChatStreamPayload } from '../../types';
type CacheTTL = Anthropic.Messages.CacheControlEphemeral['ttl'];
const DEFAULT_CACHE_TTL = '5m' as const;
/**
* Resolves cache TTL from Anthropic payload or request settings.
* Returns the first valid TTL found in system messages or content blocks.
*/
export const resolveCacheTTL = (
requestPayload: ChatStreamPayload,
anthropicPayload: {
messages: Anthropic.MessageCreateParams['messages'];
system: Anthropic.MessageCreateParams['system'];
},
): CacheTTL | undefined => {
// Check system messages for cache TTL
if (Array.isArray(anthropicPayload.system)) {
for (const block of anthropicPayload.system) {
const ttl = block.cache_control?.ttl;
if (ttl) return ttl;
}
}
// Check message content blocks for cache TTL
for (const message of anthropicPayload.messages ?? []) {
if (!Array.isArray(message.content)) continue;
for (const block of message.content) {
const ttl = ('cache_control' in block && block.cache_control?.ttl) as CacheTTL | undefined;
if (ttl) return ttl;
}
}
// Use default TTL if context caching is enabled
if (requestPayload.enabledContextCaching) {
return DEFAULT_CACHE_TTL;
}
return undefined;
};

View File

@@ -173,7 +173,18 @@ describe('LobeBedrockAI', () => {
body: JSON.stringify({
anthropic_version: 'bedrock-2023-05-31',
max_tokens: 4096,
messages: [{ content: 'Hello', role: 'user' }],
messages: [
{
content: [
{
cache_control: { type: 'ephemeral' },
text: 'Hello',
type: 'text',
},
],
role: 'user',
},
],
temperature: 0,
top_p: 1,
}),
@@ -211,8 +222,25 @@ describe('LobeBedrockAI', () => {
body: JSON.stringify({
anthropic_version: 'bedrock-2023-05-31',
max_tokens: 4096,
messages: [{ content: 'Hello', role: 'user' }],
system: 'You are an awesome greeter',
messages: [
{
content: [
{
cache_control: { type: 'ephemeral' },
text: 'Hello',
type: 'text',
},
],
role: 'user',
},
],
system: [
{
cache_control: { type: 'ephemeral' },
text: 'You are an awesome greeter',
type: 'text',
},
],
temperature: 0,
top_p: 1,
}),
@@ -248,7 +276,18 @@ describe('LobeBedrockAI', () => {
body: JSON.stringify({
anthropic_version: 'bedrock-2023-05-31',
max_tokens: 2048,
messages: [{ content: 'Hello', role: 'user' }],
messages: [
{
content: [
{
cache_control: { type: 'ephemeral' },
text: 'Hello',
type: 'text',
},
],
role: 'user',
},
],
temperature: 0.25,
top_p: 1,
}),
@@ -327,7 +366,18 @@ describe('LobeBedrockAI', () => {
body: JSON.stringify({
anthropic_version: 'bedrock-2023-05-31',
max_tokens: 4096,
messages: [{ content: 'Hello', role: 'user' }],
messages: [
{
content: [
{
cache_control: { type: 'ephemeral' },
text: 'Hello',
type: 'text',
},
],
role: 'user',
},
],
temperature: 0,
}),
contentType: 'application/json',
@@ -363,7 +413,18 @@ describe('LobeBedrockAI', () => {
body: JSON.stringify({
anthropic_version: 'bedrock-2023-05-31',
max_tokens: 2048,
messages: [{ content: 'Hello', role: 'user' }],
messages: [
{
content: [
{
cache_control: { type: 'ephemeral' },
text: 'Hello',
type: 'text',
},
],
role: 'user',
},
],
temperature: 0.25,
top_p: 1,
}),
@@ -418,7 +479,18 @@ describe('LobeBedrockAI', () => {
body: JSON.stringify({
anthropic_version: 'bedrock-2023-05-31',
max_tokens: 4096,
messages: [{ content: 'Hello', role: 'user' }],
messages: [
{
content: [
{
cache_control: { type: 'ephemeral' },
text: 'Hello',
type: 'text',
},
],
role: 'user',
},
],
temperature: 0.4, // temperature / 2, top_p omitted due to conflict
}),
contentType: 'application/json',
@@ -450,7 +522,18 @@ describe('LobeBedrockAI', () => {
body: JSON.stringify({
anthropic_version: 'bedrock-2023-05-31',
max_tokens: 4096,
messages: [{ content: 'Hello', role: 'user' }],
messages: [
{
content: [
{
cache_control: { type: 'ephemeral' },
text: 'Hello',
type: 'text',
},
],
role: 'user',
},
],
top_p: 0.9, // temperature omitted since not provided
}),
contentType: 'application/json',
@@ -483,7 +566,18 @@ describe('LobeBedrockAI', () => {
body: JSON.stringify({
anthropic_version: 'bedrock-2023-05-31',
max_tokens: 4096,
messages: [{ content: 'Hello', role: 'user' }],
messages: [
{
content: [
{
cache_control: { type: 'ephemeral' },
text: 'Hello',
type: 'text',
},
],
role: 'user',
},
],
temperature: 0.4, // temperature / 2
top_p: 0.9, // both parameters allowed for older models
}),
@@ -517,7 +611,18 @@ describe('LobeBedrockAI', () => {
body: JSON.stringify({
anthropic_version: 'bedrock-2023-05-31',
max_tokens: 4096,
messages: [{ content: 'Hello', role: 'user' }],
messages: [
{
content: [
{
cache_control: { type: 'ephemeral' },
text: 'Hello',
type: 'text',
},
],
role: 'user',
},
],
temperature: 0.3, // temperature / 2, top_p omitted due to conflict
}),
contentType: 'application/json',
@@ -550,7 +655,18 @@ describe('LobeBedrockAI', () => {
body: JSON.stringify({
anthropic_version: 'bedrock-2023-05-31',
max_tokens: 4096,
messages: [{ content: 'Hello', role: 'user' }],
messages: [
{
content: [
{
cache_control: { type: 'ephemeral' },
text: 'Hello',
type: 'text',
},
],
role: 'user',
},
],
temperature: 0.35, // temperature / 2, top_p omitted due to conflict
}),
contentType: 'application/json',

View File

@@ -23,7 +23,9 @@ import {
import { AgentRuntimeErrorType } from '../../types/error';
import { AgentRuntimeError } from '../../utils/createError';
import { debugStream } from '../../utils/debugStream';
import { getModelPricing } from '../../utils/getModelPricing';
import { StreamingResponse } from '../../utils/response';
import { resolveCacheTTL } from '../anthropic/resolveCacheTTL';
/**
* A prompt constructor for HuggingFace LLama 2 chat models.
@@ -148,7 +150,16 @@ export class LobeBedrockAI implements LobeRuntimeAI {
payload: ChatStreamPayload,
options?: ChatMethodOptions,
): Promise<Response> => {
const { max_tokens, messages, model, temperature, top_p, tools } = payload;
const {
enabledContextCaching = true,
max_tokens,
messages,
model,
temperature,
top_p,
tools,
} = payload;
const inputStartAt = Date.now();
const system_message = messages.find((m) => m.role === 'system');
const user_messages = messages.filter((m) => m.role !== 'system');
@@ -159,17 +170,29 @@ export class LobeBedrockAI implements LobeRuntimeAI {
{ hasConflict, normalizeTemperature: true, preferTemperature: true },
);
const systemPrompts = !!system_message?.content
? ([
{
cache_control: enabledContextCaching ? { type: 'ephemeral' } : undefined,
text: system_message.content as string,
type: 'text',
},
] as any)
: undefined;
const anthropicPayload = {
anthropic_version: 'bedrock-2023-05-31',
max_tokens: max_tokens || 4096,
messages: await buildAnthropicMessages(user_messages, { enabledContextCaching }),
system: systemPrompts,
temperature: resolvedParams.temperature,
tools: buildAnthropicTools(tools, { enabledContextCaching }),
top_p: resolvedParams.top_p,
};
const command = new InvokeModelWithResponseStreamCommand({
accept: 'application/json',
body: JSON.stringify({
anthropic_version: 'bedrock-2023-05-31',
max_tokens: max_tokens || 4096,
messages: await buildAnthropicMessages(user_messages),
system: system_message?.content as string,
temperature: resolvedParams.temperature,
tools: buildAnthropicTools(tools),
top_p: resolvedParams.top_p,
}),
body: JSON.stringify(anthropicPayload),
contentType: 'application/json',
modelId: model,
});
@@ -186,10 +209,21 @@ export class LobeBedrockAI implements LobeRuntimeAI {
debugStream(debug).catch(console.error);
}
const pricing = await getModelPricing(payload.model, ModelProvider.Bedrock);
const cacheTTL = resolveCacheTTL({ ...payload, enabledContextCaching }, anthropicPayload);
const pricingOptions = cacheTTL ? { lookupParams: { ttl: cacheTTL } } : undefined;
// Respond with the stream
return StreamingResponse(AWSBedrockClaudeStream(prod, options?.callback), {
headers: options?.headers,
});
return StreamingResponse(
AWSBedrockClaudeStream(prod, {
callbacks: options?.callback,
inputStartAt,
payload: { model, pricing, pricingOptions, provider: ModelProvider.Bedrock },
}),
{
headers: options?.headers,
},
);
} catch (e) {
const err = e as Error & { $metadata: any };