mirror of
https://github.com/lobehub/lobehub.git
synced 2026-03-30 13:59:22 +07:00
✨ feat: support bedrok prompt cache and usage compute (#10337)
This commit is contained in:
@@ -28,6 +28,7 @@ The project follows a well-organized monorepo structure:
|
||||
|
||||
### Git Workflow
|
||||
|
||||
- The current release branch is `next` instead of `main` until v2.0.0 is officially released
|
||||
- Use rebase for git pull
|
||||
- Git commit messages should prefix with gitmoji
|
||||
- Git branch name format: `username/feat/feature-name`
|
||||
|
||||
@@ -14,6 +14,7 @@ read @.cursor/rules/project-structure.mdc
|
||||
|
||||
### Git Workflow
|
||||
|
||||
- The current release branch is `next` instead of `main` until v2.0.0 is officially released
|
||||
- use rebase for git pull
|
||||
- git commit message should prefix with gitmoji
|
||||
- git branch name format example: tj/feat/feature-name
|
||||
|
||||
@@ -148,10 +148,9 @@ export const createRouterRuntime = ({
|
||||
}
|
||||
|
||||
/**
|
||||
* TODO: 考虑添加缓存机制,避免重复创建相同配置的 runtimes
|
||||
* Resolve routers configuration and validate
|
||||
*/
|
||||
private async createRuntimesByRouters(model?: string): Promise<RuntimeItem[]> {
|
||||
// 动态获取 routers,支持传入 model
|
||||
private async resolveRouters(model?: string): Promise<RouterInstance[]> {
|
||||
const resolvedRouters =
|
||||
typeof this._routers === 'function'
|
||||
? await this._routers(this._options, { model })
|
||||
@@ -161,6 +160,41 @@ export const createRouterRuntime = ({
|
||||
throw new Error('empty providers');
|
||||
}
|
||||
|
||||
return resolvedRouters;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create runtime for inference requests (chat, generateObject, etc.)
|
||||
* Finds the router that matches the model, or uses the last router as fallback
|
||||
*/
|
||||
private async createRuntimeForInference(model: string): Promise<RuntimeItem> {
|
||||
const resolvedRouters = await this.resolveRouters(model);
|
||||
|
||||
const matchedRouter =
|
||||
resolvedRouters.find((router) => {
|
||||
if (router.models && router.models.length > 0) {
|
||||
return router.models.includes(model);
|
||||
}
|
||||
return false;
|
||||
}) ?? resolvedRouters.at(-1)!;
|
||||
|
||||
const providerAI =
|
||||
matchedRouter.runtime ?? baseRuntimeMap[matchedRouter.apiType] ?? LobeOpenAI;
|
||||
const finalOptions = { ...this._params, ...this._options, ...matchedRouter.options };
|
||||
const runtime: LobeRuntimeAI = new providerAI({ ...finalOptions, id: this._id });
|
||||
|
||||
return {
|
||||
id: matchedRouter.apiType,
|
||||
models: matchedRouter.models,
|
||||
runtime,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Create all runtimes for listing models
|
||||
*/
|
||||
private async createRuntimes(): Promise<RuntimeItem[]> {
|
||||
const resolvedRouters = await this.resolveRouters();
|
||||
return resolvedRouters.map((router) => {
|
||||
const providerAI = router.runtime ?? baseRuntimeMap[router.apiType] ?? LobeOpenAI;
|
||||
const finalOptions = { ...this._params, ...this._options, ...router.options };
|
||||
@@ -176,16 +210,8 @@ export const createRouterRuntime = ({
|
||||
|
||||
// Check if it can match a specific model, otherwise default to using the last runtime
|
||||
async getRuntimeByModel(model: string) {
|
||||
const runtimes = await this.createRuntimesByRouters(model);
|
||||
|
||||
for (const runtimeItem of runtimes) {
|
||||
const models = runtimeItem.models || [];
|
||||
if (models.includes(model)) {
|
||||
return runtimeItem.runtime;
|
||||
}
|
||||
}
|
||||
|
||||
return runtimes.at(-1)!.runtime;
|
||||
const runtimeItem = await this.createRuntimeForInference(model);
|
||||
return runtimeItem.runtime;
|
||||
}
|
||||
|
||||
async chat(payload: ChatStreamPayload, options?: ChatMethodOptions) {
|
||||
@@ -222,9 +248,8 @@ export const createRouterRuntime = ({
|
||||
|
||||
async models() {
|
||||
if (modelsOption && typeof modelsOption === 'function') {
|
||||
// 延迟创建 runtimes
|
||||
const runtimes = await this.createRuntimesByRouters();
|
||||
// 如果是函数式配置,使用最后一个运行时的客户端来调用函数
|
||||
const runtimes = await this.createRuntimes();
|
||||
// If it's a functional configuration, use the last runtime's client to call the function
|
||||
const lastRuntime = runtimes.at(-1)?.runtime;
|
||||
if (lastRuntime && 'client' in lastRuntime) {
|
||||
const modelList = await modelsOption({ client: (lastRuntime as any).client });
|
||||
@@ -232,8 +257,7 @@ export const createRouterRuntime = ({
|
||||
}
|
||||
}
|
||||
|
||||
// 延迟创建 runtimes
|
||||
const runtimes = await this.createRuntimesByRouters();
|
||||
const runtimes = await this.createRuntimes();
|
||||
return runtimes.at(-1)?.runtime.models?.();
|
||||
}
|
||||
|
||||
|
||||
@@ -7,18 +7,32 @@ import {
|
||||
StreamContext,
|
||||
createCallbacksTransformer,
|
||||
createSSEProtocolTransformer,
|
||||
createTokenSpeedCalculator,
|
||||
} from '../protocol';
|
||||
import { createBedrockStream } from './common';
|
||||
|
||||
export const AWSBedrockClaudeStream = (
|
||||
res: InvokeModelWithResponseStreamResponse | ReadableStream,
|
||||
cb?: ChatStreamCallbacks,
|
||||
options?: {
|
||||
callbacks?: ChatStreamCallbacks;
|
||||
inputStartAt?: number;
|
||||
payload?: Parameters<typeof transformAnthropicStream>[2];
|
||||
},
|
||||
): ReadableStream<string> => {
|
||||
const streamStack: StreamContext = { id: 'chat_' + nanoid() };
|
||||
|
||||
const stream = res instanceof ReadableStream ? res : createBedrockStream(res);
|
||||
|
||||
const transformWithPayload: typeof transformAnthropicStream = (chunk, ctx) =>
|
||||
transformAnthropicStream(chunk, ctx, options?.payload);
|
||||
|
||||
return stream
|
||||
.pipeThrough(createSSEProtocolTransformer(transformAnthropicStream, streamStack))
|
||||
.pipeThrough(createCallbacksTransformer(cb));
|
||||
.pipeThrough(
|
||||
createTokenSpeedCalculator(transformWithPayload, {
|
||||
inputStartAt: options?.inputStartAt,
|
||||
streamStack,
|
||||
}),
|
||||
)
|
||||
.pipeThrough(createSSEProtocolTransformer((c) => c, streamStack))
|
||||
.pipeThrough(createCallbacksTransformer(options?.callbacks));
|
||||
};
|
||||
|
||||
@@ -21,6 +21,7 @@ import { MODEL_LIST_CONFIGS, processModelList } from '../../utils/modelParse';
|
||||
import { StreamingResponse } from '../../utils/response';
|
||||
import { createAnthropicGenerateObject } from './generateObject';
|
||||
import { handleAnthropicError } from './handleAnthropicError';
|
||||
import { resolveCacheTTL } from './resolveCacheTTL';
|
||||
|
||||
export interface AnthropicModelCard {
|
||||
created_at: string;
|
||||
@@ -33,44 +34,6 @@ type anthropicTools = Anthropic.Tool | Anthropic.WebSearchTool20250305;
|
||||
const modelsWithSmallContextWindow = new Set(['claude-3-opus-20240229', 'claude-3-haiku-20240307']);
|
||||
|
||||
const DEFAULT_BASE_URL = 'https://api.anthropic.com';
|
||||
const DEFAULT_CACHE_TTL = '5m' as const;
|
||||
|
||||
type CacheTTL = Anthropic.Messages.CacheControlEphemeral['ttl'];
|
||||
|
||||
/**
|
||||
* Resolves cache TTL from Anthropic payload or request settings
|
||||
* Returns the first valid TTL found in system messages or content blocks
|
||||
*/
|
||||
const resolveCacheTTL = (
|
||||
requestPayload: ChatStreamPayload,
|
||||
anthropicPayload: Anthropic.MessageCreateParams,
|
||||
): CacheTTL | undefined => {
|
||||
// Check system messages for cache TTL
|
||||
if (Array.isArray(anthropicPayload.system)) {
|
||||
for (const block of anthropicPayload.system) {
|
||||
const ttl = block.cache_control?.ttl;
|
||||
if (ttl) return ttl;
|
||||
}
|
||||
}
|
||||
|
||||
// Check message content blocks for cache TTL
|
||||
for (const message of anthropicPayload.messages ?? []) {
|
||||
if (!Array.isArray(message.content)) continue;
|
||||
|
||||
for (const block of message.content) {
|
||||
// Message content blocks might have cache_control property
|
||||
const ttl = ('cache_control' in block && block.cache_control?.ttl) as CacheTTL | undefined;
|
||||
if (ttl) return ttl;
|
||||
}
|
||||
}
|
||||
|
||||
// Use default TTL if context caching is enabled
|
||||
if (requestPayload.enabledContextCaching) {
|
||||
return DEFAULT_CACHE_TTL;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
};
|
||||
|
||||
interface AnthropicAIParams extends ClientOptions {
|
||||
id?: string;
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
import Anthropic from '@anthropic-ai/sdk';
|
||||
|
||||
import { ChatStreamPayload } from '../../types';
|
||||
|
||||
type CacheTTL = Anthropic.Messages.CacheControlEphemeral['ttl'];
|
||||
|
||||
const DEFAULT_CACHE_TTL = '5m' as const;
|
||||
|
||||
/**
|
||||
* Resolves cache TTL from Anthropic payload or request settings.
|
||||
* Returns the first valid TTL found in system messages or content blocks.
|
||||
*/
|
||||
export const resolveCacheTTL = (
|
||||
requestPayload: ChatStreamPayload,
|
||||
anthropicPayload: {
|
||||
messages: Anthropic.MessageCreateParams['messages'];
|
||||
system: Anthropic.MessageCreateParams['system'];
|
||||
},
|
||||
): CacheTTL | undefined => {
|
||||
// Check system messages for cache TTL
|
||||
if (Array.isArray(anthropicPayload.system)) {
|
||||
for (const block of anthropicPayload.system) {
|
||||
const ttl = block.cache_control?.ttl;
|
||||
if (ttl) return ttl;
|
||||
}
|
||||
}
|
||||
|
||||
// Check message content blocks for cache TTL
|
||||
for (const message of anthropicPayload.messages ?? []) {
|
||||
if (!Array.isArray(message.content)) continue;
|
||||
|
||||
for (const block of message.content) {
|
||||
const ttl = ('cache_control' in block && block.cache_control?.ttl) as CacheTTL | undefined;
|
||||
if (ttl) return ttl;
|
||||
}
|
||||
}
|
||||
|
||||
// Use default TTL if context caching is enabled
|
||||
if (requestPayload.enabledContextCaching) {
|
||||
return DEFAULT_CACHE_TTL;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
};
|
||||
@@ -173,7 +173,18 @@ describe('LobeBedrockAI', () => {
|
||||
body: JSON.stringify({
|
||||
anthropic_version: 'bedrock-2023-05-31',
|
||||
max_tokens: 4096,
|
||||
messages: [{ content: 'Hello', role: 'user' }],
|
||||
messages: [
|
||||
{
|
||||
content: [
|
||||
{
|
||||
cache_control: { type: 'ephemeral' },
|
||||
text: 'Hello',
|
||||
type: 'text',
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
],
|
||||
temperature: 0,
|
||||
top_p: 1,
|
||||
}),
|
||||
@@ -211,8 +222,25 @@ describe('LobeBedrockAI', () => {
|
||||
body: JSON.stringify({
|
||||
anthropic_version: 'bedrock-2023-05-31',
|
||||
max_tokens: 4096,
|
||||
messages: [{ content: 'Hello', role: 'user' }],
|
||||
system: 'You are an awesome greeter',
|
||||
messages: [
|
||||
{
|
||||
content: [
|
||||
{
|
||||
cache_control: { type: 'ephemeral' },
|
||||
text: 'Hello',
|
||||
type: 'text',
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
],
|
||||
system: [
|
||||
{
|
||||
cache_control: { type: 'ephemeral' },
|
||||
text: 'You are an awesome greeter',
|
||||
type: 'text',
|
||||
},
|
||||
],
|
||||
temperature: 0,
|
||||
top_p: 1,
|
||||
}),
|
||||
@@ -248,7 +276,18 @@ describe('LobeBedrockAI', () => {
|
||||
body: JSON.stringify({
|
||||
anthropic_version: 'bedrock-2023-05-31',
|
||||
max_tokens: 2048,
|
||||
messages: [{ content: 'Hello', role: 'user' }],
|
||||
messages: [
|
||||
{
|
||||
content: [
|
||||
{
|
||||
cache_control: { type: 'ephemeral' },
|
||||
text: 'Hello',
|
||||
type: 'text',
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
],
|
||||
temperature: 0.25,
|
||||
top_p: 1,
|
||||
}),
|
||||
@@ -327,7 +366,18 @@ describe('LobeBedrockAI', () => {
|
||||
body: JSON.stringify({
|
||||
anthropic_version: 'bedrock-2023-05-31',
|
||||
max_tokens: 4096,
|
||||
messages: [{ content: 'Hello', role: 'user' }],
|
||||
messages: [
|
||||
{
|
||||
content: [
|
||||
{
|
||||
cache_control: { type: 'ephemeral' },
|
||||
text: 'Hello',
|
||||
type: 'text',
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
],
|
||||
temperature: 0,
|
||||
}),
|
||||
contentType: 'application/json',
|
||||
@@ -363,7 +413,18 @@ describe('LobeBedrockAI', () => {
|
||||
body: JSON.stringify({
|
||||
anthropic_version: 'bedrock-2023-05-31',
|
||||
max_tokens: 2048,
|
||||
messages: [{ content: 'Hello', role: 'user' }],
|
||||
messages: [
|
||||
{
|
||||
content: [
|
||||
{
|
||||
cache_control: { type: 'ephemeral' },
|
||||
text: 'Hello',
|
||||
type: 'text',
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
],
|
||||
temperature: 0.25,
|
||||
top_p: 1,
|
||||
}),
|
||||
@@ -418,7 +479,18 @@ describe('LobeBedrockAI', () => {
|
||||
body: JSON.stringify({
|
||||
anthropic_version: 'bedrock-2023-05-31',
|
||||
max_tokens: 4096,
|
||||
messages: [{ content: 'Hello', role: 'user' }],
|
||||
messages: [
|
||||
{
|
||||
content: [
|
||||
{
|
||||
cache_control: { type: 'ephemeral' },
|
||||
text: 'Hello',
|
||||
type: 'text',
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
],
|
||||
temperature: 0.4, // temperature / 2, top_p omitted due to conflict
|
||||
}),
|
||||
contentType: 'application/json',
|
||||
@@ -450,7 +522,18 @@ describe('LobeBedrockAI', () => {
|
||||
body: JSON.stringify({
|
||||
anthropic_version: 'bedrock-2023-05-31',
|
||||
max_tokens: 4096,
|
||||
messages: [{ content: 'Hello', role: 'user' }],
|
||||
messages: [
|
||||
{
|
||||
content: [
|
||||
{
|
||||
cache_control: { type: 'ephemeral' },
|
||||
text: 'Hello',
|
||||
type: 'text',
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
],
|
||||
top_p: 0.9, // temperature omitted since not provided
|
||||
}),
|
||||
contentType: 'application/json',
|
||||
@@ -483,7 +566,18 @@ describe('LobeBedrockAI', () => {
|
||||
body: JSON.stringify({
|
||||
anthropic_version: 'bedrock-2023-05-31',
|
||||
max_tokens: 4096,
|
||||
messages: [{ content: 'Hello', role: 'user' }],
|
||||
messages: [
|
||||
{
|
||||
content: [
|
||||
{
|
||||
cache_control: { type: 'ephemeral' },
|
||||
text: 'Hello',
|
||||
type: 'text',
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
],
|
||||
temperature: 0.4, // temperature / 2
|
||||
top_p: 0.9, // both parameters allowed for older models
|
||||
}),
|
||||
@@ -517,7 +611,18 @@ describe('LobeBedrockAI', () => {
|
||||
body: JSON.stringify({
|
||||
anthropic_version: 'bedrock-2023-05-31',
|
||||
max_tokens: 4096,
|
||||
messages: [{ content: 'Hello', role: 'user' }],
|
||||
messages: [
|
||||
{
|
||||
content: [
|
||||
{
|
||||
cache_control: { type: 'ephemeral' },
|
||||
text: 'Hello',
|
||||
type: 'text',
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
],
|
||||
temperature: 0.3, // temperature / 2, top_p omitted due to conflict
|
||||
}),
|
||||
contentType: 'application/json',
|
||||
@@ -550,7 +655,18 @@ describe('LobeBedrockAI', () => {
|
||||
body: JSON.stringify({
|
||||
anthropic_version: 'bedrock-2023-05-31',
|
||||
max_tokens: 4096,
|
||||
messages: [{ content: 'Hello', role: 'user' }],
|
||||
messages: [
|
||||
{
|
||||
content: [
|
||||
{
|
||||
cache_control: { type: 'ephemeral' },
|
||||
text: 'Hello',
|
||||
type: 'text',
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
],
|
||||
temperature: 0.35, // temperature / 2, top_p omitted due to conflict
|
||||
}),
|
||||
contentType: 'application/json',
|
||||
|
||||
@@ -23,7 +23,9 @@ import {
|
||||
import { AgentRuntimeErrorType } from '../../types/error';
|
||||
import { AgentRuntimeError } from '../../utils/createError';
|
||||
import { debugStream } from '../../utils/debugStream';
|
||||
import { getModelPricing } from '../../utils/getModelPricing';
|
||||
import { StreamingResponse } from '../../utils/response';
|
||||
import { resolveCacheTTL } from '../anthropic/resolveCacheTTL';
|
||||
|
||||
/**
|
||||
* A prompt constructor for HuggingFace LLama 2 chat models.
|
||||
@@ -148,7 +150,16 @@ export class LobeBedrockAI implements LobeRuntimeAI {
|
||||
payload: ChatStreamPayload,
|
||||
options?: ChatMethodOptions,
|
||||
): Promise<Response> => {
|
||||
const { max_tokens, messages, model, temperature, top_p, tools } = payload;
|
||||
const {
|
||||
enabledContextCaching = true,
|
||||
max_tokens,
|
||||
messages,
|
||||
model,
|
||||
temperature,
|
||||
top_p,
|
||||
tools,
|
||||
} = payload;
|
||||
const inputStartAt = Date.now();
|
||||
const system_message = messages.find((m) => m.role === 'system');
|
||||
const user_messages = messages.filter((m) => m.role !== 'system');
|
||||
|
||||
@@ -159,17 +170,29 @@ export class LobeBedrockAI implements LobeRuntimeAI {
|
||||
{ hasConflict, normalizeTemperature: true, preferTemperature: true },
|
||||
);
|
||||
|
||||
const systemPrompts = !!system_message?.content
|
||||
? ([
|
||||
{
|
||||
cache_control: enabledContextCaching ? { type: 'ephemeral' } : undefined,
|
||||
text: system_message.content as string,
|
||||
type: 'text',
|
||||
},
|
||||
] as any)
|
||||
: undefined;
|
||||
|
||||
const anthropicPayload = {
|
||||
anthropic_version: 'bedrock-2023-05-31',
|
||||
max_tokens: max_tokens || 4096,
|
||||
messages: await buildAnthropicMessages(user_messages, { enabledContextCaching }),
|
||||
system: systemPrompts,
|
||||
temperature: resolvedParams.temperature,
|
||||
tools: buildAnthropicTools(tools, { enabledContextCaching }),
|
||||
top_p: resolvedParams.top_p,
|
||||
};
|
||||
|
||||
const command = new InvokeModelWithResponseStreamCommand({
|
||||
accept: 'application/json',
|
||||
body: JSON.stringify({
|
||||
anthropic_version: 'bedrock-2023-05-31',
|
||||
max_tokens: max_tokens || 4096,
|
||||
messages: await buildAnthropicMessages(user_messages),
|
||||
system: system_message?.content as string,
|
||||
temperature: resolvedParams.temperature,
|
||||
tools: buildAnthropicTools(tools),
|
||||
top_p: resolvedParams.top_p,
|
||||
}),
|
||||
body: JSON.stringify(anthropicPayload),
|
||||
contentType: 'application/json',
|
||||
modelId: model,
|
||||
});
|
||||
@@ -186,10 +209,21 @@ export class LobeBedrockAI implements LobeRuntimeAI {
|
||||
debugStream(debug).catch(console.error);
|
||||
}
|
||||
|
||||
const pricing = await getModelPricing(payload.model, ModelProvider.Bedrock);
|
||||
const cacheTTL = resolveCacheTTL({ ...payload, enabledContextCaching }, anthropicPayload);
|
||||
const pricingOptions = cacheTTL ? { lookupParams: { ttl: cacheTTL } } : undefined;
|
||||
|
||||
// Respond with the stream
|
||||
return StreamingResponse(AWSBedrockClaudeStream(prod, options?.callback), {
|
||||
headers: options?.headers,
|
||||
});
|
||||
return StreamingResponse(
|
||||
AWSBedrockClaudeStream(prod, {
|
||||
callbacks: options?.callback,
|
||||
inputStartAt,
|
||||
payload: { model, pricing, pricingOptions, provider: ModelProvider.Bedrock },
|
||||
}),
|
||||
{
|
||||
headers: options?.headers,
|
||||
},
|
||||
);
|
||||
} catch (e) {
|
||||
const err = e as Error & { $metadata: any };
|
||||
|
||||
|
||||
Reference in New Issue
Block a user