examples/dreambooth: fix missing weighting chunk when using prior preservation in Flux and SD3 LoRA training#13743
Conversation
…target when using prior preservation (flux LoRA)
…target when using prior preservation (SD3 LoRA)
|
Hi @sayakpaul, would you mind taking a look at this when you get a chance? The bug is present in both |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Thank you @sayakpaul for the review and the quick merge! Really appreciate it. |
…reservation in Flux and SD3 LoRA training (huggingface#13743) * examples/dreambooth: chunk weighting tensor alongside model_pred and target when using prior preservation (flux LoRA) * examples/dreambooth: chunk weighting tensor alongside model_pred and target when using prior preservation (SD3 LoRA) --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
|
Hi @sayakpaul, hope you're doing well! A quick etiquette question — for PRs that fall under your area, should contributors tag you right when opening the PR? And if not, how long should we wait before tagging you if there's been no activity? Also, do you usually go through all PRs on your own or is an explicit tag always helpful to make sure nothing gets missed? Just want to be respectful of your time. Thanks a lot! 🙏 |
|
You can tag me whenever you feel comfortable. But as a sign of respect, waiting till 4-5 days is a good idea I think. |
What does this PR do?
When
--with_prior_preservationis enabled, the training batch concatenatesinstance and class (prior) samples, so every per-sample tensor —
model_pred,target,sigmas, and thereforeweighting— has shape(2 * train_batch_size, ...).Inside the loss block,
model_predandtargetare correctly split viatorch.chunk(..., 2, dim=0), butweightingwas never chunked. This means:weighting(size2B) is broadcast againstmodel_pred_priorandtarget_prior(sizeB), producing a loss tensor of the wrong shape andapplying incorrectly paired timestep weights to the prior loss term.
weightinginstead of only the instance-sample half.
The correct pattern already exists in
train_dreambooth_lora_flux2.py:This PR applies the same fix to
train_dreambooth_lora_flux.pyandtrain_dreambooth_lora_sd3.py, which were both missing it.Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul