Skip to content

examples/dreambooth: fix missing weighting chunk when using prior preservation in Flux and SD3 LoRA training#13743

Merged
sayakpaul merged 4 commits into
huggingface:mainfrom
Dev-X25874:fix/dreambooth-prior-preservation-weighting-chunk
May 18, 2026
Merged

examples/dreambooth: fix missing weighting chunk when using prior preservation in Flux and SD3 LoRA training#13743
sayakpaul merged 4 commits into
huggingface:mainfrom
Dev-X25874:fix/dreambooth-prior-preservation-weighting-chunk

Conversation

@Dev-X25874
Copy link
Copy Markdown
Contributor

What does this PR do?

When --with_prior_preservation is enabled, the training batch concatenates
instance and class (prior) samples, so every per-sample tensor —
model_pred, target, sigmas, and therefore weighting — has shape
(2 * train_batch_size, ...).

Inside the loss block, model_pred and target are correctly split via
torch.chunk(..., 2, dim=0), but weighting was never chunked. This means:

  • weighting (size 2B) is broadcast against model_pred_prior and
    target_prior (size B), producing a loss tensor of the wrong shape and
    applying incorrectly paired timestep weights to the prior loss term.
  • The instance loss term also gets weights from the full unsplit weighting
    instead of only the instance-sample half.

The correct pattern already exists in train_dreambooth_lora_flux2.py:

weighting, weighting_prior = torch.chunk(weighting, 2, dim=0)

This PR applies the same fix to train_dreambooth_lora_flux.py and
train_dreambooth_lora_sd3.py, which were both missing it.

Fixes # (issue)

Before submitting

Who can review?

@sayakpaul

…target when using prior preservation (flux LoRA)
…target when using prior preservation (SD3 LoRA)
@github-actions github-actions Bot added examples size/S PR with diff < 50 LOC labels May 14, 2026
@Dev-X25874
Copy link
Copy Markdown
Contributor Author

Hi @sayakpaul, would you mind taking a look at this when you get a chance?

The bug is present in both train_dreambooth_lora_flux.py and train_dreambooth_lora_sd3.py — when --with_prior_preservation is enabled, weighting is never chunked alongside model_pred and target, causing incorrect timestep weights to be applied to the prior loss term. The fix already exists in train_dreambooth_lora_flux2.py (line 1832), so this PR simply backports it to the two older scripts. Happy to make any changes if needed!

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul sayakpaul merged commit 387a471 into huggingface:main May 18, 2026
16 of 30 checks passed
@Dev-X25874
Copy link
Copy Markdown
Contributor Author

Thank you @sayakpaul for the review and the quick merge! Really appreciate it.
Looking forward to contributing more! 🙏

@Dev-X25874 Dev-X25874 deleted the fix/dreambooth-prior-preservation-weighting-chunk branch May 18, 2026 10:52
Enderfga pushed a commit to Enderfga/diffusers that referenced this pull request May 19, 2026
…reservation in Flux and SD3 LoRA training (huggingface#13743)

* examples/dreambooth: chunk weighting tensor alongside model_pred and target when using prior preservation (flux LoRA)

* examples/dreambooth: chunk weighting tensor alongside model_pred and target when using prior preservation (SD3 LoRA)

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
@Dev-X25874
Copy link
Copy Markdown
Contributor Author

Hi @sayakpaul, hope you're doing well! A quick etiquette question — for PRs that fall under your area, should contributors tag you right when opening the PR? And if not, how long should we wait before tagging you if there's been no activity? Also, do you usually go through all PRs on your own or is an explicit tag always helpful to make sure nothing gets missed? Just want to be respectful of your time. Thanks a lot! 🙏

@sayakpaul
Copy link
Copy Markdown
Member

You can tag me whenever you feel comfortable. But as a sign of respect, waiting till 4-5 days is a good idea I think.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants