Skip to main content

CUDA

Pie has two CUDA-oriented paths:

  • cuda_native: embedded C++/CUDA driver linked into the pie binary when the installed flavor includes CUDA.
  • dev: subprocess Python driver backed by PyTorch and FlashInfer, useful for development and model/kernel prototyping.

Use cuda_native for the standalone binary path. Use dev when you need the readable Python implementation or its platform extras.

Install

Install a CUDA-flavored pie binary:

curl -fsSL https://pie-project.org/install.sh | PIE_FLAVOR=cuda12.8 bash

Auto-detection chooses cuda13.0 for NVIDIA driver >= 580, cuda12.8 for driver >= 525, and portable otherwise. Valid CUDA flavors are cuda12.6, cuda12.8, cuda13.0, and the matching portable-cuda* variants.

For the Python dev driver:

pie driver dev install ~/.pie/venvs/dev --run
pie driver dev set venv ~/.pie/venvs/dev
pie driver dev doctor

The install recipe uses pie-driver-dev[cu128] by default. For source or manual installs, the driver wheel also exposes cu126, cu128, and metal extras.

cuda_native configuration

[model.driver]
type = "cuda_native"
device = ["cuda:0"]
tensor_parallel_size = 1
activation_dtype = "bfloat16"

[model.driver.options]
gpu_mem_utilization = 0.85
kv_page_size = 32
max_batch_tokens = 10240
max_batch_size = 512
max_num_kv_pages = 1024
swap_pool_size = 0
weight_dtype = "bfloat16"
runtime_quant = ""
ready_timeout_s = 600.0
shutdown_timeout_s = 5.0
KeyDefaultDescription
binary_path""Accepted for older config compatibility, ignored by standalone pie; the driver is embedded.
gpu_mem_utilization0.85Fraction of free GPU memory the driver claims.
kv_page_size32KV cache page size in tokens.
max_batch_tokens10240Cap on tokens per batch.
max_batch_size512Cap on sequences per batch.
max_num_kv_pages1024KV cache page count. KV memory scales linearly.
swap_pool_size0Pinned host KV-page count for swap-out. 0 disables swap.
weight_dtype"bfloat16"Weight precision.
runtime_quant""Empty disables runtime quantization; "fp8" enables the current FP8 path where supported.
ready_timeout_s600.0Seconds to wait for driver readiness.
shutdown_timeout_s5.0Seconds to wait for graceful shutdown.

dev configuration

[model.driver]
type = "dev"
device = ["cuda:0"]
tensor_parallel_size = 1
activation_dtype = "bfloat16"

[model.driver.options]
venv = "/home/me/.pie/venvs/dev"
gpu_mem_utilization = 0.8
max_batch_tokens = 10240
max_batch_size = 512
max_dist_size = 32
max_num_embeds = 128
max_num_adapters = 32
max_adapter_rank = 8
kv_page_size = 16
weight_dtype = "auto"
cpu_mem_budget_in_gb = 0
KeyDefaultDescription
venv / pythonunsetOptional per-model interpreter override. Otherwise pie driver dev set ..., PIE_PYTHON, the active venv, and python3 are checked in order.
gpu_mem_utilization0.8Fraction of free GPU memory used as the KV budget.
max_batch_tokens10240Maximum tokens summed across all sequences in a batch.
max_batch_size512Maximum sequences per batch.
max_dist_size32Cap on Distribution { k } probe size.
max_num_embeds128Cap on embedding lookups per batch.
max_num_adapters32Adapter slot capacity.
max_adapter_rank8Maximum LoRA rank.
kv_page_size16KV cache page size in tokens.
weight_dtype"auto"auto, float32, float16, bfloat16, int4, int8, or float8.
cpu_mem_budget_in_gb0Pinned host pool for KV swap, in GiB. 0 disables swap.

Supported architectures

cuda_native covers the architectures ported to driver/cuda/src/. The dev driver has the broader Python model roster under driver/dev/src/pie_driver_dev/model/.

FamilyHF model_typeNotes
Llama 3.x / Mistral-compatiblellamaInstruct and base checkpoints.
Qwen 2.xqwen2Qwen2 and Qwen2.5.
Qwen 3.xqwen3Includes the default Qwen/Qwen3-0.6B.
Qwen 3.5qwen3_5Supported where the selected CUDA path has the graph implemented.
Phi-3phi3Microsoft Phi-3 family.
MixtralmixtralMoE path.
Gemma 2 / 3 / 4gemma2, gemma3_text, gemma4_text, gemma4Text checkpoints.
Mistral 3mistral3Ministral-class checkpoints.
OLMo 3olmo3AI2 OLMo 3.
GPT-OSSgptoss, gpt_ossCommunity GPT-OSS variants.

Run pie model list to see whether cached HuggingFace repos are compatible with your installed drivers.

Quantization

For dev, set weight_dtype under [model.driver.options]:

[model.driver.options] # type = "dev"
weight_dtype = "float8" # auto / float32 / float16 / bfloat16 / int4 / int8 / float8

For cuda_native, weight_dtype = "bfloat16" is the default and runtime_quant = "fp8" enables the current runtime quantization path where supported.