Allow bucket reshuffling with DreamBooth caches#13712
Conversation
There was a problem hiding this comment.
Pull request overview
Enables epoch-wise reshuffling of DreamBooth bucketed batches in the Flux2 DreamBooth LoRA example scripts while keeping cached latents and (custom-caption) prompt/text embeddings correctly aligned by switching caches from step-indexing to dataset-sample indexing.
Changes:
- Add per-sample
indexto dataset items and propagate it throughcollate_fnso caches can be keyed by sample index rather than dataloader step. - Rework latent/prompt-embedding caching to precompute via a non-dropping cache dataloader and store per-sample cached tensors.
- Update
BucketBatchSamplerto reshuffle indices/batches on each__iter__()call (epoch reshuffle) while keeping__len__stable.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| examples/dreambooth/train_dreambooth_lora_flux2.py | Switch cache alignment to sample-indexing and re-enable epoch reshuffling for bucket batches. |
| examples/dreambooth/train_dreambooth_lora_flux2_klein.py | Same cache-by-index + epoch reshuffle changes for the Klein variant of the training script. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
6cf74c0 to
b022936
Compare
|
@claude can you do a review here? |
|
Claude finished @sayakpaul's task in 5m 3s —— View job Review: Allow bucket reshuffling with DreamBooth caches
Overall: The approach is sound. Switching from step-indexed to sample-indexed caching correctly decouples cache lookups from batch ordering, enabling epoch reshuffling without misaligned cache reads. The non-dropping cache dataloader ( No blocking issues found. A few observations and suggestions below. Correctness traceI traced the full flow and confirmed the design holds:
Suggestions (non-blocking)
SummaryThe core design (sample-indexed caches + |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@sayakpaul Thanks for the review. |
sayakpaul
left a comment
There was a problem hiding this comment.
The diff looks very complicated to me, to the point where it's getting increasingly difficult for me to wrap my head around. Can you please explain the major changes in line?
| batch = { | ||
| "indices": indices, | ||
| "pixel_values": pixel_values, | ||
| "instance_prompts": instance_prompts, |
There was a problem hiding this comment.
Why do we need the instance prompts?
There was a problem hiding this comment.
With train_batch_size=2, custom captions, and prior preservation, collate_fn receives two examples:
examples = [
{
"index": 0,
"instance_prompt": "a sks dog on grass",
"instance_images": inst_img_0,
"class_prompt": "a dog",
"class_images": class_img_0,
},
{
"index": 1,
"instance_prompt": "a sks dog near water",
"instance_images": inst_img_1,
"class_prompt": "a dog",
"class_images": class_img_1,
},
]First, it stores sample indices and instance image tensors:
indices = [example["index"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]Result:
indices = [0, 1]
pixel_values = [inst_img_0, inst_img_1]Then it builds two prompt lists from the same source:
instance_prompts = [example["instance_prompt"] for example in examples]
prompts = [example["instance_prompt"] for example in examples]Result:
instance_prompts = [
"a sks dog on grass",
"a sks dog near water",
]
prompts = [
"a sks dog on grass",
"a sks dog near water",
]If prior preservation is enabled, collate_fn appends class images and class prompts only to the training batch lists:
if with_prior_preservation:
pixel_values += [example["class_images"] for example in examples]
prompts += [example["class_prompt"] for example in examples]Result:
pixel_values = [
inst_img_0,
inst_img_1,
class_img_0,
class_img_1,
]
prompts = [
"a sks dog on grass",
"a sks dog near water",
"a dog",
"a dog",
]instance_prompts does not change:
instance_prompts = [
"a sks dog on grass",
"a sks dog near water",
]Batch:
batch = {
"indices": [0, 1],
"pixel_values": tensor([inst_img_0, inst_img_1, class_img_0, class_img_1]),
"instance_prompts": [
"a sks dog on grass",
"a sks dog near water",
],
"prompts": [
"a sks dog on grass",
"a sks dog near water",
"a dog",
"a dog",
],
}Why is the instance-only prompt list needed?
During custom-caption precompute, we cache prompt embeddings per dataset sample.
See lines 1604-1606:
if train_dataset.custom_instance_prompts:
prompt_embeds_cache = [None] * train_dataset.num_instance_images
text_ids_cache = [None] * train_dataset.num_instance_imagesThen, inside the cache loop, we need to encode only the instance captions, not the class prompts.
See lines 1632-1646:
if train_dataset.custom_instance_prompts:
prompt_embeds, text_ids = compute_text_embeddings(
batch["instance_prompts"], text_encoding_pipeline
)
for i, idx in enumerate(sample_indices):
prompt_embeds_cache[idx] = prompt_embeds[i : i + 1]
text_ids_cache[idx] = text_ids[i : i + 1]Using the example:
batch["instance_prompts"] = ["a sks dog on grass", "a sks dog near water"]So we cache:
prompt_embeds_cache[0] = embed("a sks dog on grass")
prompt_embeds_cache[1] = embed("a sks dog near water")Custom captions are per instance image.
Class prompts are not cached the same way. The class prompt is shared. Usually it is one fixed string:
--class_prompt "a dog"So every class image uses the same prompt.
The script already encodes this fixed class prompt separately before training:
class_prompt_hidden_states, class_text_ids = compute_text_embeddings(
args.class_prompt, text_encoding_pipeline
)Then the training loop repeats the class embedding for the current batch size.
See lines 1791-1795:
if args.with_prior_preservation:
prompt_embeds = torch.cat(
[prompt_embeds, class_prompt_hidden_states.repeat(len(sample_indices), 1, 1)], dim=0
)
text_ids = torch.cat([text_ids, class_text_ids.repeat(len(sample_indices), 1, 1)], dim=0)Using the example:
prompt_embeds = [
cached_embed_for_sample_0,
cached_embed_for_sample_1,
embed("a dog"),
embed("a dog"),
]This matches the image/latent order.
See lines 1805-1810:
model_input = torch.cat([instance_latents_cache[idx] for idx in sample_indices], dim=0)
if args.with_prior_preservation:
model_input = torch.cat(
[model_input, torch.cat([class_latents_cache[idx] for idx in sample_indices], dim=0)],
dim=0,
)So the latent batch is:
model_input = [
instance_latent_0,
instance_latent_1,
class_latent_0,
class_latent_1,
]And the prompt batch is:
prompt_embeds = [
embed("a sks dog on grass"),
embed("a sks dog near water"),
embed("a dog"),
embed("a dog"),
]The ordering matches.
Also, we need an instance-only prompt list.
I could remove instance_prompts from collate_fn because after collate_fn, the first len(batch["indices"]) prompts are always the instance prompts.
Then we have:
batch = {
"indices": indices,
"pixel_values": pixel_values,
"prompts": prompts,
}Result:
batch["prompts"] = [
"a sks dog on grass",
"a sks dog near water",
"a dog",
"a dog",
]
batch["indices"] = [0, 1]So we can derive:
sample_indices = batch["indices"]
instance_prompts = batch["prompts"][: len(sample_indices)]Result:
instance_prompts = [
"a sks dog on grass",
"a sks dog near water",
]Then cache precompute can use the local variable:
prompt_embeds, text_ids = compute_text_embeddings(instance_prompts, text_encoding_pipeline)This gives the same behavior with less batch state.
| dataset: DreamBoothDataset, | ||
| batch_size: int, | ||
| drop_last: bool = False, | ||
| shuffle_batches_each_epoch: bool = True, |
There was a problem hiding this comment.
Why is shuffle_batches_each_epoch going away? Because we're storing indices?
There was a problem hiding this comment.
Yes. shuffle_batches_each_epoch=False was needed because caches were step-indexed: cache[step] had to match the exact batch seen at that dataloader step during precompute.
This PR makes caches sample-indexed instead:
dataset index -> batch["indices"] -> cache[index] -> training retrieves cache[index]
So cache correctness no longer depends on stable dataloader step order, and epoch-wise reshuffling is safe.
| for i in range(0, len(indices_in_bucket), self.batch_size): | ||
| batch = indices_in_bucket[i : i + self.batch_size] | ||
| shuffled_indices = indices_in_bucket.copy() | ||
| random.shuffle(shuffled_indices) |
There was a problem hiding this comment.
We should be able to seed it too, for reproducibility?
There was a problem hiding this comment.
I thought it was already seeded indirectly because accelerate.utils.set_seed seeds Python random, and BucketBatchSampler uses random.shuffle, so reshuffling should be reproducible when --seed is set.
But you’re right the sampler sequence can be affected by any other code that consumes Python random state. I can make this explicit by passing args.seed into BucketBatchSampler and using a sampler-local random.Random(seed) instance.
| latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) | ||
| latents = vae.encode(batch["pixel_values"]).latent_dist.mode() |
There was a problem hiding this comment.
Why are we switching to .mode() here?
There was a problem hiding this comment.
latents contains the encoded latents for the whole batch. vae.encode(image).latent_dist is not one latent tensor. It is a probability distribution, right? To turn that distribution into a latent tensor we use .mode():
def mode(self):
return self.meanNo random noise.
The script already uses .mode() for the non-cached path:
See line:1815
model_input = vae.encode(pixel_values).latent_dist.mode()
So the cached path should also use .mode() to match that behavior.
What does this PR do?
Allow DreamBooth bucket batches to reshuffle each epoch while keeping cached latents and custom-caption prompt embeddings aligned.
After #13353, bucket batches with cached latents/custom captions were kept in stable step order because caches were indexed by dataloader step. This fixes the underlying limitation by indexing cached latents and prompt embeddings by dataset sample index instead. The training dataloader can then reshuffle bucket batches each epoch without reading the wrong cached tensors.
The cache precompute pass now uses a non-dropping cache dataloader, so every sample that can appear in a later reshuffled training epoch has a cache entry.
This also avoids mutating static prompt embeddings inside the training loop. Each step now derives repeated prompt/text embeddings from the original static tensors, which keeps prior-preservation runs with multiple steps stable.
Tested:
Kleinsmoke tests withhf-internal-testing/tiny-flux2-klein:--cache_latents--cache_latents--cache_latents--cache_latents--cache_latents, crossing an epoch boundary withmax_train_steps=7Flux2smoke tests withhf-internal-testing/tiny-flux2using the standard tiny-model settings:train_batch_size=1,max_train_steps=2train_batch_size=2,max_train_steps=2--cache_latents,train_batch_size=2,max_train_steps=3Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul