-
Notifications
You must be signed in to change notification settings - Fork 7k
Allow bucket reshuffling with DreamBooth caches #13712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b022936
f6e37a8
95153a9
3b43f3e
a9483d9
967a3fe
5b86bc9
a89412c
b08e214
c281c09
df5e907
882d892
17c4887
fd7fed7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -905,6 +905,7 @@ def __len__(self): | |
| def __getitem__(self, index): | ||
| example = {} | ||
| instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] | ||
| example["index"] = index | ||
| example["instance_images"] = instance_image | ||
| example["bucket_idx"] = bucket_idx | ||
| if self.custom_instance_prompts: | ||
|
|
@@ -957,7 +958,10 @@ def train_transform(self, image, size=(224, 224), center_crop=False, random_flip | |
|
|
||
|
|
||
| def collate_fn(examples, with_prior_preservation=False): | ||
| indices = [example["index"] for example in examples] | ||
| pixel_values = [example["instance_images"] for example in examples] | ||
| # Keep instance_prompts unchanged for prompt cache precompute; prompts may be extended with class prompts below. | ||
| instance_prompts = [example["instance_prompt"] for example in examples] | ||
| prompts = [example["instance_prompt"] for example in examples] | ||
|
|
||
| # Concat class and instance examples for prior preservation. | ||
|
|
@@ -969,18 +973,17 @@ def collate_fn(examples, with_prior_preservation=False): | |
| pixel_values = torch.stack(pixel_values) | ||
| pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() | ||
|
|
||
| batch = {"pixel_values": pixel_values, "prompts": prompts} | ||
| batch = { | ||
| "indices": indices, | ||
| "pixel_values": pixel_values, | ||
| "instance_prompts": instance_prompts, | ||
| "prompts": prompts, | ||
| } | ||
| return batch | ||
|
|
||
|
|
||
| class BucketBatchSampler(BatchSampler): | ||
| def __init__( | ||
| self, | ||
| dataset: DreamBoothDataset, | ||
| batch_size: int, | ||
| drop_last: bool = False, | ||
| shuffle_batches_each_epoch: bool = True, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. This PR makes caches sample-indexed instead:
So cache correctness no longer depends on stable dataloader step order, and epoch-wise reshuffling is safe. |
||
| ): | ||
| def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, seed: int = None): | ||
| if not isinstance(batch_size, int) or batch_size <= 0: | ||
| raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) | ||
| if not isinstance(drop_last, bool): | ||
|
|
@@ -989,37 +992,33 @@ def __init__( | |
| self.dataset = dataset | ||
| self.batch_size = batch_size | ||
| self.drop_last = drop_last | ||
| self.shuffle_batches_each_epoch = shuffle_batches_each_epoch | ||
| self.generator = random.Random(seed) if seed is not None else random | ||
|
|
||
| # Group indices by bucket | ||
| self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] | ||
| for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values): | ||
| self.bucket_indices[bucket_idx].append(idx) | ||
|
|
||
| self.sampler_len = 0 | ||
| self.batches = [] | ||
| for indices_in_bucket in self.bucket_indices: | ||
| num_batches, remainder = divmod(len(indices_in_bucket), self.batch_size) | ||
| self.sampler_len += num_batches | ||
| if remainder > 0 and not self.drop_last: | ||
| self.sampler_len += 1 | ||
|
|
||
| # Pre-generate batches for each bucket | ||
| def __iter__(self): | ||
| batches = [] | ||
| for indices_in_bucket in self.bucket_indices: | ||
| # Shuffle indices within the bucket | ||
| random.shuffle(indices_in_bucket) | ||
| # Create batches | ||
| for i in range(0, len(indices_in_bucket), self.batch_size): | ||
| batch = indices_in_bucket[i : i + self.batch_size] | ||
| shuffled_indices = indices_in_bucket.copy() | ||
| self.generator.shuffle(shuffled_indices) | ||
| for i in range(0, len(shuffled_indices), self.batch_size): | ||
| batch = shuffled_indices[i : i + self.batch_size] | ||
| if len(batch) < self.batch_size and self.drop_last: | ||
| continue # Skip partial batch if drop_last is True | ||
| self.batches.append(batch) | ||
| self.sampler_len += 1 # Count the number of batches | ||
|
|
||
| if not self.shuffle_batches_each_epoch: | ||
| # Shuffle the precomputed batches once to mix buckets while keeping | ||
| # the order stable across epochs for step-indexed caches. | ||
| random.shuffle(self.batches) | ||
| continue | ||
| batches.append(batch) | ||
|
|
||
| def __iter__(self): | ||
| if self.shuffle_batches_each_epoch: | ||
| random.shuffle(self.batches) | ||
| for batch in self.batches: | ||
| self.generator.shuffle(batches) | ||
| for batch in batches: | ||
| yield batch | ||
|
|
||
| def __len__(self): | ||
|
|
@@ -1480,13 +1479,8 @@ def load_model_hook(models, input_dir): | |
| center_crop=args.center_crop, | ||
| buckets=buckets, | ||
| ) | ||
| has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts | ||
| batch_sampler = BucketBatchSampler( | ||
| train_dataset, | ||
| batch_size=args.train_batch_size, | ||
| drop_last=True, | ||
| shuffle_batches_each_epoch=not has_step_indexed_caches, | ||
| ) | ||
| precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts | ||
| batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True, seed=args.seed) | ||
| train_dataloader = torch.utils.data.DataLoader( | ||
| train_dataset, | ||
| batch_sampler=batch_sampler, | ||
|
|
@@ -1599,32 +1593,72 @@ def _encode_single(prompt: str): | |
| if args.with_prior_preservation: | ||
| prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) | ||
| text_ids = torch.cat([text_ids, class_text_ids], dim=0) | ||
| static_prompt_embeds = prompt_embeds | ||
| static_text_ids = text_ids | ||
|
|
||
| # if cache_latents is set to True, we encode images to latents and store them. | ||
| # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided | ||
| # we encode them in advance as well. | ||
| if args.cache_latents: | ||
| instance_latents_cache = [None] * train_dataset.num_instance_images | ||
| class_latents_cache = [None] * train_dataset.num_instance_images if args.with_prior_preservation else None | ||
| if train_dataset.custom_instance_prompts: | ||
| prompt_embeds_cache = [None] * train_dataset.num_instance_images | ||
| text_ids_cache = [None] * train_dataset.num_instance_images | ||
| if precompute_latents: | ||
| prompt_embeds_cache = [] | ||
| text_ids_cache = [] | ||
| latents_cache = [] | ||
| for batch in tqdm(train_dataloader, desc="Caching latents"): | ||
| cache_batch_sampler = BucketBatchSampler( | ||
| train_dataset, batch_size=args.train_batch_size, drop_last=False, seed=args.seed | ||
| ) | ||
| cache_dataloader = torch.utils.data.DataLoader( | ||
| train_dataset, | ||
| batch_sampler=cache_batch_sampler, | ||
| collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), | ||
| num_workers=args.dataloader_num_workers, | ||
| ) | ||
| for batch in tqdm(cache_dataloader, desc="Caching latents"): | ||
| with torch.no_grad(): | ||
| sample_indices = batch["indices"] | ||
| if args.cache_latents: | ||
| with offload_models(vae, device=accelerator.device, offload=args.offload): | ||
| batch["pixel_values"] = batch["pixel_values"].to( | ||
| accelerator.device, non_blocking=True, dtype=vae.dtype | ||
| ) | ||
| latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) | ||
| latents = vae.encode(batch["pixel_values"]).latent_dist.mode() | ||
|
azolotenkov marked this conversation as resolved.
|
||
| if args.with_prior_preservation: | ||
| instance_latents, class_latents = torch.chunk(latents, 2, dim=0) | ||
| else: | ||
| instance_latents = latents | ||
| for i, idx in enumerate(sample_indices): | ||
| instance_latents_cache[idx] = instance_latents[i : i + 1] | ||
| if args.with_prior_preservation: | ||
| class_latents_cache[idx] = class_latents[i : i + 1] | ||
| if train_dataset.custom_instance_prompts: | ||
| if args.remote_text_encoder: | ||
| prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"]) | ||
| prompt_embeds, text_ids = compute_remote_text_embeddings(batch["instance_prompts"]) | ||
| elif args.fsdp_text_encoder: | ||
| prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) | ||
| prompt_embeds, text_ids = compute_text_embeddings( | ||
| batch["instance_prompts"], text_encoding_pipeline | ||
| ) | ||
| else: | ||
| with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): | ||
| prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) | ||
| prompt_embeds_cache.append(prompt_embeds) | ||
| text_ids_cache.append(text_ids) | ||
| prompt_embeds, text_ids = compute_text_embeddings( | ||
| batch["instance_prompts"], text_encoding_pipeline | ||
| ) | ||
| for i, idx in enumerate(sample_indices): | ||
| prompt_embeds_cache[idx] = prompt_embeds[i : i + 1] | ||
| text_ids_cache[idx] = text_ids[i : i + 1] | ||
|
|
||
| if args.cache_latents: | ||
| assert all(latents is not None for latents in instance_latents_cache), "Latent cache has unfilled entries." | ||
| if args.with_prior_preservation: | ||
| assert all(latents is not None for latents in class_latents_cache), ( | ||
| "Class latent cache has unfilled entries." | ||
| ) | ||
| if train_dataset.custom_instance_prompts: | ||
| assert all(embeds is not None for embeds in prompt_embeds_cache), ( | ||
| "Prompt embedding cache has unfilled entries." | ||
| ) | ||
| assert all(ids is not None for ids in text_ids_cache), "Text ID cache has unfilled entries." | ||
|
|
||
| # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 | ||
| if args.cache_latents: | ||
|
|
@@ -1748,25 +1782,36 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): | |
| for epoch in range(first_epoch, args.num_train_epochs): | ||
| transformer.train() | ||
|
|
||
| for step, batch in enumerate(train_dataloader): | ||
| for batch in train_dataloader: | ||
| models_to_accumulate = [transformer] | ||
| sample_indices = batch["indices"] | ||
| prompts = batch["prompts"] | ||
|
azolotenkov marked this conversation as resolved.
|
||
|
|
||
| with accelerator.accumulate(models_to_accumulate): | ||
| if train_dataset.custom_instance_prompts: | ||
| prompt_embeds = prompt_embeds_cache[step] | ||
| text_ids = text_ids_cache[step] | ||
| prompt_embeds = torch.cat([prompt_embeds_cache[idx] for idx in sample_indices], dim=0) | ||
| text_ids = torch.cat([text_ids_cache[idx] for idx in sample_indices], dim=0) | ||
| if args.with_prior_preservation: | ||
| prompt_embeds = torch.cat( | ||
| [prompt_embeds, class_prompt_hidden_states.repeat(len(sample_indices), 1, 1)], dim=0 | ||
| ) | ||
| text_ids = torch.cat([text_ids, class_text_ids.repeat(len(sample_indices), 1, 1)], dim=0) | ||
| else: | ||
| # With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries, | ||
| # while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along | ||
| # dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...]. | ||
| num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts) | ||
| prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0) | ||
| text_ids = text_ids.repeat_interleave(num_repeat_elements, dim=0) | ||
| prompt_embeds = static_prompt_embeds.repeat_interleave(num_repeat_elements, dim=0) | ||
| text_ids = static_text_ids.repeat_interleave(num_repeat_elements, dim=0) | ||
|
|
||
| # Convert images to latent space | ||
| if args.cache_latents: | ||
| model_input = latents_cache[step].mode() | ||
| model_input = torch.cat([instance_latents_cache[idx] for idx in sample_indices], dim=0) | ||
| if args.with_prior_preservation: | ||
| model_input = torch.cat( | ||
| [model_input, torch.cat([class_latents_cache[idx] for idx in sample_indices], dim=0)], | ||
| dim=0, | ||
| ) | ||
| else: | ||
| with offload_models(vae, device=accelerator.device, offload=args.offload): | ||
| pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need the instance prompts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With
train_batch_size=2, custom captions, and prior preservation,collate_fnreceives two examples:First, it stores sample indices and instance image tensors:
Result:
Then it builds two prompt lists from the same source:
Result:
If prior preservation is enabled,
collate_fnappends class images and class prompts only to the training batch lists:Result:
instance_promptsdoes not change:Batch:
Why is the instance-only prompt list needed?
During custom-caption precompute, we cache prompt embeddings per dataset sample.
See lines 1604-1606:
Then, inside the cache loop, we need to encode only the instance captions, not the class prompts.
See lines 1632-1646:
Using the example:
So we cache:
Custom captions are per instance image.
Class prompts are not cached the same way. The class prompt is shared. Usually it is one fixed string:
--class_prompt "a dog"So every class image uses the same prompt.
The script already encodes this fixed class prompt separately before training:
Then the training loop repeats the class embedding for the current batch size.
See lines 1791-1795:
Using the example:
This matches the image/latent order.
See lines 1805-1810:
So the latent batch is:
And the prompt batch is:
The ordering matches.
Also, we need an instance-only prompt list.
I could remove
instance_promptsfromcollate_fnbecause aftercollate_fn, the firstlen(batch["indices"])prompts are always the instance prompts.Then we have:
Result:
So we can derive:
Result:
Then cache precompute can use the local variable:
This gives the same behavior with less batch state.