diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 28722ec25e7a..886e251937e6 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -905,6 +905,7 @@ def __len__(self): def __getitem__(self, index): example = {} instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] + example["index"] = index example["instance_images"] = instance_image example["bucket_idx"] = bucket_idx if self.custom_instance_prompts: @@ -957,7 +958,10 @@ def train_transform(self, image, size=(224, 224), center_crop=False, random_flip def collate_fn(examples, with_prior_preservation=False): + indices = [example["index"] for example in examples] pixel_values = [example["instance_images"] for example in examples] + # Keep instance_prompts unchanged for prompt cache precompute; prompts may be extended with class prompts below. + instance_prompts = [example["instance_prompt"] for example in examples] prompts = [example["instance_prompt"] for example in examples] # Concat class and instance examples for prior preservation. @@ -969,18 +973,17 @@ def collate_fn(examples, with_prior_preservation=False): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - batch = {"pixel_values": pixel_values, "prompts": prompts} + batch = { + "indices": indices, + "pixel_values": pixel_values, + "instance_prompts": instance_prompts, + "prompts": prompts, + } return batch class BucketBatchSampler(BatchSampler): - def __init__( - self, - dataset: DreamBoothDataset, - batch_size: int, - drop_last: bool = False, - shuffle_batches_each_epoch: bool = True, - ): + def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, seed: int = None): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): @@ -989,7 +992,7 @@ def __init__( self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last - self.shuffle_batches_each_epoch = shuffle_batches_each_epoch + self.generator = random.Random(seed) if seed is not None else random # Group indices by bucket self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] @@ -997,29 +1000,25 @@ def __init__( self.bucket_indices[bucket_idx].append(idx) self.sampler_len = 0 - self.batches = [] + for indices_in_bucket in self.bucket_indices: + num_batches, remainder = divmod(len(indices_in_bucket), self.batch_size) + self.sampler_len += num_batches + if remainder > 0 and not self.drop_last: + self.sampler_len += 1 - # Pre-generate batches for each bucket + def __iter__(self): + batches = [] for indices_in_bucket in self.bucket_indices: - # Shuffle indices within the bucket - random.shuffle(indices_in_bucket) - # Create batches - for i in range(0, len(indices_in_bucket), self.batch_size): - batch = indices_in_bucket[i : i + self.batch_size] + shuffled_indices = indices_in_bucket.copy() + self.generator.shuffle(shuffled_indices) + for i in range(0, len(shuffled_indices), self.batch_size): + batch = shuffled_indices[i : i + self.batch_size] if len(batch) < self.batch_size and self.drop_last: - continue # Skip partial batch if drop_last is True - self.batches.append(batch) - self.sampler_len += 1 # Count the number of batches - - if not self.shuffle_batches_each_epoch: - # Shuffle the precomputed batches once to mix buckets while keeping - # the order stable across epochs for step-indexed caches. - random.shuffle(self.batches) + continue + batches.append(batch) - def __iter__(self): - if self.shuffle_batches_each_epoch: - random.shuffle(self.batches) - for batch in self.batches: + self.generator.shuffle(batches) + for batch in batches: yield batch def __len__(self): @@ -1480,13 +1479,8 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, buckets=buckets, ) - has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts - batch_sampler = BucketBatchSampler( - train_dataset, - batch_size=args.train_batch_size, - drop_last=True, - shuffle_batches_each_epoch=not has_step_indexed_caches, - ) + precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True, seed=args.seed) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -1599,32 +1593,72 @@ def _encode_single(prompt: str): if args.with_prior_preservation: prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) text_ids = torch.cat([text_ids, class_text_ids], dim=0) + static_prompt_embeds = prompt_embeds + static_text_ids = text_ids # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided # we encode them in advance as well. + if args.cache_latents: + instance_latents_cache = [None] * train_dataset.num_instance_images + class_latents_cache = [None] * train_dataset.num_instance_images if args.with_prior_preservation else None + if train_dataset.custom_instance_prompts: + prompt_embeds_cache = [None] * train_dataset.num_instance_images + text_ids_cache = [None] * train_dataset.num_instance_images if precompute_latents: - prompt_embeds_cache = [] - text_ids_cache = [] - latents_cache = [] - for batch in tqdm(train_dataloader, desc="Caching latents"): + cache_batch_sampler = BucketBatchSampler( + train_dataset, batch_size=args.train_batch_size, drop_last=False, seed=args.seed + ) + cache_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=cache_batch_sampler, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + for batch in tqdm(cache_dataloader, desc="Caching latents"): with torch.no_grad(): + sample_indices = batch["indices"] if args.cache_latents: with offload_models(vae, device=accelerator.device, offload=args.offload): batch["pixel_values"] = batch["pixel_values"].to( accelerator.device, non_blocking=True, dtype=vae.dtype ) - latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + latents = vae.encode(batch["pixel_values"]).latent_dist.mode() + if args.with_prior_preservation: + instance_latents, class_latents = torch.chunk(latents, 2, dim=0) + else: + instance_latents = latents + for i, idx in enumerate(sample_indices): + instance_latents_cache[idx] = instance_latents[i : i + 1] + if args.with_prior_preservation: + class_latents_cache[idx] = class_latents[i : i + 1] if train_dataset.custom_instance_prompts: if args.remote_text_encoder: - prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"]) + prompt_embeds, text_ids = compute_remote_text_embeddings(batch["instance_prompts"]) elif args.fsdp_text_encoder: - prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) + prompt_embeds, text_ids = compute_text_embeddings( + batch["instance_prompts"], text_encoding_pipeline + ) else: with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): - prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) - prompt_embeds_cache.append(prompt_embeds) - text_ids_cache.append(text_ids) + prompt_embeds, text_ids = compute_text_embeddings( + batch["instance_prompts"], text_encoding_pipeline + ) + for i, idx in enumerate(sample_indices): + prompt_embeds_cache[idx] = prompt_embeds[i : i + 1] + text_ids_cache[idx] = text_ids[i : i + 1] + + if args.cache_latents: + assert all(latents is not None for latents in instance_latents_cache), "Latent cache has unfilled entries." + if args.with_prior_preservation: + assert all(latents is not None for latents in class_latents_cache), ( + "Class latent cache has unfilled entries." + ) + if train_dataset.custom_instance_prompts: + assert all(embeds is not None for embeds in prompt_embeds_cache), ( + "Prompt embedding cache has unfilled entries." + ) + assert all(ids is not None for ids in text_ids_cache), "Text ID cache has unfilled entries." # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 if args.cache_latents: @@ -1748,25 +1782,36 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): for epoch in range(first_epoch, args.num_train_epochs): transformer.train() - for step, batch in enumerate(train_dataloader): + for batch in train_dataloader: models_to_accumulate = [transformer] + sample_indices = batch["indices"] prompts = batch["prompts"] with accelerator.accumulate(models_to_accumulate): if train_dataset.custom_instance_prompts: - prompt_embeds = prompt_embeds_cache[step] - text_ids = text_ids_cache[step] + prompt_embeds = torch.cat([prompt_embeds_cache[idx] for idx in sample_indices], dim=0) + text_ids = torch.cat([text_ids_cache[idx] for idx in sample_indices], dim=0) + if args.with_prior_preservation: + prompt_embeds = torch.cat( + [prompt_embeds, class_prompt_hidden_states.repeat(len(sample_indices), 1, 1)], dim=0 + ) + text_ids = torch.cat([text_ids, class_text_ids.repeat(len(sample_indices), 1, 1)], dim=0) else: # With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries, # while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along # dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...]. num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts) - prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0) - text_ids = text_ids.repeat_interleave(num_repeat_elements, dim=0) + prompt_embeds = static_prompt_embeds.repeat_interleave(num_repeat_elements, dim=0) + text_ids = static_text_ids.repeat_interleave(num_repeat_elements, dim=0) # Convert images to latent space if args.cache_latents: - model_input = latents_cache[step].mode() + model_input = torch.cat([instance_latents_cache[idx] for idx in sample_indices], dim=0) + if args.with_prior_preservation: + model_input = torch.cat( + [model_input, torch.cat([class_latents_cache[idx] for idx in sample_indices], dim=0)], + dim=0, + ) else: with offload_models(vae, device=accelerator.device, offload=args.offload): pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index 21cbc8a2c47b..7eb627e4bd1d 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -900,6 +900,7 @@ def __len__(self): def __getitem__(self, index): example = {} instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] + example["index"] = index example["instance_images"] = instance_image example["bucket_idx"] = bucket_idx if self.custom_instance_prompts: @@ -952,7 +953,10 @@ def train_transform(self, image, size=(224, 224), center_crop=False, random_flip def collate_fn(examples, with_prior_preservation=False): + indices = [example["index"] for example in examples] pixel_values = [example["instance_images"] for example in examples] + # Keep instance_prompts unchanged for prompt cache precompute; prompts may be extended with class prompts below. + instance_prompts = [example["instance_prompt"] for example in examples] prompts = [example["instance_prompt"] for example in examples] # Concat class and instance examples for prior preservation. @@ -964,18 +968,17 @@ def collate_fn(examples, with_prior_preservation=False): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - batch = {"pixel_values": pixel_values, "prompts": prompts} + batch = { + "indices": indices, + "pixel_values": pixel_values, + "instance_prompts": instance_prompts, + "prompts": prompts, + } return batch class BucketBatchSampler(BatchSampler): - def __init__( - self, - dataset: DreamBoothDataset, - batch_size: int, - drop_last: bool = False, - shuffle_batches_each_epoch: bool = True, - ): + def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, seed: int = None): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): @@ -984,7 +987,7 @@ def __init__( self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last - self.shuffle_batches_each_epoch = shuffle_batches_each_epoch + self.generator = random.Random(seed) if seed is not None else random # Group indices by bucket self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] @@ -992,29 +995,25 @@ def __init__( self.bucket_indices[bucket_idx].append(idx) self.sampler_len = 0 - self.batches = [] + for indices_in_bucket in self.bucket_indices: + num_batches, remainder = divmod(len(indices_in_bucket), self.batch_size) + self.sampler_len += num_batches + if remainder > 0 and not self.drop_last: + self.sampler_len += 1 - # Pre-generate batches for each bucket + def __iter__(self): + batches = [] for indices_in_bucket in self.bucket_indices: - # Shuffle indices within the bucket - random.shuffle(indices_in_bucket) - # Create batches - for i in range(0, len(indices_in_bucket), self.batch_size): - batch = indices_in_bucket[i : i + self.batch_size] + shuffled_indices = indices_in_bucket.copy() + self.generator.shuffle(shuffled_indices) + for i in range(0, len(shuffled_indices), self.batch_size): + batch = shuffled_indices[i : i + self.batch_size] if len(batch) < self.batch_size and self.drop_last: - continue # Skip partial batch if drop_last is True - self.batches.append(batch) - self.sampler_len += 1 # Count the number of batches - - if not self.shuffle_batches_each_epoch: - # Shuffle the precomputed batches once to mix buckets while keeping - # the order stable across epochs for step-indexed caches. - random.shuffle(self.batches) + continue + batches.append(batch) - def __iter__(self): - if self.shuffle_batches_each_epoch: - random.shuffle(self.batches) - for batch in self.batches: + self.generator.shuffle(batches) + for batch in batches: yield batch def __len__(self): @@ -1473,13 +1472,8 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, buckets=buckets, ) - has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts - batch_sampler = BucketBatchSampler( - train_dataset, - batch_size=args.train_batch_size, - drop_last=True, - shuffle_batches_each_epoch=not has_step_indexed_caches, - ) + precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True, seed=args.seed) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -1542,30 +1536,70 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): if args.with_prior_preservation: prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) text_ids = torch.cat([text_ids, class_text_ids], dim=0) + static_prompt_embeds = prompt_embeds + static_text_ids = text_ids # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided # we encode them in advance as well. + if args.cache_latents: + instance_latents_cache = [None] * train_dataset.num_instance_images + class_latents_cache = [None] * train_dataset.num_instance_images if args.with_prior_preservation else None + if train_dataset.custom_instance_prompts: + prompt_embeds_cache = [None] * train_dataset.num_instance_images + text_ids_cache = [None] * train_dataset.num_instance_images if precompute_latents: - prompt_embeds_cache = [] - text_ids_cache = [] - latents_cache = [] - for batch in tqdm(train_dataloader, desc="Caching latents"): + cache_batch_sampler = BucketBatchSampler( + train_dataset, batch_size=args.train_batch_size, drop_last=False, seed=args.seed + ) + cache_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=cache_batch_sampler, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + for batch in tqdm(cache_dataloader, desc="Caching latents"): with torch.no_grad(): + sample_indices = batch["indices"] if args.cache_latents: with offload_models(vae, device=accelerator.device, offload=args.offload): batch["pixel_values"] = batch["pixel_values"].to( accelerator.device, non_blocking=True, dtype=vae.dtype ) - latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + latents = vae.encode(batch["pixel_values"]).latent_dist.mode() + if args.with_prior_preservation: + instance_latents, class_latents = torch.chunk(latents, 2, dim=0) + else: + instance_latents = latents + for i, idx in enumerate(sample_indices): + instance_latents_cache[idx] = instance_latents[i : i + 1] + if args.with_prior_preservation: + class_latents_cache[idx] = class_latents[i : i + 1] if train_dataset.custom_instance_prompts: if args.fsdp_text_encoder: - prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) + prompt_embeds, text_ids = compute_text_embeddings( + batch["instance_prompts"], text_encoding_pipeline + ) else: with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): - prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) - prompt_embeds_cache.append(prompt_embeds) - text_ids_cache.append(text_ids) + prompt_embeds, text_ids = compute_text_embeddings( + batch["instance_prompts"], text_encoding_pipeline + ) + for i, idx in enumerate(sample_indices): + prompt_embeds_cache[idx] = prompt_embeds[i : i + 1] + text_ids_cache[idx] = text_ids[i : i + 1] + + if args.cache_latents: + assert all(latents is not None for latents in instance_latents_cache), "Latent cache has unfilled entries." + if args.with_prior_preservation: + assert all(latents is not None for latents in class_latents_cache), ( + "Class latent cache has unfilled entries." + ) + if train_dataset.custom_instance_prompts: + assert all(embeds is not None for embeds in prompt_embeds_cache), ( + "Prompt embedding cache has unfilled entries." + ) + assert all(ids is not None for ids in text_ids_cache), "Text ID cache has unfilled entries." # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 if args.cache_latents: @@ -1688,25 +1722,36 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): for epoch in range(first_epoch, args.num_train_epochs): transformer.train() - for step, batch in enumerate(train_dataloader): + for batch in train_dataloader: models_to_accumulate = [transformer] + sample_indices = batch["indices"] prompts = batch["prompts"] with accelerator.accumulate(models_to_accumulate): if train_dataset.custom_instance_prompts: - prompt_embeds = prompt_embeds_cache[step] - text_ids = text_ids_cache[step] + prompt_embeds = torch.cat([prompt_embeds_cache[idx] for idx in sample_indices], dim=0) + text_ids = torch.cat([text_ids_cache[idx] for idx in sample_indices], dim=0) + if args.with_prior_preservation: + prompt_embeds = torch.cat( + [prompt_embeds, class_prompt_hidden_states.repeat(len(sample_indices), 1, 1)], dim=0 + ) + text_ids = torch.cat([text_ids, class_text_ids.repeat(len(sample_indices), 1, 1)], dim=0) else: # With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries, # while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along # dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...]. num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts) - prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0) - text_ids = text_ids.repeat_interleave(num_repeat_elements, dim=0) + prompt_embeds = static_prompt_embeds.repeat_interleave(num_repeat_elements, dim=0) + text_ids = static_text_ids.repeat_interleave(num_repeat_elements, dim=0) # Convert images to latent space if args.cache_latents: - model_input = latents_cache[step].mode() + model_input = torch.cat([instance_latents_cache[idx] for idx in sample_indices], dim=0) + if args.with_prior_preservation: + model_input = torch.cat( + [model_input, torch.cat([class_latents_cache[idx] for idx in sample_indices], dim=0)], + dim=0, + ) else: with offload_models(vae, device=accelerator.device, offload=args.offload): pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype)