Skip to content

Add MXFP8 quantized_model_init memory profiler for FSDP2 qinit analysis#3008

Draft
savitha-eng wants to merge 5 commits into
NVIDIA:mainfrom
savitha-eng:savitha/mxfp8-memory-profiler
Draft

Add MXFP8 quantized_model_init memory profiler for FSDP2 qinit analysis#3008
savitha-eng wants to merge 5 commits into
NVIDIA:mainfrom
savitha-eng:savitha/mxfp8-memory-profiler

Conversation

@savitha-eng
Copy link
Copy Markdown

@savitha-eng savitha-eng commented May 18, 2026

Summary

Standalone memory profiler script for diagnosing MXFP8 quantized_model_init memory behavior with FSDP2. Creates one or more te.TransformerLayer blocks with 8B-scale dimensions, wraps with FSDP2 fully_shard, and runs forward+backward+step iterations while recording PyTorch memory history.

Issue observed: When using quantized_model_init + FSDP2, MXFP8 quantized weight tensors from mxfp8_tensor.py:quantize_impl are never freed. FSDP2 calls .view(numel,) to flatten params, which triggers _ViewFunc dequantize fallback, and the intermediate tensors leak. With --num-layers 4, the leaked memory accumulates per layer.

Quick repro (requires 2+ GPUs)

# BF16 baseline (control — no leak)
torchrun --nproc-per-node 2 examples/pytorch/quantized_model_init/single_block_memory_profile.py --mode bare-fsdp2

# MXFP8 + qinit + FSDP2 (shows leaked tensors)
torchrun --nproc-per-node 2 examples/pytorch/quantized_model_init/single_block_memory_profile.py --mode mxfp8-fsdp2

# 4 layers (shows cross-layer accumulation)
torchrun --nproc-per-node 2 examples/pytorch/quantized_model_init/single_block_memory_profile.py --mode mxfp8-fsdp2 --num-layers 4

Snapshots saved to /tmp/single_block_snapshots/ — view at https://pytorch.org/memory_viz

Available modes

Mode Description
bare BF16 baseline, no FP8, no FSDP2
mxfp8 MXFP8 + quantized_model_init, no FSDP2
fp8-no-qinit FP8 autocast without qinit, no FSDP2
mxfp8-no-qinit MXFP8 autocast without qinit, no FSDP2
bare-fsdp2 BF16 + FSDP2 (control)
mxfp8-fsdp2 MXFP8 + qinit + FSDP2 (repro)
fp8-no-qinit-fsdp2 FP8 autocast + FSDP2, no qinit
mxfp8-no-qinit-fsdp2 MXFP8 autocast + FSDP2, no qinit

Additional flags: --model-size {8b,70b}, --num-layers N, --no-hpiv, --recipe {mxfp8,float8block,auto}

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

  • Add examples/pytorch/quantized_model_init/single_block_memory_profile.py — self-contained memory profiler with 8 modes for comparing BF16 vs MXFP8 vs FP8 autocast, with and without FSDP2
  • No changes to TE library code

@savitha-eng savitha-eng changed the title Add MXFP8 single-block memory profiler for FSDP2 qinit analysis Add MXFP8 quantized_model_init memory profiler for FSDP2 qinit analysis May 19, 2026
@nvMelissa
Copy link
Copy Markdown
Collaborator

Capturing priority from @savitha-eng : This PR doesn't hard block anything — it's a reproducer for a memory discrepancy we observed with MXFP8 + quantized_model_init. Specifically, we're seeing a ~13% persistent memory overhead per rank with MXFP8 + qinit + FSDP2 vs. the BF16 + FSDP2 baseline (MXFP8 without qinit shows zero overhead). For an 8B model on 2 GPUs, that translates to ~5.7 GB/rank of extra memory. We'd expect MXFP8 + qinit to perform at least as well as BF16 or MXFP8 without qinit on memory — not worse. So while it isn't blocking, the priority is medium, since the delta may be significant enough to matter at scale with larger models/tighter memory budgets. cc: @sbhavani @ptrendx

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants