Skip to main content

Prompt Caching

Oftentimes, you may want to reuse the KV cache across multiple inferlet invocations to save computation and latency. Pie provides built-in support for inter-inferlet KV cache reuse. In this tutorial, we will cover the preliminary concepts about how Pie virtually manages the KV cache, and how you can export/import KV cache pages to/from different inferlets.

KV cache pages

Pie uses paged attention to manage the KV cache. This means the KV cache of a context is a collection of fixed-size pages (typically 8 to 16 tokens per page). Pie virtualizes the KV cache pages just like virtual memory in an operating system. Each context has its own virtual address space for the KV cache pages, and the Pie runtime manages the mapping between virtual addresses and physical pages. For example, to allocate/deallocate a new KV cache page,

let q = model.create_queue();
// allocate ten new KV cache page
let pages = q.allocate_kv_pages(10);
// deallocate the KV cache page when no longer needed
q.deallocate_kv_pages(pages);

The pages is a virtual address that can be used to refer to the KV cache page. You can pass the pages to the other APIs that require a KV cache page, such as forward. In practice, you typically do not need to manually allocate/deallocate KV cache pages, as the Pie standard library provides higher-level abstractions for context management. However, it is useful to understand how the KV cache pages are managed under the hood.

Exporting/importing KV cache pages

You can "export" the KV cache pages with the specific id attached to it. Exported KV cache pages are treated as persistent objects that outlive the inferlet instance that created them. For instance,

q.export_kv_pages(pages, "my_page_id");

Other inferlets can "import" the exported KV cache pages by specifying the same id. For instance,

let imported_pages = q.import_kv_pages("my_page_id");

Creating a context from raw KV cache pages

You can create a new context from the raw KV cache pages using Context::from_imported_state constructor.

let mut ctx = Context::from_imported_state(&model, imported_pages, /** prefix tokens **/, /** kv page last len **/);

Note that Pie does not automatically track which portion of the KV pages are filled with valid tokens and which are not. Context abstraction only keeps track of the last partially filled page length, so you need to provide the prefix tokens and the last page length when creating a context from raw KV pages.

Sharing states across different inferlets

There are scenarios where you may want to persistently share states across different inferlets, which requires different mechanisms than broadcast/subscribe. Pie provides a lightweight key-value store that can be used to share states across different inferlets. You can think of it as a simple database that is accessible by all inferlets.

To use the key-value store, you can use the inferlet::store_set, inferlet::store_get, inferlet::store_delete, and inferlet::store_exists APIs. For example:

use inferlet::{Args, Result, store_set, store_get};
#[inferlet::main]
async fn main(mut args: Args) -> Result<()> {

store_set("some key", "300").await;
let retrieved_value: String = store_get("some key").await?;
println!("Retrieved value: {}", retrieved_value);
Ok(())
}
warning

The key-value store currently does not provide separate namespaces for different users, so please be careful when using it to avoid key collisions.

Prefix caching example

Putting it all together, here is a simple example of how you can export and import KV cache pages to reuse the system prompt across different inferlet invocations.

#[derive(Serialize, Deserialize)]
struct CachedPrefixState {
token_ids: Vec<u32>,
kv_page_last_len: usize,
}

#[inferlet::main]
async fn main(mut args: Args) -> Result<()> {

let system_prompt = args
.opt_value_from_str(["-s", "--system-prompt"])?
.unwrap_or_else(|| "You are a helpful pirate who speaks in pirate slang.".to_string());
let prompt = args
.opt_value_from_str(["-p", "--prompt"])?
.unwrap_or_else(|| "What is the capital of South Korea?".to_string());
let max_num_outputs: usize = args
.opt_value_from_str(["-n", "--max-tokens"])?
.unwrap_or(128);
let invalidate_cache: bool = args.contains(["-i", "--invalidate-cache"]);

let mut hasher = Sha256::new();
hasher.update(system_prompt.as_bytes());
let prefix_hash = format!("{:x}", hasher.finalize());

let cache_flag_key = format!("cached_v1_{}", prefix_hash);
let cache_export_name = format!("kv_pages_v1_{}", prefix_hash);
let cache_state_key = format!("state_v1_{}", prefix_hash);

let start = Instant::now();
let model = inferlet::get_auto_model();
let mut ctx: Context;

if invalidate_cache && inferlet::store_get(&cache_flag_key) == Some("true".to_string()) {
model
.create_queue()
.release_exported_resources(Resource::KvPage, &cache_export_name);
inferlet::store_set(&cache_flag_key, "false");
println!("🗑️ Cache invalidated for this prefix.");
}

if inferlet::store_get(&cache_flag_key) == Some("true".to_string()) {
println!("Cache HIT. Loading pre-computed prefix from store.");

let imported_page_ptrs = model
.create_queue()
.import_resource(Resource::KvPage, &cache_export_name);
let state_json = inferlet::store_get(&cache_state_key)
.ok_or_else(|| inferlet::anyhow!("Cache Inconsistency: State missing"))?;
let state: CachedPrefixState = serde_json::from_str(&state_json)?;

ctx = Context::from_imported_state(
&model,
imported_page_ptrs,
state.token_ids,
state.kv_page_last_len,
);
} else {
println!("Cache MISS. Computing and saving prefix to store.");
let mut prefill_ctx = model.create_context();

prefill_ctx.fill(&system_prompt);
prefill_ctx.flush().await; // Process the tokens.

// Save the computed KV pages and token state.
let page_ptrs = prefill_ctx.get_kv_page_ptrs();
let state_to_cache = CachedPrefixState {
token_ids: prefill_ctx.get_token_ids().to_vec(),
kv_page_last_len: prefill_ctx.get_kv_page_last_len(),
};

prefill_ctx
.queue()
.export_resource(Resource::KvPage, &page_ptrs, &cache_export_name);

let state_json = serde_json::to_string(&state_to_cache)?;
inferlet::store_set(&cache_state_key, &state_json);
inferlet::store_set(&cache_flag_key, "true");

ctx = prefill_ctx;
}


ctx.fill_user(&prompt);

let stop_cond = stop_condition::max_len(max_num_outputs)
.or(stop_condition::ends_with_any(model.eos_tokens()));

let text = ctx.generate(Sampler::greedy(), stop_cond).await;

Ok(())
}