wan: add dual-GPU model-parallel path for Wan 2.x LoRA training (depends on #1435)#1436
wan: add dual-GPU model-parallel path for Wan 2.x LoRA training (depends on #1435)#1436genno-whittlery wants to merge 2 commits into
Conversation
Off by default -- no behavior change unless WAN_DUAL_GPU=true is set. Wan 2.2 14B variants (I2V-A14B, T2V-A14B, S2V-14B, etc.) are ~28 GB in bf16 -- the weights fit on one 32 GB consumer card with fp8 quant, but video training activations at 480x832x49 frames + gradient checkpointing routinely push the actual step over 32 GB even on a 14B model. Splitting the transformer blocks across two GPUs gives training-step headroom that single-GPU users can't otherwise reach without dropping resolution or frame count. What changed: - examples/wanvideo/model_training/wan_dual_gpu_diffsynth.py (new): ~150 LOC helper. Splits WanModel.blocks at the midpoint across cuda:0/cuda:1. Registers forward_pre_hook on every cuda:1 block (not just the boundary -- Wan's forward passes loop-level constants context / t_mod / freqs positionally to each iteration, so a boundary-only hook would leave subsequent blocks receiving cuda:0 tensors). Bridges activations back to cuda:0 at the head module. Also explicitly moves WanModel.freqs (a tuple of plain CPU tensors, not registered buffers) so .to(device) doesn't miss them. - examples/wanvideo/model_training/train.py: forces CPU model load when WAN_DUAL_GPU=true (so the bf16 transformer doesn't pre-allocate on cuda:0 before split), runs torchao Float8WeightOnlyConfig quantize_ with the same LoRA-skip filter used by the FLUX.2 port (skips lora_A/lora_B Linear submodules -- otherwise their requires_grad is stripped and backward fails), then calls enable_wan_dual_gpu(model.pipe.dit) after PEFT has injected LoRA so block.to(device) carries LoRA params with their base layers. Also sets FLUX2_DUAL_GPU=true after distribute so the existing runner.py branch from PR modelscope#1434 catches the device_placement=[False, ...] case in accelerator.prepare without needing a parallel WAN_DUAL_GPU branch there. Depends on modelscope#1435 (patchify fix). The current main has a broken WanModel.patchify that returns the wrong shape and arity; Wan training fails immediately at the first forward call regardless of dual-GPU. Once modelscope#1435 lands, both single-GPU and dual-GPU Wan training paths work. Validated locally on 2x RTX 5090 with a synthetic 8-layer WanModel (same architecture shape as real Wan 2.2, miniaturized to fit a quick smoke test): forward + backward complete across the cross- device split, output round-trips to the original (B, C, T, H, W) shape, LoRA gradients land on both cuda:0 and cuda:1 (proving cross- device autograd). Same patch shape as the validated FLUX.2 port in PR modelscope#1434 from this account. Both share the runner.py model-parallel branch.
There was a problem hiding this comment.
Code Review
This pull request implements dual-GPU model parallelism for Wan video DiT training, allowing the 14B model to be trained on consumer GPUs by splitting it across two devices. It introduces a new helper module for device distribution and updates the training script to perform CPU-side FP8 quantization before moving model blocks. Reviewers suggested extending the device placement logic and activation hooks to support additional Wan variants like WanToDance. Additionally, it was recommended to use a more generic environment variable name instead of FLUX2_DUAL_GPU to signal dual-GPU mode in the training runner.
| if hasattr(dit, "img_emb") and dit.img_emb is not None: | ||
| dit.img_emb.to(cuda0) | ||
| if hasattr(dit, "ref_conv") and dit.ref_conv is not None: | ||
| dit.ref_conv.to(cuda0) | ||
| if hasattr(dit, "control_adapter") and dit.control_adapter is not None: | ||
| dit.control_adapter.to(cuda0) |
There was a problem hiding this comment.
Several optional modules used in specific Wan variants (such as WanToDance) are not being moved to cuda:0. This will lead to device mismatch errors during training if these variants are used in dual-GPU mode.
| if hasattr(dit, "img_emb") and dit.img_emb is not None: | |
| dit.img_emb.to(cuda0) | |
| if hasattr(dit, "ref_conv") and dit.ref_conv is not None: | |
| dit.ref_conv.to(cuda0) | |
| if hasattr(dit, "control_adapter") and dit.control_adapter is not None: | |
| dit.control_adapter.to(cuda0) | |
| if hasattr(dit, "img_emb") and dit.img_emb is not None: | |
| dit.img_emb.to(cuda0) | |
| if hasattr(dit, "ref_conv") and dit.ref_conv is not None: | |
| dit.ref_conv.to(cuda0) | |
| if hasattr(dit, "control_adapter") and dit.control_adapter is not None: | |
| dit.control_adapter.to(cuda0) | |
| # Move WanToDance specific modules | |
| for attr in ["img_emb_refimage", "img_emb_refface", "music_projection", "music_encoder", "patch_embedding_global", "head_global"]: | |
| if hasattr(dit, attr) and getattr(dit, attr) is not None: | |
| getattr(dit, attr).to(cuda0) |
| dit.head.register_forward_pre_hook( | ||
| _make_device_bridge_hook(cuda0), with_kwargs=True | ||
| ) |
There was a problem hiding this comment.
The activation bridge hook is only registered on dit.head. However, some Wan variants (like WanToDance) use dit.head_global instead. If head_global is used, the activations coming from the last block on cuda:1 will not be moved back to cuda:0, causing a device mismatch error.
| dit.head.register_forward_pre_hook( | |
| _make_device_bridge_hook(cuda0), with_kwargs=True | |
| ) | |
| # Bridge activations back to cuda:0 for the head + unpatchify. | |
| for head_attr in ["head", "head_global"]: | |
| if hasattr(dit, head_attr) and getattr(dit, head_attr) is not None: | |
| getattr(dit, head_attr).register_forward_pre_hook( | |
| _make_device_bridge_hook(cuda0), with_kwargs=True | |
| ) |
| # Signal launch_training_task to skip its own model.to() move | ||
| # (runner.py keys off FLUX2_DUAL_GPU for that branch; reuse to | ||
| # avoid duplicating the device-placement logic). | ||
| os.environ["FLUX2_DUAL_GPU"] = "true" |
There was a problem hiding this comment.
Setting the FLUX2_DUAL_GPU environment variable to control behavior for a Wan model is confusing and creates a hidden, model-specific dependency in the training runner. It would be better to use a more generic environment variable name, such as DIFFSYNTH_DUAL_GPU, and update the logic in runner.py to recognize this shared flag for any model-parallel training path.
ec4e3dd to
61a4c44
Compare
Per @gemini-code-assist review on modelscope#1436: use a model-neutral env var name. This change: - WAN_DUAL_GPU -> DIFFSYNTH_DUAL_GPU - WAN_DUAL_GPU_SPLIT_AT -> DIFFSYNTH_DUAL_GPU_SPLIT_AT The same generic name will be adopted in modelscope#1434 (flux2 dual-GPU) so a single env var controls dual-GPU mode for both Wan and FLUX.2 paths. Helper function name (is_dual_gpu_enabled) and helper module file name (wan_dual_gpu_diffsynth.py) kept Wan-specific to make the import call sites at the entry-point clearer. Also dropped the now-redundant os.environ["FLUX2_DUAL_GPU"]="true" re-broadcast in train.py — the launcher-set DIFFSYNTH_DUAL_GPU is inherited by the same process and read by both is_dual_gpu_enabled() and runner.py's gate (after modelscope#1434 lands its matching rename). Added a comment block making the implicit runtime dependency on modelscope#1434's runner.py change explicit.
Per @gemini-code-assist review on modelscope#1436: model-neutral env var name so a single signal controls dual-GPU mode for both FLUX.2 and Wan paths. This change: - FLUX2_DUAL_GPU -> DIFFSYNTH_DUAL_GPU - FLUX2_DUAL_GPU_SPLIT_AT -> DIFFSYNTH_DUAL_GPU_SPLIT_AT Function names (enable_flux2_dual_gpu) and the helper module name (flux2_dual_gpu_diffsynth.py) keep the model-specific naming since they're the entry points users explicitly import for FLUX.2 training. Coordinated with PR modelscope#1436 which adopts the same DIFFSYNTH_DUAL_GPU gate on the Wan side. Once both merge, one env var enables dual-GPU for either model family; runner.py recognizes the gate regardless of which family the training script targets.
|
Out of draft — both review gates passed:
The WanToDance extension suggestion is parked as a follow-up — kept this PR focused on the standard |
|
Thank you for the update and for coordinating the environment variable naming across the FLUX.2 and Wan model families. The validation results on the 2× RTX 5090 setup look excellent, and it's great to see the dual-GPU model-parallel path successfully handling the memory requirements for the 14B variants. I have no further comments on this implementation. Good luck with the remaining cross-GPU-vendor validation. |
|
Single-card fp8 baseline — head-to-head with dual-GPU bf16 Ran the upstream
Takeaway: at this resolution / frame count, the dual-GPU bf16 path is not measurably faster than the supported single-card fp8 recipe. Cross-device hooks add per-block bridging overhead that approximately cancels the speedup from fp8 weights being routed through the slower fp8 GEMM path on Blackwell. Where dual-GPU actually wins:
So the framing isn't "dual-GPU is faster" — it's "dual-GPU lets you train Wan 2.2 A14B in full bf16 precision on a pair of 32 GB cards, where the single-card path requires fp8 weights everywhere." For users training characters / fine-detail subjects, that precision matters; for users who just want a working LoRA in any precision, the single-card recipe in |
|
Adding context: the 24 GB market is the real motivation. The 32 GB / 32 GB benchmark above understates the value because the single-card fp8 path actually fits on a single 5090. On the bulk of the consumer-GPU population (24 GB cards: RTX 3090, RTX 4090, RTX A5000) the story is very different:
In other words: this isn't really about being faster than the single-card fp8 recipe — it's about enabling Wan 2.2 14B training in clean bf16 on hardware that otherwise can't run it at all, or only with the heaviest of the We don't have a 2× 24 GB rig on the bench right now to post measured numbers, but the architecture (~7 GB DiT half on each + bf16 T5 on cuda:0) leaves comfortable headroom on 24 GB. Happy to validate on 2× 3090/4090 if a maintainer has access and wants concrete numbers before merge. |
Summary
Adds an env-var-gated dual-GPU model-parallel path for Wan 2.x LoRA training in
examples/wanvideo/model_training/. SplitsWanModel.blocksat the midpoint acrosscuda:0/cuda:1, with per-block forward pre-hooks bridging the cross-device activation flow.Off by default — no behavior change unless
DIFFSYNTH_DUAL_GPU=trueis set.Depends on #1435 (the patchify fix). The current main has a broken
WanModel.patchifythat returns the wrong shape and arity; Wan training fails immediately at the first forward call regardless of dual-GPU. Once #1435 lands, both single-GPU and dual-GPU Wan training paths work.Depends on #1434 (the FLUX.2 dual-GPU PR) for the
diffsynth/diffusion/runner.pyenv-gated skip ofmodel.to(accelerator.device). Without that change, runner.py moves the manually-split DiT back to a single device and silently undoes the distribute. Both PRs use the sameDIFFSYNTH_DUAL_GPUenv var so one signal controls both Wan and FLUX.2 paths.Why
The Wan 2.2 14B variants (I2V-A14B, T2V-A14B, S2V-14B, ...) are ~28 GB in bf16. The weights fit on a single 32 GB consumer card after fp8 weight-only quant, but video training activations at 480 × 832 × 49 frames with gradient checkpointing routinely push the actual step memory over 32 GB even on a 14B model. The single-GPU answer is to drop resolution / frame count / precision; the dual-GPU answer is to split the transformer blocks across two cards and let each side hold half the activations.
This patch is the same shape as the validated FLUX.2 port in #1434 from this account, adapted for Wan's simpler block structure (one
DiTBlocktype vs. FLUX.2's double + single split). Cross-trainer write-up at https://github.com/genno-whittlery/flux2-dual-gpu-lora.What changed
examples/wanvideo/model_training/wan_dual_gpu_diffsynth.py(new, ~150 LOC)The helper. Distributes:
cuda:0:patch_embedding,text_embedding,time_embedding,time_projection, first half ofblocks,head, plus optionalimg_emb/ref_conv/control_adaptercuda:1: second half ofblocksRegisters
forward_pre_hookwithwith_kwargs=Trueon every block inblocks[split_at:]—WanModel.forwardpasses loop-level constants (context,t_mod,freqs) positionally to every iteration, so a boundary-only hook would leave subsequent cuda:1 blocks receiving the cuda:0 originals. Also one hook onheadbridging activations back to cuda:0 for the final layers.Edge case worth flagging:
WanModel.freqsis a tuple of plain CPU tensors, not registered buffers, sonn.Module.to(device)doesn't move it. The helper handles this explicitly.examples/wanvideo/model_training/train.pyDIFFSYNTH_DUAL_GPU=true(so the ~28 GB bf16 transformer doesn't pre-allocate on cuda:0 before split).torchao.quantize_withFloat8WeightOnlyConfig+ afilter_fnthat excludeslora_A/lora_BLinear submodules (otherwise theirrequires_gradis stripped and backward fails). Same pattern from flux2: dual-GPU model-parallel + transformers 5.8 compat + Mistral-on-CPU data_process #1434.enable_wan_dual_gpu(model.pipe.dit)after PEFT has injected LoRA soblock.to(device)carries LoRA params with their base layers.pipe.text_encoderandpipe.vaetocuda:0. The runner.py globalmodel.to()is skipped (per flux2: dual-GPU model-parallel + transformers 5.8 compat + Mistral-on-CPU data_process #1434's gate), so non-DiT components otherwise stay on CPU and the first encode step crashes with a cross-deviceindex_select. FLUX.2 doesn't need this because Mistral is pre-cached on CPU; Wan encodes T5 on-the-fly during the training step.Validation
End-to-end smoke train on 2× RTX 5090 (Wan2.2-I2V-A14B high_noise_model, full bf16 base + fp8 weight-only DiT, LoRA rank 32 on
q,k,v,o,ffn.0,ffn.2, 480×832×49 frames, gradient checkpointing + offload, 5 steps):Single-card fp8 baseline (
--fp8_modelson all four models, one 5090, no dual-GPU) for head-to-head s/it comparison is a planned follow-up — will post as a comment.Notes for reviewers
WAN_DUAL_GPU->DIFFSYNTH_DUAL_GPUhere, and coordinatedFLUX2_DUAL_GPU->DIFFSYNTH_DUAL_GPUin flux2: dual-GPU model-parallel + transformers 5.8 compat + Mistral-on-CPU data_process #1434. One env var controls dual-GPU mode for either model family. The deados.environ["FLUX2_DUAL_GPU"]="true"re-broadcast in train.py is gone — the launcher-set env var is inherited by the same process and read by bothis_dual_gpu_enabled()and the runner.py gate.WanModel's standardblockslist; WanToDance variants have their own block list + different forward signatures that warrant a separate hooks pass to do correctly.Test plan