Skip to content

Allow bucket reshuffling with DreamBooth caches#13712

Open
azolotenkov wants to merge 6 commits into
huggingface:mainfrom
azolotenkov:feat-bucket-epoch-reshuffle-caching
Open

Allow bucket reshuffling with DreamBooth caches#13712
azolotenkov wants to merge 6 commits into
huggingface:mainfrom
azolotenkov:feat-bucket-epoch-reshuffle-caching

Conversation

@azolotenkov
Copy link
Copy Markdown
Contributor

What does this PR do?

Allow DreamBooth bucket batches to reshuffle each epoch while keeping cached latents and custom-caption prompt embeddings aligned.

After #13353, bucket batches with cached latents/custom captions were kept in stable step order because caches were indexed by dataloader step. This fixes the underlying limitation by indexing cached latents and prompt embeddings by dataset sample index instead. The training dataloader can then reshuffle bucket batches each epoch without reading the wrong cached tensors.

The cache precompute pass now uses a non-dropping cache dataloader, so every sample that can appear in a later reshuffled training epoch has a cache entry.

This also avoids mutating static prompt embeddings inside the training loop. Each step now derives repeated prompt/text embeddings from the original static tensors, which keeps prior-preservation runs with multiple steps stable.

Tested:

Klein smoke tests with hf-internal-testing/tiny-flux2-klein:

  • static prompt, no prior, no cache
  • static prompt, no prior, --cache_latents
  • custom captions, no prior, no latent cache
  • custom captions, no prior, --cache_latents
  • static prompt + prior preservation, no cache
  • static prompt + prior preservation, --cache_latents
  • custom captions + prior preservation, no latent cache
  • custom captions + prior preservation, --cache_latents
  • custom captions + prior preservation + --cache_latents, crossing an epoch boundary with max_train_steps=7

Flux2 smoke tests with hf-internal-testing/tiny-flux2 using the standard tiny-model settings:

  • no prior preservation, no cache, train_batch_size=1, max_train_steps=2
  • prior preservation, train_batch_size=2, max_train_steps=2
  • prior preservation + --cache_latents, train_batch_size=2, max_train_steps=3

Before submitting

Who can review?

@sayakpaul

Copilot AI review requested due to automatic review settings May 10, 2026 19:42
@github-actions github-actions Bot added examples size/L PR with diff > 200 LOC labels May 10, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Enables epoch-wise reshuffling of DreamBooth bucketed batches in the Flux2 DreamBooth LoRA example scripts while keeping cached latents and (custom-caption) prompt/text embeddings correctly aligned by switching caches from step-indexing to dataset-sample indexing.

Changes:

  • Add per-sample index to dataset items and propagate it through collate_fn so caches can be keyed by sample index rather than dataloader step.
  • Rework latent/prompt-embedding caching to precompute via a non-dropping cache dataloader and store per-sample cached tensors.
  • Update BucketBatchSampler to reshuffle indices/batches on each __iter__() call (epoch reshuffle) while keeping __len__ stable.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
examples/dreambooth/train_dreambooth_lora_flux2.py Switch cache alignment to sample-indexing and re-enable epoch reshuffling for bucket batches.
examples/dreambooth/train_dreambooth_lora_flux2_klein.py Same cache-by-index + epoch reshuffle changes for the Klein variant of the training script.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/dreambooth/train_dreambooth_lora_flux2.py
Comment thread examples/dreambooth/train_dreambooth_lora_flux2_klein.py
@azolotenkov azolotenkov force-pushed the feat-bucket-epoch-reshuffle-caching branch from 6cf74c0 to b022936 Compare May 10, 2026 19:56
@sayakpaul
Copy link
Copy Markdown
Member

@claude can you do a review here?

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 11, 2026

Claude finished @sayakpaul's task in 5m 3s —— View job


Review: Allow bucket reshuffling with DreamBooth caches

  • Read review guidelines (AGENTS.md)
  • Read the full changes in both files
  • Trace cache precompute and training-loop retrieval paths
  • Post review feedback

Overall: The approach is sound. Switching from step-indexed to sample-indexed caching correctly decouples cache lookups from batch ordering, enabling epoch reshuffling without misaligned cache reads. The non-dropping cache dataloader (drop_last=False) ensures every sample has a cache entry. Both files are consistent with one another (modulo Klein-specific differences: no remote_text_encoder path, Flux2KleinPipeline vs Flux2Pipeline).

No blocking issues found. A few observations and suggestions below.


Correctness trace

I traced the full flow and confirmed the design holds:

  1. BucketBatchSampler generates indices 0..len(pixel_values)-1, which equals num_instance_images (already includes repeats). Since cache lists are sized [None] * num_instance_images, all indices are in-bounds.

  2. Cache precompute (drop_last=False) visits every sample. Training dataloader (drop_last=True) may drop remainder batches, but those samples still have cache entries. Reshuffling changes which samples land in the remainder, but the cache has full coverage regardless.

  3. Prior-preservation alignment: collate_fn orders batches as [inst1..instB, class1..classB]. The cache retrieval reconstructs this same layout — instance latents first, then class latents. Prompt handling follows the same pattern (custom prompts concat'd with repeated class prompts). This is consistent.

  4. Static prompt path: repeat_interleave from the immutable static_prompt_embeds/static_text_ids avoids mutating shared tensors across steps — this is a nice improvement over the prior pattern.


Suggestions (non-blocking)

  1. Consider an assertion after the cache loop (train_dreambooth_lora_flux2.py:1645, train_dreambooth_lora_flux2_klein.py:1586)

    After caching, there's no validation that every slot was filled. If a future change accidentally breaks the sampler, you'd get a silent None in the training loop (likely surfacing as a confusing torch.cat error). A quick sanity check would surface the real cause:

    if args.cache_latents:
        assert all(x is not None for x in instance_latents_cache), "Latent cache has unfilled entries"
  2. Unused batch["instance_prompts"] in training loop

    During training, batch["instance_prompts"] is never read — prompts come from prompt_embeds_cache (custom caption path) or static_prompt_embeds (static path). It's harmless but does mean the string data travels through the dataloader for no reason. Not worth changing now, just noting for awareness.

  3. Minor: collate_fn initializes both instance_prompts and prompts from the same source (train_dreambooth_lora_flux2.py:963-964, train_dreambooth_lora_flux2_klein.py:958-959)

    instance_prompts = [example["instance_prompt"] for example in examples]
    prompts = [example["instance_prompt"] for example in examples]

    These are separate list objects (so the prompts += class_prompts mutation is safe), but at first glance it reads like a duplication bug. A short inline comment explaining that instance_prompts stays as-is for cache keying while prompts gets extended for prior preservation would help future readers.


Summary

The core design (sample-indexed caches + drop_last=False precompute + reshuffling __iter__) is correct and well-tested across the configurations listed in the PR description. LGTM with the optional suggestions above.


@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@azolotenkov
Copy link
Copy Markdown
Contributor Author

@sayakpaul Thanks for the review.
I addressed Claude’s suggestions in the latest push by adding cache-fill sanity checks after precompute and clarifying why instance_prompts is kept separate from prompts in collate_fn.

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The diff looks very complicated to me, to the point where it's getting increasingly difficult for me to wrap my head around. Can you please explain the major changes in line?

batch = {
"indices": indices,
"pixel_values": pixel_values,
"instance_prompts": instance_prompts,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the instance prompts?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With train_batch_size=2, custom captions, and prior preservation, collate_fn receives two examples:

examples = [
    {
        "index": 0,
        "instance_prompt": "a sks dog on grass",
        "instance_images": inst_img_0,
        "class_prompt": "a dog",
        "class_images": class_img_0,
    },
    {
        "index": 1,
        "instance_prompt": "a sks dog near water",
        "instance_images": inst_img_1,
        "class_prompt": "a dog",
        "class_images": class_img_1,
    },
]

First, it stores sample indices and instance image tensors:

indices = [example["index"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]

Result:

indices = [0, 1]
pixel_values = [inst_img_0, inst_img_1]

Then it builds two prompt lists from the same source:

instance_prompts = [example["instance_prompt"] for example in examples]
prompts = [example["instance_prompt"] for example in examples]

Result:

instance_prompts = [
    "a sks dog on grass",
    "a sks dog near water",
]

prompts = [
    "a sks dog on grass",
    "a sks dog near water",
]

If prior preservation is enabled, collate_fn appends class images and class prompts only to the training batch lists:

if with_prior_preservation:
    pixel_values += [example["class_images"] for example in examples]
    prompts += [example["class_prompt"] for example in examples]

Result:

pixel_values = [
    inst_img_0,
    inst_img_1,
    class_img_0,
    class_img_1,
]

prompts = [
    "a sks dog on grass",
    "a sks dog near water",
    "a dog",
    "a dog",
]

instance_prompts does not change:

instance_prompts = [
    "a sks dog on grass",
    "a sks dog near water",
]

Batch:

batch = {
    "indices": [0, 1],
    "pixel_values": tensor([inst_img_0, inst_img_1, class_img_0, class_img_1]),
    "instance_prompts": [
        "a sks dog on grass",
        "a sks dog near water",
    ],
    "prompts": [
        "a sks dog on grass",
        "a sks dog near water",
        "a dog",
        "a dog",
    ],
}

Why is the instance-only prompt list needed?

During custom-caption precompute, we cache prompt embeddings per dataset sample.

See lines 1604-1606:

if train_dataset.custom_instance_prompts:
    prompt_embeds_cache = [None] * train_dataset.num_instance_images
    text_ids_cache = [None] * train_dataset.num_instance_images

Then, inside the cache loop, we need to encode only the instance captions, not the class prompts.

See lines 1632-1646:

if train_dataset.custom_instance_prompts:
    prompt_embeds, text_ids = compute_text_embeddings(
        batch["instance_prompts"], text_encoding_pipeline
    )
    for i, idx in enumerate(sample_indices):
        prompt_embeds_cache[idx] = prompt_embeds[i : i + 1]
        text_ids_cache[idx] = text_ids[i : i + 1]

Using the example:

batch["instance_prompts"] = ["a sks dog on grass", "a sks dog near water"]

So we cache:

prompt_embeds_cache[0] = embed("a sks dog on grass")
prompt_embeds_cache[1] = embed("a sks dog near water")

Custom captions are per instance image.

Class prompts are not cached the same way. The class prompt is shared. Usually it is one fixed string:

--class_prompt "a dog"

So every class image uses the same prompt.

The script already encodes this fixed class prompt separately before training:

class_prompt_hidden_states, class_text_ids = compute_text_embeddings(
    args.class_prompt, text_encoding_pipeline
)

Then the training loop repeats the class embedding for the current batch size.

See lines 1791-1795:

if args.with_prior_preservation:
    prompt_embeds = torch.cat(
        [prompt_embeds, class_prompt_hidden_states.repeat(len(sample_indices), 1, 1)], dim=0
    )
    text_ids = torch.cat([text_ids, class_text_ids.repeat(len(sample_indices), 1, 1)], dim=0)

Using the example:

prompt_embeds = [
    cached_embed_for_sample_0,
    cached_embed_for_sample_1,
    embed("a dog"),
    embed("a dog"),
]

This matches the image/latent order.

See lines 1805-1810:

model_input = torch.cat([instance_latents_cache[idx] for idx in sample_indices], dim=0)
if args.with_prior_preservation:
    model_input = torch.cat(
        [model_input, torch.cat([class_latents_cache[idx] for idx in sample_indices], dim=0)],
        dim=0,
    )

So the latent batch is:

model_input = [
    instance_latent_0,
    instance_latent_1,
    class_latent_0,
    class_latent_1,
]

And the prompt batch is:

prompt_embeds = [
    embed("a sks dog on grass"),
    embed("a sks dog near water"),
    embed("a dog"),
    embed("a dog"),
]

The ordering matches.

Also, we need an instance-only prompt list.

I could remove instance_prompts from collate_fn because after collate_fn, the first len(batch["indices"]) prompts are always the instance prompts.

Then we have:

batch = {
    "indices": indices,
    "pixel_values": pixel_values,
    "prompts": prompts,
}

Result:

batch["prompts"] = [
    "a sks dog on grass",
    "a sks dog near water",
    "a dog",
    "a dog",
]

batch["indices"] = [0, 1]

So we can derive:

sample_indices = batch["indices"]
instance_prompts = batch["prompts"][: len(sample_indices)]

Result:

instance_prompts = [
    "a sks dog on grass",
    "a sks dog near water",
]

Then cache precompute can use the local variable:

prompt_embeds, text_ids = compute_text_embeddings(instance_prompts, text_encoding_pipeline)

This gives the same behavior with less batch state.

dataset: DreamBoothDataset,
batch_size: int,
drop_last: bool = False,
shuffle_batches_each_epoch: bool = True,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is shuffle_batches_each_epoch going away? Because we're storing indices?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. shuffle_batches_each_epoch=False was needed because caches were step-indexed: cache[step] had to match the exact batch seen at that dataloader step during precompute.

This PR makes caches sample-indexed instead:

dataset index -> batch["indices"] -> cache[index] -> training retrieves cache[index]

So cache correctness no longer depends on stable dataloader step order, and epoch-wise reshuffling is safe.

for i in range(0, len(indices_in_bucket), self.batch_size):
batch = indices_in_bucket[i : i + self.batch_size]
shuffled_indices = indices_in_bucket.copy()
random.shuffle(shuffled_indices)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to seed it too, for reproducibility?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it was already seeded indirectly because accelerate.utils.set_seed seeds Python random, and BucketBatchSampler uses random.shuffle, so reshuffling should be reproducible when --seed is set.

But you’re right the sampler sequence can be affected by any other code that consumes Python random state. I can make this explicit by passing args.seed into BucketBatchSampler and using a sampler-local random.Random(seed) instance.

Comment on lines -1617 to +1623
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
latents = vae.encode(batch["pixel_values"]).latent_dist.mode()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we switching to .mode() here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

latents contains the encoded latents for the whole batch. vae.encode(image).latent_dist is not one latent tensor. It is a probability distribution, right? To turn that distribution into a latent tensor we use .mode():

def mode(self):
    return self.mean

No random noise.

The script already uses .mode() for the non-cached path:

See line:1815

model_input = vae.encode(pixel_values).latent_dist.mode()

So the cached path should also use .mode() to match that behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples size/L PR with diff > 200 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants