Skip to content

Commit

Permalink
feat(llm): Add prompt caching for Anthropic Claude models
Browse files Browse the repository at this point in the history
Add prompt caching parameters for all Claude-3 series models, supporting tagged text
caching to improve response speed. Each model can cache up to 4 text blocks.
  • Loading branch information
赵旭阳 committed Dec 27, 2024
1 parent 55c327f commit 1cecc25
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ parameter_rules:
max: 8192
- name: response_format
use_template: response_format
- name: prompt_caching
label:
en_US: Prompt Caching
zh_Hans: 提示词缓存
type: boolean
required: false
help:
zh_Hans: 缓存使用<prompt-cache></prompt-cache> 包裹的提示词(最多 4 组,每组 1024+ 个 token)
en_US: <prompt-cache>prompts</prompt-cache> 4- blocks, 1024+ tokens
pricing:
input: '3.00'
output: '15.00'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ parameter_rules:
max: 8192
- name: response_format
use_template: response_format
- name: prompt_caching
label:
en_US: Prompt Caching
zh_Hans: 提示词缓存
type: boolean
required: false
help:
zh_Hans: 缓存使用<prompt-cache></prompt-cache> 包裹的提示词(最多 4 组,每组 1024+ 个 token)
en_US: <prompt-cache>prompts</prompt-cache> 4- blocks, 1024+ tokens
pricing:
input: '3.00'
output: '15.00'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ parameter_rules:
max: 4096
- name: response_format
use_template: response_format
- name: prompt_caching
label:
en_US: Prompt Caching
zh_Hans: 提示词缓存
type: boolean
required: false
help:
zh_Hans: 缓存使用<prompt-cache></prompt-cache> 包裹的提示词(最多 4 组,每组 1024+ 个 token)
en_US: <prompt-cache>prompts</prompt-cache> 4- blocks, 1024+ tokens
pricing:
input: '0.25'
output: '1.25'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ parameter_rules:
max: 4096
- name: response_format
use_template: response_format
- name: prompt_caching
label:
en_US: Prompt Caching
zh_Hans: 提示词缓存
type: boolean
required: false
help:
zh_Hans: 缓存使用<prompt-cache></prompt-cache> 包裹的提示词(最多 4 组,每组 1024+ 个 token)
en_US: <prompt-cache>prompts</prompt-cache> 4- blocks, 1024+ tokens
pricing:
input: '15.00'
output: '75.00'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ parameter_rules:
max: 4096
- name: response_format
use_template: response_format
- name: prompt_caching
label:
en_US: Prompt Caching
zh_Hans: 提示词缓存
type: boolean
required: false
help:
zh_Hans: 缓存使用<prompt-cache></prompt-cache> 包裹的提示词(最多 4 组,每组 1024+ 个 token)
en_US: <prompt-cache>prompts</prompt-cache> 4- blocks, 1024+ tokens
pricing:
input: '3.00'
output: '15.00'
Expand Down
35 changes: 31 additions & 4 deletions api/core/model_runtime/model_providers/anthropic/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import json
from collections.abc import Generator, Sequence
import re
from collections.abc import Generator, Iterable, Sequence
from typing import Optional, Union, cast

import anthropic
Expand Down Expand Up @@ -127,19 +128,31 @@ def _chat_generate(
extra_model_kwargs["system"] = system

# Add the new header for claude-3-5-sonnet-20240620 model
extra_headers = {}
beta_flags = []
if model == "claude-3-5-sonnet-20240620":
if model_parameters.get("max_tokens", 0) > 4096:
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
beta_flags.append("max-tokens-3-5-sonnet-2024-07-15")

if any(
isinstance(content, DocumentPromptMessageContent)
for prompt_message in prompt_messages
if isinstance(prompt_message.content, list)
for content in prompt_message.content
):
extra_headers["anthropic-beta"] = "pdfs-2024-09-25"
beta_flags.append("pdfs-2024-09-25")

if (
any(s in model for s in ["claude-3-5-sonnet", "claude-3-haiku", "claude-3-opus"])
and model_parameters.get("prompt_caching") is True
):
# remove prompt_caching parameter from model_parameters
model_parameters.pop("prompt_caching")
# append prompt-caching-2024-07-31
beta_flags.append("prompt-caching-2024-07-31")
extra_model_kwargs["system"] = self.parse_prompt_with_ephemeral_tags(system)
extra_headers = {}
if beta_flags:
extra_headers["anthropic-beta"] = ",".join(beta_flags)
if tools:
extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools]
response = client.beta.tools.messages.create(
Expand Down Expand Up @@ -652,3 +665,17 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]
anthropic.APIError,
],
}

def parse_prompt_with_ephemeral_tags(self, system: str) -> Iterable[ToolsBetaMessage]:
parts = re.split(r"(<prompt-cache>.*?</prompt-cache>)", system, flags=re.DOTALL)

result: list[ToolsBetaMessage] = []
for part in parts:
if part.strip(): # ignore white
if part.startswith("<prompt-cache>") and part.endswith("</prompt-cache>"):
text = part[14:-15].strip()
result.append({"text": text, "type": "text", "cache_control": {"type": "ephemeral"}})
else:
result.append({"text": part.strip(), "type": "text"})

return result

0 comments on commit 1cecc25

Please sign in to comment.