#!/usr/bin/env python

import os
import logging
import math
import accelerate
import datasets
from datasets import Dataset, Image
import numpy as np
import PIL
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig

import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from diffusers.optimization import get_scheduler

from models.layer_aware_dual_lora import LayerUNet
from myutils.img_utils import meta_to_inpaint_dataset_format
import time
import gc
import argparse
import cv2
from torchvision.transforms.functional import to_tensor
import PIL.Image




logger = get_logger(__name__, log_level="INFO")

TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}



def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"):
    text_encoder_config = PretrainedConfig.from_pretrained(
        pretrained_model_name_or_path, subfolder=subfolder, revision=revision
    )
    model_class = text_encoder_config.architectures[0]

    if model_class == "CLIPTextModel":
        from transformers import CLIPTextModel
        return CLIPTextModel
    elif model_class == "CLIPTextModelWithProjection":
        from transformers import CLIPTextModelWithProjection
        return CLIPTextModelWithProjection
    else:
        raise ValueError(f"{model_class} is not supported.")

def convert_to_np(image, resolution):
    if isinstance(image, PIL.Image.Image):
        image = image.resize((resolution, resolution), PIL.Image.Resampling.LANCZOS)
    else:
        if not image:
            raise FileNotFoundError("Image path is empty!")
        if not os.path.exists(image):
            raise FileNotFoundError(f"Image path not found: '{image}'")
        image = PIL.Image.open(image).resize((resolution, resolution), PIL.Image.Resampling.LANCZOS)
    
    image = image.convert("RGB")
    return np.array(image).transpose(2, 0, 1)

def convert_to_np_single(image, resolution):
    if isinstance(image, PIL.Image.Image):
        image = image.resize((resolution, resolution), PIL.Image.Resampling.LANCZOS)
    else:
        if not image:
            raise FileNotFoundError("Image path is empty!")
        if not os.path.exists(image):
            raise FileNotFoundError(f"Image path not found: '{image}'")
        image = PIL.Image.open(image).resize((resolution, resolution), PIL.Image.Resampling.LANCZOS)
    
    image = image.convert("RGB")
    return np.array(image).transpose(2, 0, 1)[:1, :, :]

class PoseEncoder(torch.nn.Module):
    def __init__(self, num_keypoints=25, latent_channels=32, latent_size=64):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(num_keypoints, latent_channels, 3, padding=1)
        self.conv2 = torch.nn.Conv2d(latent_channels, latent_channels, 3, padding=1)
    def forward(self, heatmaps):
        x = F.relu(self.conv1(heatmaps))
        x = F.relu(self.conv2(x))
        return x

class HumanParsingEncoder(torch.nn.Module):
    def __init__(self, num_classes=20, latent_channels=32, latent_size=64):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(num_classes, latent_channels, 3, padding=1)
        self.conv2 = torch.nn.Conv2d(latent_channels, latent_channels, 3, padding=1)
    def forward(self, parsing_maps):
        x = F.relu(self.conv1(parsing_maps))
        x = F.relu(self.conv2(x))
        return x

def parse_args():
    parser = argparse.ArgumentParser(description="Script to train UNet for object removal.")
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--pretrained_unet",
        type=str,
        help="Path to pretrained UNet model.",
    )
    parser.add_argument(
        "--lora_rank",
        type=int,
        help="Rank for LoRA layers.",
    )
    parser.add_argument(
        "--lora_alpha",
        type=float,
        help="Alpha parameter for LoRA layers.",
    )
    parser.add_argument(
        "--dual_loss_weight",
        type=float,
        help="Weight for dual branch loss.",
    )
    parser.add_argument(
        "--comments",
        type=str,
        help="Comments for logging.",
    )
    parser.add_argument(
        "--pretrained_vae_model_name_or_path",
        type=str,
        help="Path to an improved VAE.",
    )
    parser.add_argument(
        "--vae_precision",
        type=str,
        choices=["fp32", "fp16", "bf16"],
        help="VAE precision.",
    )
    parser.add_argument(
        "--revision",
        type=str,
        help="Revision of pretrained model identifier.",
    )
    parser.add_argument(
        "--variant",
        type=str,
        help="Variant of the model files.",
    )
    parser.add_argument(
        "--meta_path",
        type=str,
        nargs="*",
        help="Path to meta info about the dataset.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        help="Output directory.",
    )
    parser.add_argument("--seed", type=int, help="Seed for reproducible training.")
    parser.add_argument(
        "--resolution",
        type=int,
        help="Resolution for input images.",
    )
    parser.add_argument(
        "--train_batch_size", type=int, help="Batch size per device for training."
    )
    parser.add_argument("--num_train_epochs", type=int)
    parser.add_argument(
        "--max_train_steps",
        type=int,
        help="Total number of training steps.",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        help="Number of gradient accumulation steps.",
    )
    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
        help="Whether to use gradient checkpointing.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        help="Initial learning rate.",
    )
    parser.add_argument(
        "--lr_scheduler",
        type=str,
        help="Learning rate scheduler type.",
    )
    parser.add_argument(
        "--lr_warmup_steps", type=int, help="Number of warmup steps."
    )
    parser.add_argument(
        "--dataloader_num_workers",
        type=int,
        help="Number of dataloader workers.",
    )
    parser.add_argument("--adam_beta1", type=float, help="Adam beta1.")
    parser.add_argument("--adam_beta2", type=float, help="Adam beta2.")
    parser.add_argument("--adam_weight_decay", type=float, help="Adam weight decay.")
    parser.add_argument("--adam_epsilon", type=float, help="Adam epsilon.")
    parser.add_argument("--max_grad_norm", type=float, help="Max gradient norm.")
    parser.add_argument(
        "--mixed_precision",
        type=str,
        choices=["no", "fp16", "bf16"],
        help="Mixed precision training.",
    )
    parser.add_argument(
        "--report_to",
        type=str,
        help="Reporting integration.",
    )
    parser.add_argument("--local_rank", type=int, help="Local rank for distributed training.")

    parser.add_argument(
        "--enable_pose_estimator",
        action="store_true",
        help="Whether to enable pose estimator."
    )
    parser.add_argument(
        "--num_keypoints",
        type=int,
        help="Number of keypoints to use for pose estimation."
    )
    parser.add_argument(
        "--pose_latent_ch",
        type=int,
        help="Number of latent channels to use for pose estimation."
    )
    parser.add_argument(
        "--pose_path",
        type=str,
        help="Path to the pose estimation output directory."
    )
    parser.add_argument(
        "--enable_human_parsing",
        action="store_true",
        help="Whether to enable human parsing."
    )
    parser.add_argument(
        "--num_parsing_classes",
        type=int,
        help="Number of classes in human parsing (20 for CIHP)."
    )
    parser.add_argument(
        "--parsing_latent_ch",
        type=int,
        help="Number of latent channels to use for human parsing."
    )
    parser.add_argument(
        "--parsing_path",
        type=str,
        help="Path to the human parsing output directory."
    )
    parser.add_argument(
        "--use_mask_aware",
        action="store_true",
        help="Whether to use mask-aware attention."
    )
    
    parser.add_argument(
        "--unet_scale",
        type=float,
        help="Scale factor for latent unet encoding."
    )
    
    parser.add_argument(
        "--enable_layer_exchange",
        action="store_true",
        help="Enable inter-layer information exchange."
    )
    
    parser.add_argument(
        "--harmonization_weight",
        type=float,
        help="Weight for boundary harmonization."
    )
    parser.add_argument(
        "--get_other_mask_info",
        action="store_true",
        help="Whether to use other_masks info as additional condition embedding."
    )
    
    args = parser.parse_args()
    config = Config()
    
    for arg_name, arg_value in vars(args).items():
        if arg_value is not None:
            setattr(config, arg_name, arg_value)
    
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != config.local_rank:
        config.local_rank = env_local_rank

    return config

def main(config=None):
    if config is None:
        config = parse_args()

    logging_dir = os.path.join(config.output_dir, "logs")

    accelerator_project_config = ProjectConfiguration(project_dir=config.output_dir, logging_dir=logging_dir)
    accelerator = Accelerator(
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        mixed_precision=config.mixed_precision,
        log_with=config.report_to,
        project_config=accelerator_project_config,
    )

    generator = torch.Generator(device=accelerator.device).manual_seed(config.seed)

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    logger.info("Training UNet for object removal")

    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    if config.seed is not None:
        set_seed(config.seed)

    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)

    if config.enable_pose_estimator:
        pose_encoder = PoseEncoder(
            num_keypoints=config.num_keypoints,
            latent_channels=config.pose_latent_ch,
            latent_size=config.resolution // 8
        ).to(accelerator.device)
    if config.enable_human_parsing:
        parsing_encoder = HumanParsingEncoder(
            num_classes=config.num_parsing_classes,
            latent_channels=config.parsing_latent_ch,
            latent_size=config.resolution // 8
        ).to(accelerator.device)

    vae_path = (
        config.pretrained_model_name_or_path
        if config.pretrained_vae_model_name_or_path is None
        else config.pretrained_vae_model_name_or_path
    )
    vae = AutoencoderKL.from_pretrained(
        vae_path,
        subfolder="vae" if config.pretrained_vae_model_name_or_path is None else None,
        revision=config.revision,
        variant=config.variant,
    )
    
    original_unet = UNet2DConditionModel.from_pretrained(
        config.pretrained_model_name_or_path, subfolder="unet", revision=config.revision, variant=config.variant
    )
    if config.pretrained_unet is not None:
        original_unet = UNet2DConditionModel.from_pretrained(config.pretrained_unet)
    
    unet = LayerUNet(original_unet, lora_rank=config.lora_rank, lora_alpha=config.lora_alpha, use_mask_aware=config.use_mask_aware)
    logger.info(f"Created UNet with rank={config.lora_rank}, alpha={config.lora_alpha}, use_mask_aware={config.use_mask_aware}")

    if config.gradient_checkpointing:
        unet.original_unet.enable_gradient_checkpointing()

    optimizer_cls = torch.optim.AdamW
    lora_params = list(filter(lambda p: p.requires_grad, unet.get_lora_parameters()))
    assert len(lora_params) > 0, "No trainable LoRA parameters found!"
    optimizer = optimizer_cls(
        lora_params,
        lr=config.learning_rate,
        betas=(config.adam_beta1, config.adam_beta2),
        weight_decay=config.adam_weight_decay,
        eps=config.adam_epsilon,
    )
    logger.info(f"Optimizing {len(lora_params)} LoRA parameters")

    if not isinstance(config.meta_path, list):
        meta_folder = os.path.dirname(config.meta_path)
    else:
        meta_folder = None
    dataset_dict = meta_to_inpaint_dataset_format(config.meta_path, meta_folder)

    dataset = Dataset.from_dict(dataset_dict).cast_column("input_image", Image()).cast_column("edited_image", Image()).cast_column("mask", Image())

    dataset_columns = ("input_image", "edited_image", "edit_prompt")
    original_image_column = dataset_columns[0]
    edit_prompt_column = dataset_columns[2]
    edited_image_column = dataset_columns[1]

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    tokenizer_1 = AutoTokenizer.from_pretrained(
        config.pretrained_model_name_or_path, subfolder="tokenizer", revision=config.revision, use_fast=False,
    )
    tokenizer_2 = AutoTokenizer.from_pretrained(
        config.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=config.revision, use_fast=False,
    )
    
    text_encoder_cls_1 = import_model_class_from_model_name_or_path(config.pretrained_model_name_or_path, config.revision)
    text_encoder_cls_2 = import_model_class_from_model_name_or_path(
        config.pretrained_model_name_or_path, config.revision, subfolder="text_encoder_2"
    )

    text_encoder_1 = text_encoder_cls_1.from_pretrained(
        config.pretrained_model_name_or_path, subfolder="text_encoder", revision=config.revision
    )
    text_encoder_2 = text_encoder_cls_2.from_pretrained(
        config.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=config.revision
    )

    text_encoder_1.to(accelerator.device, dtype=weight_dtype)
    text_encoder_2.to(accelerator.device, dtype=weight_dtype)
    tokenizers = [tokenizer_1, tokenizer_2]
    text_encoders = [text_encoder_1, text_encoder_2]

    vae.requires_grad_(False)
    text_encoder_1.requires_grad_(False)
    text_encoder_2.requires_grad_(False)

    noise_scheduler = DDPMScheduler.from_pretrained(config.pretrained_model_name_or_path, subfolder="scheduler")

    def encode_prompt(text_encoders, tokenizers, prompt):
        prompt_embeds_list = []

        for tokenizer, text_encoder in zip(tokenizers, text_encoders):
            text_inputs = tokenizer(
                prompt,
                padding="max_length",
                max_length=tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
            )
            text_input_ids = text_inputs.input_ids

            prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True)
            pooled_prompt_embeds = prompt_embeds[0]
            prompt_embeds = prompt_embeds.hidden_states[-2]
            bs_embed, seq_len, _ = prompt_embeds.shape
            prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
            prompt_embeds_list.append(prompt_embeds)

        prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
        pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
        return prompt_embeds, pooled_prompt_embeds

    def compute_embeddings_for_prompts(prompts, text_encoders, tokenizers):
        with torch.no_grad():
            prompt_embeds_all = []
            pooled_prompt_embeds_all = []

            for prompt in prompts:
                prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
                prompt_embeds_all.append(prompt_embeds)
                pooled_prompt_embeds_all.append(pooled_prompt_embeds)

            return torch.stack(prompt_embeds_all), torch.stack(pooled_prompt_embeds_all)

        def compute_time_ids():
            crops_coords_top_left = (0, 0)
            original_size = target_size = (config.resolution, config.resolution)
            add_time_ids = list(original_size + crops_coords_top_left + target_size)
            add_time_ids = torch.tensor([add_time_ids], dtype=weight_dtype)
            return add_time_ids.to(accelerator.device).repeat(config.train_batch_size, 1)

    one_time_prompt_embeds_all, one_time_add_text_embeds_all = compute_embeddings_for_prompts([""], text_encoders, tokenizers)
    add_time_ids = compute_time_ids()

    train_transforms = transforms.Compose([
        transforms.CenterCrop(config.resolution),
    ])

    if config.get_other_mask_info:
        class OtherMaskEncoder(torch.nn.Module):
            def __init__(self, in_ch=1, out_dim=128, resolution=512):
                super().__init__()
                self.encoder = torch.nn.Sequential(
                    torch.nn.Conv2d(in_ch, 8, 3, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(8, 16, 3, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.AdaptiveAvgPool2d((8, 8)),
                    torch.nn.Flatten(),
                    torch.nn.Linear(16*8*8, out_dim),
                )
            def forward(self, x):
                return self.encoder(x)
        other_mask_encoder = OtherMaskEncoder(in_ch=1, out_dim=128, resolution=config.resolution).to('cpu')
        other_mask_encoder.eval()

    def preprocess_images(examples):
        original_images = np.concatenate(
            [convert_to_np(image, config.resolution) for image in examples[original_image_column]]
        )
        edited_images = np.concatenate(
            [convert_to_np(image, config.resolution) for image in examples[edited_image_column]]
        )
        mask_images = np.concatenate([convert_to_np_single(image, config.resolution) for image in examples["mask"]])

        images = np.concatenate([original_images, edited_images])
        images = torch.tensor(images)
        images = 2 * (images / 255) - 1

        mask_tensor = torch.tensor(mask_images)
        mask_tensor = mask_tensor / 255
        concat_mask = torch.cat([images, mask_tensor])
        concat_mask = train_transforms(concat_mask)
        image_length = images.shape[0]
        images, mask_tensor = concat_mask[:image_length, :, :], concat_mask[image_length:, :, :]
        return images, mask_tensor

    def preprocess_train(examples):
        preprocessed_images, preprocess_mask = preprocess_images(examples)
        original_images, edited_images = preprocessed_images.chunk(2)
        original_images = original_images.reshape(-1, 3, config.resolution, config.resolution)
        edited_images = edited_images.reshape(-1, 3, config.resolution, config.resolution)
        mask_images = preprocess_mask.reshape(-1, 1, config.resolution, config.resolution)
        examples["original_pixel_values"] = original_images
        examples["edited_pixel_values"] = edited_images
        examples["mask_pixel_values"] = mask_images

        if config.get_other_mask_info:
            other_mask_embeds = []
            for other_mask_paths in examples.get("other_masks", [[]]*len(mask_images)):
                if len(other_mask_paths) == 0:
                    multi_mask = torch.zeros(1, config.resolution, config.resolution)
                else:
                    mask_tensors = []
                    for p in other_mask_paths:
                        if os.path.exists(p):
                            m = to_tensor(PIL.Image.open(p).convert('L')).unsqueeze(0)
                            m = torch.nn.functional.interpolate(m.float(), size=(config.resolution, config.resolution))
                            mask_tensors.append(m)
                    if len(mask_tensors) == 0:
                        multi_mask = torch.zeros(1, config.resolution, config.resolution)
                    else:
                        multi_mask = torch.clamp(torch.sum(torch.stack(mask_tensors), dim=0), 0, 1)
                with torch.no_grad():
                    emb = other_mask_encoder(multi_mask.unsqueeze(0)).squeeze(0)
                other_mask_embeds.append(emb)
            examples["other_mask_embeds"] = torch.stack(other_mask_embeds)

        if config.enable_pose_estimator:
            pose_keypoints = []
            for image_path in examples.get("image_path", [""] * len(examples["original_pixel_values"])):
                filename = os.path.splitext(os.path.basename(image_path))[0]
                pose_path = os.path.join(config.pose_path, f"{filename}.npy")
                try:
                    keypoints = np.load(pose_path, allow_pickle=True)
                    pose_keypoints.append(keypoints)
                except Exception as e:
                    print(pose_path)
                    logger.warning(f"Pose .npy file not found for {filename}")
                    empty_keypoints = np.zeros((1, config.num_keypoints, 3), dtype=np.float32)
                    pose_keypoints.append(empty_keypoints)
            examples["pose_keypoints"] = pose_keypoints

        if config.enable_human_parsing:
            parsing_maps = []
            for image_path in examples.get("image_path", [""] * len(examples["original_pixel_values"])):
                filename = os.path.splitext(os.path.basename(image_path))[0]
                parsing_npy_path = os.path.join(config.parsing_path, f"{filename}.npz")
                if os.path.exists(parsing_npy_path):
                    try:
                        parsing_map = np.load(parsing_npy_path)["parsing_map"]
                        parsing_map = torch.from_numpy(parsing_map).permute(2, 0, 1).float()
                    except Exception as e:
                        parsing_map = torch.zeros((config.num_parsing_classes, config.resolution, config.resolution), dtype=torch.float32)
                else:
                    print(parsing_npy_path)
                    logger.warning(f"Parsing .npy file not found for {filename}")
                    parsing_map = torch.zeros((config.num_parsing_classes, config.resolution, config.resolution), dtype=torch.float32)
                parsing_maps.append(parsing_map)
            examples["parsing_maps"] = parsing_maps

        bsz = len(examples[edit_prompt_column])
        prompt_embeds_all, add_text_embeds_all = one_time_prompt_embeds_all.repeat(bsz, 1, 1, 1), one_time_add_text_embeds_all.repeat(bsz, 1, 1)
        examples["prompt_embeds"] = prompt_embeds_all
        examples["add_text_embeds"] = add_text_embeds_all
        return examples

    with accelerator.main_process_first():
        train_dataset = dataset.shuffle(seed=config.seed).with_transform(preprocess_train)

    def collate_fn(batch):
        result = {}
        for k in batch[0].keys():
            if isinstance(batch[0][k], torch.Tensor):
                result[k] = torch.stack([d[k] for d in batch])
            elif isinstance(batch[0][k], list):
                result[k] = [d[k] for d in batch]
            else:
                result[k] = [d[k] for d in batch]
        if config.enable_pose_estimator and "pose_keypoints" in batch[0]:
            result["pose_keypoints"] = [d["pose_keypoints"] for d in batch]
        if config.enable_human_parsing and "parsing_maps" in batch[0]:
            result["parsing_maps"] = torch.stack([d["parsing_maps"] for d in batch])
        if config.get_other_mask_info and "other_mask_embeds" in batch[0]:
            result["other_mask_embeds"] = torch.stack([d["other_mask_embeds"] for d in batch])
        return result

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=collate_fn,
        batch_size=config.train_batch_size,
        num_workers=config.dataloader_num_workers,
    )

    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.gradient_accumulation_steps)
    if config.max_train_steps is None:
        config.max_train_steps = config.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    lr_scheduler = get_scheduler(
        config.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=config.lr_warmup_steps * config.gradient_accumulation_steps,
        num_training_steps=config.max_train_steps * config.gradient_accumulation_steps,
    )

    prepare_args = [unet, optimizer, train_dataloader, lr_scheduler]
    if config.enable_pose_estimator:
        prepare_args.append(pose_encoder)
    if config.enable_human_parsing:
        prepare_args.append(parsing_encoder)
    prepared = accelerator.prepare(*prepare_args)
    unet, optimizer, train_dataloader, lr_scheduler = prepared[:4]
    idx = 4
    if config.enable_pose_estimator:
        pose_encoder = prepared[idx]
        idx += 1
    if config.enable_human_parsing:
        parsing_encoder = prepared[idx]

    if config.pretrained_vae_model_name_or_path is not None:
        vae.to(accelerator.device, dtype=weight_dtype)
    else:
        vae.to(accelerator.device, dtype=TORCH_DTYPE_MAPPING[config.vae_precision])

    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.gradient_accumulation_steps)
    if overrode_max_train_steps:
        config.max_train_steps = config.num_train_epochs * num_update_steps_per_epoch
    config.num_train_epochs = math.ceil(config.max_train_steps / num_update_steps_per_epoch)

    total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps

    logger.info("***** Running Progressive Training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {config.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {config.train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {config.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {config.max_train_steps}")
    
    stage1_steps = int(config.max_train_steps * 0.4)
    stage2_steps = int(config.max_train_steps * 0.4)
    stage3_steps = config.max_train_steps - stage1_steps - stage2_steps
    
    logger.info("***** Progressive Training Strategy *****")
    logger.info(f"  Stage 1 (Background Only): Steps 0-{stage1_steps} ({stage1_steps} steps)")
    logger.info(f"  Stage 2 (Protected Dual): Steps {stage1_steps+1}-{stage1_steps+stage2_steps} ({stage2_steps} steps)")
    logger.info(f"  Stage 3 (Fine Tuning): Steps {stage1_steps+stage2_steps+1}-{config.max_train_steps} ({stage3_steps} steps)")
    logger.info("******************************************")
    
    global_step = 1
    first_epoch = 0

    progress_bar = tqdm(
        range(0, config.max_train_steps),
        initial=global_step,
        desc="Steps",
        disable=not accelerator.is_local_main_process,
    )

    latent_size = config.resolution // 8

    for epoch in range(first_epoch, config.num_train_epochs):
        train_loss = 0.0
        for step, batch in enumerate(train_dataloader):
            start_time = time.time()
            with accelerator.accumulate(unet):
                if config.pretrained_vae_model_name_or_path is not None:
                    edited_pixel_values = batch["edited_pixel_values"].to(dtype=weight_dtype)
                else:
                    edited_pixel_values = batch["edited_pixel_values"]

                edited_pixel_values = edited_pixel_values.to(torch.float16)
                latents = vae.encode(edited_pixel_values).latent_dist.sample()
                latents = latents * vae.config.scaling_factor
                if config.pretrained_vae_model_name_or_path is None:
                    latents = latents.to(weight_dtype)

                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
                timesteps = timesteps.long()

                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                encoder_hidden_states = batch["prompt_embeds"]
                add_text_embeds = batch["add_text_embeds"]

                if config.pretrained_vae_model_name_or_path is not None:
                    original_pixel_values = batch["original_pixel_values"].to(dtype=weight_dtype)
                else:
                    original_pixel_values = batch["original_pixel_values"]

                original_pixel_values = original_pixel_values.to(torch.float16)
                original_pixel_values = original_pixel_values * (batch["mask_pixel_values"] < 0.5)
                original_image_embeds = vae.encode(original_pixel_values).latent_dist.sample() * vae.config.scaling_factor
                if config.pretrained_vae_model_name_or_path is None:
                    original_image_embeds = original_image_embeds.to(weight_dtype)

                mask_pixel_values = batch["mask_pixel_values"]
                mask_pixel_values = torch.nn.functional.interpolate(mask_pixel_values, size=(latent_size, latent_size))

                original_image_embeds_fg = original_image_embeds.clone()
                mask_latent_blur = F.interpolate(mask_pixel_values, size=original_image_embeds.shape[-2:], mode='nearest')
                original_image_embeds_fg = original_image_embeds_fg * (1 - mask_latent_blur * 0.95)
                
                original_image_embeds_bg = original_image_embeds
                
                concatenated_noisy_latents_fg = torch.cat([noisy_latents, mask_pixel_values, original_image_embeds_fg], dim=1)
                concatenated_noisy_latents_bg = torch.cat([noisy_latents, mask_pixel_values, original_image_embeds_bg], dim=1)

                unet.set_mdma_params(global_step, config.max_train_steps, mdma_lambda=1.0)

                added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}

                if config.enable_pose_estimator:
                    def keypoints_to_heatmaps(pose_keypoints_batch, image_size, num_keypoints, sigma=6):
                        B = len(pose_keypoints_batch)
                        H, W = image_size
                        heatmaps = []
                        for person_kpts in pose_keypoints_batch:
                            hm = np.zeros((num_keypoints, H, W), np.float32)
                            for p in person_kpts:
                                for i, (x, y, c) in enumerate(p):
                                    if c > 0.05:
                                        cv2.circle(hm[i], (int(x), int(y)), sigma, 1, -1)
                            heatmaps.append(hm)
                        heatmaps = torch.from_numpy(np.stack(heatmaps, 0))
                        return heatmaps
                    raw_kpts = batch["pose_keypoints"]
                    pose_latent_size = 8
                    heatmaps = keypoints_to_heatmaps(
                        raw_kpts,
                        image_size=(config.resolution, config.resolution),
                        num_keypoints=config.num_keypoints
                    ).to(accelerator.device)
                    heatmaps = F.interpolate(heatmaps, size=(pose_latent_size, pose_latent_size))
                    pose_feats = pose_encoder(heatmaps)
                    B, C, L, _ = pose_feats.shape
                    pose_feats = pose_feats.view(B, C, L*L).permute(0, 2, 1)
                    cross_attention_dim = unet.config.cross_attention_dim if not hasattr(unet, "module") else unet.module.config.cross_attention_dim
                    proj = torch.nn.Linear(C, cross_attention_dim).to(pose_feats.device)
                    pose_tokens = proj(pose_feats)
                    if encoder_hidden_states.dim() == 4 and encoder_hidden_states.shape[1] == 1:
                        encoder_hidden_states = encoder_hidden_states.squeeze(1)
                    if encoder_hidden_states.dim() == 2:
                        encoder_hidden_states = encoder_hidden_states.unsqueeze(1)
                    if pose_tokens.dim() == 4:
                        pose_tokens = pose_tokens.squeeze(-1)
                    assert encoder_hidden_states.dim() == 3 and pose_tokens.dim() == 3, \
                        f"Shapes: {encoder_hidden_states.shape}, {pose_tokens.shape}"
                    encoder_hidden_states = torch.cat([encoder_hidden_states, pose_tokens], dim=1)
                if config.enable_human_parsing:
                    parsing_latent_size = 8
                    parsing_maps = batch["parsing_maps"]
                    parsing_maps = F.interpolate(parsing_maps, size=(parsing_latent_size, parsing_latent_size))
                    parsing_feats = parsing_encoder(parsing_maps)
                    B, C, L, _ = parsing_feats.shape
                    parsing_feats = parsing_feats.view(B, C, L*L).permute(0, 2, 1)
                    cross_attention_dim = unet.config.cross_attention_dim if not hasattr(unet, "module") else unet.module.config.cross_attention_dim
                    proj_parsing = torch.nn.Linear(C, cross_attention_dim).to(parsing_feats.device)
                    parsing_tokens = proj_parsing(parsing_feats)
                    if encoder_hidden_states.dim() == 4 and encoder_hidden_states.shape[1] == 1:
                        encoder_hidden_states = encoder_hidden_states.squeeze(1)
                    if encoder_hidden_states.dim() == 2:
                        encoder_hidden_states = encoder_hidden_states.unsqueeze(1)
                    if parsing_tokens.dim() == 4:
                        parsing_tokens = parsing_tokens.squeeze(-1)
                    assert encoder_hidden_states.dim() == 3 and parsing_tokens.dim() == 3, \
                        f"Shapes: {encoder_hidden_states.shape}, {parsing_tokens.shape}"
                    encoder_hidden_states = torch.cat([encoder_hidden_states, parsing_tokens], dim=1)

                if config.get_other_mask_info and "other_mask_embeds" in batch:
                    emb = batch["other_mask_embeds"]
                    if encoder_hidden_states.dim() == 4:
                        encoder_hidden_states = encoder_hidden_states.squeeze(1)
                    emb = emb.unsqueeze(1)
                    encoder_hidden_states = torch.cat([encoder_hidden_states, emb], dim=1)

                model_pred_fg = unet(
                    concatenated_noisy_latents_fg,
                    timesteps,
                    encoder_hidden_states,
                    added_cond_kwargs=added_cond_kwargs,
                    mask=mask_pixel_values,
                    branch="foreground",
                    return_dict=False,
                )[0]

                model_pred_bg = unet(
                    concatenated_noisy_latents_bg,
                    timesteps,
                    encoder_hidden_states,
                    added_cond_kwargs=added_cond_kwargs,
                    mask=mask_pixel_values,
                    branch="background",
                    return_dict=False,
                )[0]

                if noise_scheduler.config.prediction_type == "epsilon":
                    target_fg = noise
                    target_bg = noise
                elif noise_scheduler.config.prediction_type == "v_prediction":
                    target_fg = noise_scheduler.get_velocity(latents, noise, timesteps)
                    target_bg = noise_scheduler.get_velocity(latents, noise, timesteps)
                else:
                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

                if mask_pixel_values.shape[1] == 1:
                    mask_latent = mask_pixel_values.repeat(1, model_pred_fg.shape[1], 1, 1)
                else:
                    mask_latent = mask_pixel_values

                model_pred_fg = model_pred_fg.to(weight_dtype)
                target_fg = target_fg.to(weight_dtype)
                mask_latent = mask_latent.to(weight_dtype)
                model_pred_bg = model_pred_bg.to(weight_dtype)
                target_bg = target_bg.to(weight_dtype)

                def compute_foreground_region_loss(model_pred_fg, original_latents, mask_latent, bg_pred):
                    B, C, H, W = model_pred_fg.shape
                    
                    if mask_latent.dim() == 3:
                        mask_latent = mask_latent.unsqueeze(1)
                    elif mask_latent.dim() == 2:
                        mask_latent = mask_latent.unsqueeze(0).unsqueeze(0)
                    
                    if mask_latent.shape[1] == 1:
                        mask_2d = mask_latent.squeeze(1)
                    else:
                        mask_2d = mask_latent[:, 0, :, :]
                    
                    core_mask = (mask_2d > 0.5).float()
                    
                    boundary_mask = compute_boundary_mask_latent(mask_2d, boundary_width=1)
                    
                    core_mask = core_mask.unsqueeze(1).expand(-1, C, -1, -1)
                    boundary_mask = boundary_mask.unsqueeze(1).expand(-1, C, -1, -1)
                    
                    combined_mask = core_mask + boundary_mask
                    combined_mask = torch.clamp(combined_mask, 0, 1)
                    
                    loss = F.mse_loss(
                        model_pred_fg * combined_mask, 
                        original_latents * combined_mask, 
                        reduction="mean"
                    )
                    
                    return loss
                
                def compute_boundary_mask_latent(mask_2d, boundary_width=1):
                    B, H, W = mask_2d.shape
                    boundary_masks = []
                    
                    for b in range(B):
                        single_mask = mask_2d[b]
                        
                        kernel_size = boundary_width * 2 + 1 
                        kernel = torch.ones(kernel_size, kernel_size, device=mask_2d.device, dtype=mask_2d.dtype)
                        kernel = kernel.unsqueeze(0).unsqueeze(0) / (kernel_size ** 2)
                        
                        single_mask_expanded = single_mask.unsqueeze(0).unsqueeze(0)
                        padding = kernel_size // 2
                        dilated = F.conv2d(single_mask_expanded, kernel, padding=padding)
                        dilated = (dilated > 0.1).float().to(dtype=mask_2d.dtype)
                        
                        boundary = dilated - single_mask_expanded
                        boundary = torch.clamp(boundary, 0, 1)
                        
                        boundary_masks.append(boundary.squeeze())
                    
                    return torch.stack(boundary_masks, dim=0)

                def get_training_stage(global_step, max_train_steps):
                    progress = global_step / max_train_steps
                    if progress < 0.4:
                        return "background_only"
                    elif progress < 0.8:
                        return "protected_dual"
                    else:
                        return "fine_tuning"
                
                def set_branch_training_mode(unet, stage):
                    def set_lora_grad(module, fg_grad, bg_grad):
                        if hasattr(module, 'fg_to_q'):
                            module.fg_to_q.lora_A.requires_grad_(fg_grad)
                            module.fg_to_q.lora_B.requires_grad_(fg_grad)
                            module.fg_to_k.lora_A.requires_grad_(fg_grad)
                            module.fg_to_k.lora_B.requires_grad_(fg_grad)
                            module.fg_to_v.lora_A.requires_grad_(fg_grad)
                            module.fg_to_v.lora_B.requires_grad_(fg_grad)
                            module.fg_to_out.lora_A.requires_grad_(fg_grad)
                            module.fg_to_out.lora_B.requires_grad_(fg_grad)
                            
                            module.bg_to_q.lora_A.requires_grad_(bg_grad)
                            module.bg_to_q.lora_B.requires_grad_(bg_grad)
                            module.bg_to_k.lora_A.requires_grad_(bg_grad)
                            module.bg_to_k.lora_B.requires_grad_(bg_grad)
                            module.bg_to_v.lora_A.requires_grad_(bg_grad)
                            module.bg_to_v.lora_B.requires_grad_(bg_grad)
                            module.bg_to_out.lora_A.requires_grad_(bg_grad)
                            module.bg_to_out.lora_B.requires_grad_(bg_grad)
                            
                            if hasattr(module, 'mask_aware_attn') and module.mask_aware_attn is not None:
                                for param in module.mask_aware_attn.parameters():
                                    param.requires_grad_(fg_grad or bg_grad)
                    
                    for module in unet.modules():
                        if stage == "background_only":
                            set_lora_grad(module, fg_grad=False, bg_grad=True)
                        else:
                            set_lora_grad(module, fg_grad=True, bg_grad=True)
                
                def compute_progressive_loss(model_pred_fg, model_pred_bg, target, mask_latent, 
                                           original_latents, stage, dual_loss_weight, global_step, max_train_steps):
                    
                    if stage == "background_only":
                        loss_bg = F.mse_loss(model_pred_bg, target, reduction="mean")
                        loss_fg = torch.tensor(0.0, device=target.device)
                        loss = loss_bg
                        
                        return {
                            'total_loss': loss,
                            'loss_bg': loss_bg,
                            'loss_fg': loss_fg,
                            'stage': stage,
                            'fg_weight': 0.0
                        }
                    
                    elif stage == "protected_dual":
                        loss_bg = F.mse_loss(model_pred_bg, target, reduction="mean")
                        
                        loss_fg = compute_foreground_region_loss(
                            model_pred_fg, original_latents, mask_latent, 
                            model_pred_bg.detach()
                        )
                        
                        stage_progress = (global_step / max_train_steps - 0.4) / 0.4
                        fg_weight = 0.02 + 0.13 * stage_progress
                        
                        loss = loss_bg + fg_weight * loss_fg
                        
                        return {
                            'total_loss': loss,
                            'loss_bg': loss_bg,
                            'loss_fg': loss_fg,
                            'stage': stage,
                            'fg_weight': fg_weight
                        }
                    
                    else:
                        loss_bg = F.mse_loss(model_pred_bg, target, reduction="mean")
                        
                        loss_fg = compute_foreground_region_loss(
                            model_pred_fg, original_latents, mask_latent, 
                            model_pred_bg.detach()
                        )
                        
                        combined = model_pred_fg * mask_latent + model_pred_bg * (1 - mask_latent)
                        consistency_loss = F.mse_loss(combined, target, reduction="mean") * 0.01
                        
                        loss = loss_bg + 0.05 * loss_fg + consistency_loss
                        
                        return {
                            'total_loss': loss,
                            'loss_bg': loss_bg,
                            'loss_fg': loss_fg,
                            'consistency_loss': consistency_loss,
                            'stage': stage,
                            'fg_weight': 0.05
                        }
                
                stage = get_training_stage(global_step, config.max_train_steps)
                
                set_branch_training_mode(unet, stage)
                
                if stage != "background_only":
                    with torch.no_grad():
                        for module in unet.modules():
                            if hasattr(module, 'bg_to_q'):
                                module.bg_to_q.lora_A.requires_grad_(False)
                                module.bg_to_q.lora_B.requires_grad_(False)
                                module.bg_to_k.lora_A.requires_grad_(False)
                                module.bg_to_k.lora_B.requires_grad_(False)
                                module.bg_to_v.lora_A.requires_grad_(False)
                                module.bg_to_v.lora_B.requires_grad_(False)
                                module.bg_to_out.lora_A.requires_grad_(False)
                                module.bg_to_out.lora_B.requires_grad_(False)
                
                if stage != "background_only":
                    original_pixel_values = batch["original_pixel_values"].to(torch.float16)
                    original_latents = vae.encode(original_pixel_values).latent_dist.sample()
                    original_latents = original_latents * vae.config.scaling_factor
                    original_latents = original_latents.to(weight_dtype)
                else:
                    original_latents = None
                
                loss_dict = compute_progressive_loss(
                    model_pred_fg, model_pred_bg, target_bg, mask_latent,
                    original_latents, stage, config.dual_loss_weight, global_step, config.max_train_steps
                )
                
                loss = loss_dict['total_loss']

                avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()
                train_loss += avg_loss.item() / config.gradient_accumulation_steps

                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(lora_params, config.max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

                del model_pred_fg, model_pred_bg, target_bg, mask_latent
                if original_latents is not None:
                    del original_latents
                torch.cuda.empty_cache()

                cost_time = time.time() - start_time

            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
                log_dict = {
                    "train_loss": train_loss,
                    "loss_bg": loss_dict['loss_bg'].item(),
                    "loss_fg": loss_dict['loss_fg'].item(),
                    "training_stage": loss_dict['stage'],
                    "fg_weight": loss_dict.get('fg_weight', 0.0),
                    "step_time": cost_time
                }
                
                if 'consistency_loss' in loss_dict:
                    log_dict["consistency_loss"] = loss_dict['consistency_loss'].item()
                
                accelerator.log(log_dict, step=global_step)
                train_loss = 0.0

                gc.collect()
                torch.cuda.empty_cache()

            logs = {
                "step_loss": loss.detach().item(), 
                "lr": lr_scheduler.get_last_lr()[0],
                "stage": loss_dict['stage'],
                "fg_w": f"{loss_dict.get('fg_weight', 0.0):.2f}"
            }
            progress_bar.set_postfix(**logs)

            if global_step >= config.max_train_steps:
                break

        if global_step < config.max_train_steps:
            logger.warning(f"Training stopped early at step {global_step}. This might be due to an error.")

    accelerator.end_training()

if __name__ == "__main__":
    main()