Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 96 additions & 51 deletions examples/dreambooth/train_dreambooth_lora_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,7 @@ def __len__(self):
def __getitem__(self, index):
example = {}
instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
example["index"] = index
example["instance_images"] = instance_image
example["bucket_idx"] = bucket_idx
if self.custom_instance_prompts:
Expand Down Expand Up @@ -957,7 +958,10 @@ def train_transform(self, image, size=(224, 224), center_crop=False, random_flip


def collate_fn(examples, with_prior_preservation=False):
indices = [example["index"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
# Keep instance_prompts unchanged for prompt cache precompute; prompts may be extended with class prompts below.
instance_prompts = [example["instance_prompt"] for example in examples]
prompts = [example["instance_prompt"] for example in examples]

# Concat class and instance examples for prior preservation.
Expand All @@ -969,18 +973,17 @@ def collate_fn(examples, with_prior_preservation=False):
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

batch = {"pixel_values": pixel_values, "prompts": prompts}
batch = {
"indices": indices,
"pixel_values": pixel_values,
"instance_prompts": instance_prompts,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the instance prompts?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With train_batch_size=2, custom captions, and prior preservation, collate_fn receives two examples:

examples = [
    {
        "index": 0,
        "instance_prompt": "a sks dog on grass",
        "instance_images": inst_img_0,
        "class_prompt": "a dog",
        "class_images": class_img_0,
    },
    {
        "index": 1,
        "instance_prompt": "a sks dog near water",
        "instance_images": inst_img_1,
        "class_prompt": "a dog",
        "class_images": class_img_1,
    },
]

First, it stores sample indices and instance image tensors:

indices = [example["index"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]

Result:

indices = [0, 1]
pixel_values = [inst_img_0, inst_img_1]

Then it builds two prompt lists from the same source:

instance_prompts = [example["instance_prompt"] for example in examples]
prompts = [example["instance_prompt"] for example in examples]

Result:

instance_prompts = [
    "a sks dog on grass",
    "a sks dog near water",
]

prompts = [
    "a sks dog on grass",
    "a sks dog near water",
]

If prior preservation is enabled, collate_fn appends class images and class prompts only to the training batch lists:

if with_prior_preservation:
    pixel_values += [example["class_images"] for example in examples]
    prompts += [example["class_prompt"] for example in examples]

Result:

pixel_values = [
    inst_img_0,
    inst_img_1,
    class_img_0,
    class_img_1,
]

prompts = [
    "a sks dog on grass",
    "a sks dog near water",
    "a dog",
    "a dog",
]

instance_prompts does not change:

instance_prompts = [
    "a sks dog on grass",
    "a sks dog near water",
]

Batch:

batch = {
    "indices": [0, 1],
    "pixel_values": tensor([inst_img_0, inst_img_1, class_img_0, class_img_1]),
    "instance_prompts": [
        "a sks dog on grass",
        "a sks dog near water",
    ],
    "prompts": [
        "a sks dog on grass",
        "a sks dog near water",
        "a dog",
        "a dog",
    ],
}

Why is the instance-only prompt list needed?

During custom-caption precompute, we cache prompt embeddings per dataset sample.

See lines 1604-1606:

if train_dataset.custom_instance_prompts:
    prompt_embeds_cache = [None] * train_dataset.num_instance_images
    text_ids_cache = [None] * train_dataset.num_instance_images

Then, inside the cache loop, we need to encode only the instance captions, not the class prompts.

See lines 1632-1646:

if train_dataset.custom_instance_prompts:
    prompt_embeds, text_ids = compute_text_embeddings(
        batch["instance_prompts"], text_encoding_pipeline
    )
    for i, idx in enumerate(sample_indices):
        prompt_embeds_cache[idx] = prompt_embeds[i : i + 1]
        text_ids_cache[idx] = text_ids[i : i + 1]

Using the example:

batch["instance_prompts"] = ["a sks dog on grass", "a sks dog near water"]

So we cache:

prompt_embeds_cache[0] = embed("a sks dog on grass")
prompt_embeds_cache[1] = embed("a sks dog near water")

Custom captions are per instance image.

Class prompts are not cached the same way. The class prompt is shared. Usually it is one fixed string:

--class_prompt "a dog"

So every class image uses the same prompt.

The script already encodes this fixed class prompt separately before training:

class_prompt_hidden_states, class_text_ids = compute_text_embeddings(
    args.class_prompt, text_encoding_pipeline
)

Then the training loop repeats the class embedding for the current batch size.

See lines 1791-1795:

if args.with_prior_preservation:
    prompt_embeds = torch.cat(
        [prompt_embeds, class_prompt_hidden_states.repeat(len(sample_indices), 1, 1)], dim=0
    )
    text_ids = torch.cat([text_ids, class_text_ids.repeat(len(sample_indices), 1, 1)], dim=0)

Using the example:

prompt_embeds = [
    cached_embed_for_sample_0,
    cached_embed_for_sample_1,
    embed("a dog"),
    embed("a dog"),
]

This matches the image/latent order.

See lines 1805-1810:

model_input = torch.cat([instance_latents_cache[idx] for idx in sample_indices], dim=0)
if args.with_prior_preservation:
    model_input = torch.cat(
        [model_input, torch.cat([class_latents_cache[idx] for idx in sample_indices], dim=0)],
        dim=0,
    )

So the latent batch is:

model_input = [
    instance_latent_0,
    instance_latent_1,
    class_latent_0,
    class_latent_1,
]

And the prompt batch is:

prompt_embeds = [
    embed("a sks dog on grass"),
    embed("a sks dog near water"),
    embed("a dog"),
    embed("a dog"),
]

The ordering matches.

Also, we need an instance-only prompt list.

I could remove instance_prompts from collate_fn because after collate_fn, the first len(batch["indices"]) prompts are always the instance prompts.

Then we have:

batch = {
    "indices": indices,
    "pixel_values": pixel_values,
    "prompts": prompts,
}

Result:

batch["prompts"] = [
    "a sks dog on grass",
    "a sks dog near water",
    "a dog",
    "a dog",
]

batch["indices"] = [0, 1]

So we can derive:

sample_indices = batch["indices"]
instance_prompts = batch["prompts"][: len(sample_indices)]

Result:

instance_prompts = [
    "a sks dog on grass",
    "a sks dog near water",
]

Then cache precompute can use the local variable:

prompt_embeds, text_ids = compute_text_embeddings(instance_prompts, text_encoding_pipeline)

This gives the same behavior with less batch state.

"prompts": prompts,
}
return batch


class BucketBatchSampler(BatchSampler):
def __init__(
self,
dataset: DreamBoothDataset,
batch_size: int,
drop_last: bool = False,
shuffle_batches_each_epoch: bool = True,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is shuffle_batches_each_epoch going away? Because we're storing indices?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. shuffle_batches_each_epoch=False was needed because caches were step-indexed: cache[step] had to match the exact batch seen at that dataloader step during precompute.

This PR makes caches sample-indexed instead:

dataset index -> batch["indices"] -> cache[index] -> training retrieves cache[index]

So cache correctness no longer depends on stable dataloader step order, and epoch-wise reshuffling is safe.

):
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, seed: int = None):
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
Expand All @@ -989,37 +992,33 @@ def __init__(
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch
self.generator = random.Random(seed) if seed is not None else random

# Group indices by bucket
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):
self.bucket_indices[bucket_idx].append(idx)

self.sampler_len = 0
self.batches = []
for indices_in_bucket in self.bucket_indices:
num_batches, remainder = divmod(len(indices_in_bucket), self.batch_size)
self.sampler_len += num_batches
if remainder > 0 and not self.drop_last:
self.sampler_len += 1

# Pre-generate batches for each bucket
def __iter__(self):
batches = []
for indices_in_bucket in self.bucket_indices:
# Shuffle indices within the bucket
random.shuffle(indices_in_bucket)
# Create batches
for i in range(0, len(indices_in_bucket), self.batch_size):
batch = indices_in_bucket[i : i + self.batch_size]
shuffled_indices = indices_in_bucket.copy()
self.generator.shuffle(shuffled_indices)
for i in range(0, len(shuffled_indices), self.batch_size):
batch = shuffled_indices[i : i + self.batch_size]
if len(batch) < self.batch_size and self.drop_last:
continue # Skip partial batch if drop_last is True
self.batches.append(batch)
self.sampler_len += 1 # Count the number of batches

if not self.shuffle_batches_each_epoch:
# Shuffle the precomputed batches once to mix buckets while keeping
# the order stable across epochs for step-indexed caches.
random.shuffle(self.batches)
continue
batches.append(batch)

def __iter__(self):
if self.shuffle_batches_each_epoch:
random.shuffle(self.batches)
for batch in self.batches:
self.generator.shuffle(batches)
for batch in batches:
yield batch

def __len__(self):
Expand Down Expand Up @@ -1480,13 +1479,8 @@ def load_model_hook(models, input_dir):
center_crop=args.center_crop,
buckets=buckets,
)
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
batch_sampler = BucketBatchSampler(
train_dataset,
batch_size=args.train_batch_size,
drop_last=True,
shuffle_batches_each_epoch=not has_step_indexed_caches,
)
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True, seed=args.seed)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
Expand Down Expand Up @@ -1599,32 +1593,72 @@ def _encode_single(prompt: str):
if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
text_ids = torch.cat([text_ids, class_text_ids], dim=0)
static_prompt_embeds = prompt_embeds
static_text_ids = text_ids

# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
# we encode them in advance as well.
if args.cache_latents:
instance_latents_cache = [None] * train_dataset.num_instance_images
class_latents_cache = [None] * train_dataset.num_instance_images if args.with_prior_preservation else None
if train_dataset.custom_instance_prompts:
prompt_embeds_cache = [None] * train_dataset.num_instance_images
text_ids_cache = [None] * train_dataset.num_instance_images
if precompute_latents:
prompt_embeds_cache = []
text_ids_cache = []
latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
cache_batch_sampler = BucketBatchSampler(
train_dataset, batch_size=args.train_batch_size, drop_last=False, seed=args.seed
)
cache_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=cache_batch_sampler,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=args.dataloader_num_workers,
)
for batch in tqdm(cache_dataloader, desc="Caching latents"):
with torch.no_grad():
sample_indices = batch["indices"]
if args.cache_latents:
with offload_models(vae, device=accelerator.device, offload=args.offload):
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=vae.dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
latents = vae.encode(batch["pixel_values"]).latent_dist.mode()
Comment thread
azolotenkov marked this conversation as resolved.
if args.with_prior_preservation:
instance_latents, class_latents = torch.chunk(latents, 2, dim=0)
else:
instance_latents = latents
for i, idx in enumerate(sample_indices):
instance_latents_cache[idx] = instance_latents[i : i + 1]
if args.with_prior_preservation:
class_latents_cache[idx] = class_latents[i : i + 1]
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["instance_prompts"])
elif args.fsdp_text_encoder:
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
prompt_embeds, text_ids = compute_text_embeddings(
batch["instance_prompts"], text_encoding_pipeline
)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
prompt_embeds_cache.append(prompt_embeds)
text_ids_cache.append(text_ids)
prompt_embeds, text_ids = compute_text_embeddings(
batch["instance_prompts"], text_encoding_pipeline
)
for i, idx in enumerate(sample_indices):
prompt_embeds_cache[idx] = prompt_embeds[i : i + 1]
text_ids_cache[idx] = text_ids[i : i + 1]

if args.cache_latents:
assert all(latents is not None for latents in instance_latents_cache), "Latent cache has unfilled entries."
if args.with_prior_preservation:
assert all(latents is not None for latents in class_latents_cache), (
"Class latent cache has unfilled entries."
)
if train_dataset.custom_instance_prompts:
assert all(embeds is not None for embeds in prompt_embeds_cache), (
"Prompt embedding cache has unfilled entries."
)
assert all(ids is not None for ids in text_ids_cache), "Text ID cache has unfilled entries."

# move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
if args.cache_latents:
Expand Down Expand Up @@ -1748,25 +1782,36 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()

for step, batch in enumerate(train_dataloader):
for batch in train_dataloader:
models_to_accumulate = [transformer]
sample_indices = batch["indices"]
prompts = batch["prompts"]
Comment thread
azolotenkov marked this conversation as resolved.

with accelerator.accumulate(models_to_accumulate):
if train_dataset.custom_instance_prompts:
prompt_embeds = prompt_embeds_cache[step]
text_ids = text_ids_cache[step]
prompt_embeds = torch.cat([prompt_embeds_cache[idx] for idx in sample_indices], dim=0)
text_ids = torch.cat([text_ids_cache[idx] for idx in sample_indices], dim=0)
if args.with_prior_preservation:
prompt_embeds = torch.cat(
[prompt_embeds, class_prompt_hidden_states.repeat(len(sample_indices), 1, 1)], dim=0
)
text_ids = torch.cat([text_ids, class_text_ids.repeat(len(sample_indices), 1, 1)], dim=0)
else:
# With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries,
# while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along
# dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...].
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
text_ids = text_ids.repeat_interleave(num_repeat_elements, dim=0)
prompt_embeds = static_prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
text_ids = static_text_ids.repeat_interleave(num_repeat_elements, dim=0)

# Convert images to latent space
if args.cache_latents:
model_input = latents_cache[step].mode()
model_input = torch.cat([instance_latents_cache[idx] for idx in sample_indices], dim=0)
if args.with_prior_preservation:
model_input = torch.cat(
[model_input, torch.cat([class_latents_cache[idx] for idx in sample_indices], dim=0)],
dim=0,
)
else:
with offload_models(vae, device=accelerator.device, offload=args.offload):
pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype)
Expand Down
Loading
Loading