Training Memory Calculator

· LIVE

VRAM breakdown for training or inference, with optimizer state, activation memory, KV cache, and ZeRO/FSDP sharding. Formulas cited inline below.

Model
Precision
Batch size1
Sequence length4096 tokens
Optimizer
Activation recompute
ZeRO / FSDP stage
Per-GPU memory
Total per GPU
161.4 GB
Active arch
8.0B · 32L
Shard factor
Parameters (bf16)16.1 GB
Gradients16.1 GB
Optimizer state (adam, 12N)96.4 GB
Activations (selective recompute)18.3 GB
Framework overhead (~10%)14.7 GB
Fits on
A100 40GB
A100 80GB
H100 80GB
H200 141GB
B200 192GB
MI300X 192GB
Formulas & references

Parameters: N × bytes/param. bf16/fp16 = 2 B, fp8 = 1 B, fp4 = 0.5 B.

Gradients (training): N × bytes/param, same dtype as forward params in mixed precision.

Optimizer state per param: Adam mixed-precision = 12 B (fp32 master 4 B + m 4 B + v 4 B). Adam-8bit ≈ 6 B. SGD-momentum ≈ 8 B. Source: EleutherAI Transformer Math 101.

Activations per layer (training, no recompute): s·b·h·(34 + 5·a·s/h) bytes at bf16. Selective recompute drops the 5as/h attention term. Full recompute keeps only ~2·s·b·h per layer. Source: Korthikanti et al. 2022.

KV cache (inference, GQA): 2 × layers × kv_heads × head_dim × seq × batch × bytes/elem. The factor 2 covers both K and V; kv_heads reflects grouped-query attention sharing.

ZeRO / FSDP sharding: stage 1 shards optimizer state, stage 2 also shards gradients, stage 3 also shards parameters. Each by world size. Source: Rajbhandari et al. 2019 (ZeRO).

Framework overhead: a flat 10% accounts for the PyTorch caching allocator, NCCL buffers, and kernel workspaces. Real-world overhead typically ranges 5–20% depending on framework, fragmentation, and microbatching.

⚠ This is a back-of-envelope estimate. Real consumption varies with attention implementation (FlashAttention saves activation memory significantly), pipeline parallelism, gradient accumulation, and MoE routing. Verify with your training stack before committing hardware.

What it does#

Computes the per-GPU memory footprint for training or inference of a transformer model, broken down by component so you can spot the dominant cost and the right knob to turn.

Why it’s useful#

Existing online VRAM calculators are usually inference-only or miss the training pieces that actually matter — optimizer state (12N for Adam in mixed precision), activation memory under different recompute strategies, and how sharding stages redistribute the cost. This one cites every formula inline so engineers can verify the math themselves.

Formulas, briefly#

  • Parameters: N × bytes/param
  • Gradients (training): N × bytes/param
  • Adam optimizer: 12N bytes (fp32 master + m + v). Adam-8bit: ~6N. SGD-momentum: ~8N.
  • Activations per layer (Korthikanti 2022): s·b·h·(34 + 5·a·s/h). Selective recompute drops the attention term; full recompute keeps ~2sbh per layer.
  • KV cache (inference, GQA): 2 × layers × kv_heads × head_dim × seq × batch × bytes/elem
  • ZeRO/FSDP: stage 1 shards optimizer state; stage 2 also gradients; stage 3 also parameters — each by world size.
  • Framework overhead: flat 10% for PyTorch caching allocator, NCCL buffers, kernel workspaces. Range in practice: 5–20%.

References#

Limitations#

  • Assumes vanilla transformer with bf16 activations. FlashAttention saves activation memory significantly — not modeled here.
  • MoE models aren’t accounted for (expert-parallelism and routed activations have different memory profiles).
  • Pipeline parallelism and tensor parallelism aren’t modeled — only ZeRO-style data-parallel sharding.
  • Inference activations are approximated as a small working set; real engines (vLLM, TGI) use paged attention which has different memory characteristics.