from accelerate import DistributedDataParallelKwargs
import argparse
import logging
import math
import os
import random
import shutil
from pathlib import Path

import accelerate
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset  # ''datasets'' is a library
from huggingface_hub import create_repo, upload_folder
from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig

import diffusers
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    # StableDiffusionControlNetPipeline,
    UniPCMultistepScheduler,
)
from pipelines.pipeline_roboface import StableDiffusionControlNetPipeline
# from models.controlnet_inj import ControlNetModel
# from models.unet_2d_condition_inj import UNet2DConditionModel

from models.unet_2d_condition_inj import UNet2DConditionModel
from models.controlnet_inj import ControlNetModel

from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available

from dataloaders.new_dataset import ImageTextDegradationDataset

from typing import Mapping, Any
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F

from models.mapper import Mapper, CleanMapper
from transformers import CLIPVisionModel, CLIPImageProcessor


import lpips
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

if is_wandb_available():
    import wandb

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")

logger = get_logger(__name__)

tensor_transforms = transforms.Compose([
    transforms.ToTensor(),
])

# ram_transforms = transforms.Compose([
#     transforms.Resize((384, 384)),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])


def get_inj_embedding(image, mapper, cleanMapper, clip_image_processor, image_encoder_without_proj, reshape=True):
    if isinstance(image, torch.Tensor):
        batch_size = image.shape[0]
        processed_images = []

        for i in range(batch_size):
            img = image[i].detach().cpu()
            img = img.permute(1, 2, 0).numpy()
            img = (img * 255).astype(np.uint8)
            pil_img = Image.fromarray(img)
            processed_images.append(pil_img)

        gt_clip_image = clip_image_processor(
            images=processed_images, return_tensors='pt').pixel_values
    else:
        gt_clip_image = clip_image_processor(
            images=[image], return_tensors='pt').pixel_values

    gt_clip_image = gt_clip_image.to(device=image_encoder_without_proj.device,
                                     dtype=next(image_encoder_without_proj.parameters()).dtype)

    gt_clip_image = F.interpolate(gt_clip_image, (224, 224), mode='bilinear')
    image_features = image_encoder_without_proj(
        gt_clip_image, output_hidden_states=True)
    image_embeddings = [image_features[0]]
    image_embeddings = [emb.detach() for emb in image_embeddings]

    mapper = mapper.to(device=image_encoder_without_proj.device,
                       dtype=next(image_encoder_without_proj.parameters()).dtype)
    cleanMapper = cleanMapper.to(device=image_encoder_without_proj.device,
                                 dtype=next(image_encoder_without_proj.parameters()).dtype)

    inj_embedding = mapper(image_embeddings)
    inj_embedding = cleanMapper(inj_embedding)

    if reshape:
        B, seq_len, _ = inj_embedding.shape
        inj_embedding = inj_embedding.reshape(B, -1, 512)
    return inj_embedding


def image_grid(imgs, rows, cols):
    assert len(imgs) == rows * cols

    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid


def log_validation(vae, text_encoder, tokenizer, unet, controlnet, mapper, cleanMapper, clip_image_processor, image_encoder_without_proj, args, accelerator, weight_dtype, step):
    logger.info("Running validation... ")

    controlnet = accelerator.unwrap_model(controlnet)
    unet = accelerator.unwrap_model(unet)

    pipeline = StableDiffusionControlNetPipeline.from_pretrained(
        args.pretrained_model_name_or_path,
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet,
        controlnet=controlnet,
        safety_checker=None,
        revision=args.revision,
        torch_dtype=weight_dtype,
    )
    pipeline.scheduler = UniPCMultistepScheduler.from_config(
        pipeline.scheduler.config)
    pipeline = pipeline.to(accelerator.device)
    pipeline.set_progress_bar_config(disable=True)

    # LPIPS model
    lpips_model = lpips.LPIPS(net='vgg').to(accelerator.device)  # 使用 AlexNet

    if args.enable_xformers_memory_efficient_attention:
        pipeline.enable_xformers_memory_efficient_attention()

    if args.seed is None:
        generator = None
    else:
        generator = torch.Generator(
            device=accelerator.device).manual_seed(args.seed)

    if len(args.lr_image) == len(args.validation_prompt) and len(args.hr_image) == len(args.validation_prompt):
        lr_image_paths = args.lr_image
        hr_image_paths = args.hr_image
        validation_prompts = args.validation_prompt
    elif len(args.lr_image) == 1:
        lr_image_paths = args.lr_image * \
            len(args.validation_prompt)
        hr_image_paths = args.hr_image * \
            len(args.validation_prompt)
        validation_prompts = args.validation_prompt
    elif len(args.validation_prompt) == 1:
        lr_image_paths = args.lr_image
        hr_image_paths = args.hr_image
        validation_prompts = args.validation_prompt * \
            len(args.lr_image)
    else:
        raise ValueError(
            "number of `args.lr_image` and `args.validation_prompt` and `args.hr_image` should be checked in `parse_args`"
        )

    image_logs = []
    all_lpips_scores = []
    all_psnr_scores = []
    all_ssim_scores = []

    # Prepare transforms for metrics
    # Assuming pipeline output and GT images are RGB PIL Images
    # LPIPS expects NCHW tensor, range [-1, 1]
    lpips_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])  # Map [0, 1] to [-1, 1]
    ])
    # PSNR/SSIM expect numpy arrays, range [0, 255] or [0, 1]
    numpy_transform = transforms.Compose([
        transforms.ToTensor(),  # Map to [0, 1]
        lambda x: (x * 255).numpy().astype(np.uint8).transpose(1,
                                                               2, 0)  # C H W -> H W C, uint8
    ])
    # GT transform (resize to match output if needed, e.g., args.resolution)
    gt_transform = transforms.Compose([
        transforms.Resize((args.resolution, args.resolution)),
    ])

    for validation_prompt, lr_image_path, hr_image_path in zip(validation_prompts, lr_image_paths, hr_image_paths):
        # Load Ground Truth image
        lr_image_pil = Image.open(lr_image_path).convert("RGB")
        hr_image_pil = Image.open(hr_image_path).convert("RGB")

        # Prepare GT for metrics
        gt_for_lpips = lpips_transform(hr_image_pil).unsqueeze(
            0).to(accelerator.device, dtype=weight_dtype)
        gt_for_numpy = numpy_transform(hr_image_pil)

        generated_images_pil = []
        current_lpips = []
        current_psnr = []
        current_ssim = []

        # Assuming validation_image_path is the conditioning image path as well for this script
        conditioning_image_pil = lr_image_pil
        inj_embedding = get_inj_embedding(
            image=conditioning_image_pil,
            mapper=mapper,
            cleanMapper=cleanMapper,
            clip_image_processor=clip_image_processor,
            image_encoder_without_proj=image_encoder_without_proj,
            reshape=False
        )
        for i in range(args.num_validation_images):
            with torch.autocast("cuda"):
                # Generate image
                generated_image_pil = pipeline(
                    prompt=validation_prompt,
                    image=conditioning_image_pil,
                    num_inference_steps=50,
                    generator=generator,
                    start_point='lr',
                    ram_encoder_hidden_states=inj_embedding,
                ).images[0]
            generated_images_pil.append(generated_image_pil)

            # Prepare generated image for metrics
            gen_for_lpips = lpips_transform(generated_image_pil).unsqueeze(
                0).to(accelerator.device, dtype=weight_dtype)
            gen_for_numpy = numpy_transform(generated_image_pil)

            # Calculate metrics for this single image
            with torch.no_grad():
                lpips_val = lpips_model(gt_for_lpips, gen_for_lpips).item()

            # Calculate PSNR and SSIM (ensure data_range matches the image range)
            # Using uint8 [0, 255] range here based on numpy_transform
            psnr_val = psnr(gt_for_numpy, gen_for_numpy, data_range=255)
            # Use channel_axis for multichannel, adjust win_size if needed
            ssim_val = ssim(gt_for_numpy, gen_for_numpy,
                            data_range=255, channel_axis=-1, win_size=7)

            current_lpips.append(lpips_val)
            current_psnr.append(psnr_val)
            current_ssim.append(ssim_val)

        # Average metrics over num_validation_images for this prompt/gt pair
        avg_lpips = np.mean(current_lpips)
        avg_psnr = np.mean(current_psnr)
        avg_ssim = np.mean(current_ssim)
        all_lpips_scores.append(avg_lpips)
        all_psnr_scores.append(avg_psnr)
        all_ssim_scores.append(avg_ssim)

        image_logs.append(
            {"lr_image": lr_image_pil, "hr_image": hr_image_pil, "images": generated_images_pil,
                "validation_prompt": validation_prompt,
                "avg_lpips": avg_lpips, "avg_psnr": avg_psnr, "avg_ssim": avg_ssim}
        )

    # Calculate overall average metrics
    final_avg_lpips = np.mean(all_lpips_scores)
    final_avg_psnr = np.mean(all_psnr_scores)
    final_avg_ssim = np.mean(all_ssim_scores)

    logger.info(
        f"Validation Metrics @ step {step}: LPIPS={final_avg_lpips:.4f}, PSNR={final_avg_psnr:.2f}, SSIM={final_avg_ssim:.4f}")

    for tracker in accelerator.trackers:
        if tracker.name == "tensorboard":
            tracker.writer.add_scalar(
                "validation/avg_lpips", final_avg_lpips, step)
            tracker.writer.add_scalar(
                "validation/avg_psnr", final_avg_psnr, step)
            tracker.writer.add_scalar(
                "validation/avg_ssim", final_avg_ssim, step)

            for i, log in enumerate(image_logs):
                images = log["images"]
                validation_prompt = log["validation_prompt"]
                lr_image = log["lr_image"]  # This is GT image
                hr_image = log["hr_image"]  # This is GT image

                formatted_images = []
                # Add GT image first
                formatted_images.append(np.array(lr_image))  # HWC, uint8
                formatted_images.append(np.array(hr_image))  # HWC, uint8

                # Add generated images
                for img in images:
                    formatted_images.append(np.array(img))  # HWC, uint8

                formatted_images = np.stack(formatted_images)

                image_tag = f"Validation/sample_{i}"

                tracker.writer.add_images(
                    image_tag, formatted_images, step, dataformats="NHWC")

        elif tracker.name == "wandb":
            wandb_logs = {}
            wandb_logs["validation/avg_lpips"] = final_avg_lpips
            wandb_logs["validation/avg_psnr"] = final_avg_psnr
            wandb_logs["validation/avg_ssim"] = final_avg_ssim

            formatted_images_wandb = []
            for i, log in enumerate(image_logs):
                images = log["images"]
                validation_prompt = log["validation_prompt"]
                lr_image = log["lr_image"]  # This is GT image
                hr_image = log["hr_image"]  # This is GT image

                # Log GT image
                formatted_images_wandb.append(wandb.Image(
                    lr_image, caption=f"{i+1}_GT: {validation_prompt}"))
                formatted_images_wandb.append(wandb.Image(
                    hr_image, caption=f"{i+1}_GT: {validation_prompt}"))

                # Log generated images with metrics
                log_prompt = (f"(LPIPS: {log['avg_lpips']:.4f}, "
                              f"PSNR: {log['avg_psnr']:.2f}, SSIM: {log['avg_ssim']:.4f})")
                for j, image in enumerate(images):
                    caption = f"{i+1}_Gen_{j+1}: {validation_prompt} {log_prompt}"
                    formatted_images_wandb.append(
                        wandb.Image(image, caption=caption))

            wandb_logs["validation_samples"] = formatted_images_wandb
            # Log metrics and images together
            tracker.log(wandb_logs, step=step)
        else:
            logger.warn(
                f"Image logging and metric logging not implemented for {tracker.name}")

    # Cleanup LPIPS model from GPU memory
    del lpips_model
    torch.cuda.empty_cache()

    return image_logs  # Return original image logs if needed elsewhere


def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
    text_encoder_config = PretrainedConfig.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="text_encoder",
        revision=revision,
    )
    model_class = text_encoder_config.architectures[0]

    if model_class == "CLIPTextModel":
        from transformers import CLIPTextModel

        return CLIPTextModel
    elif model_class == "RobertaSeriesModelWithTransformation":
        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation

        return RobertaSeriesModelWithTransformation
    else:
        raise ValueError(f"{model_class} is not supported.")


def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
    img_str = ""
    if image_logs is not None:
        img_str = "You can find some example images below.\n"
        for i, log in enumerate(image_logs):
            images = log["images"]
            validation_prompt = log["validation_prompt"]
            validation_image = log["validation_image"]
            validation_image.save(os.path.join(
                repo_folder, "image_control.png"))
            img_str += f"prompt: {validation_prompt}\n"
            images = [validation_image] + images
            image_grid(images, 1, len(images)).save(
                os.path.join(repo_folder, f"images_{i}.png"))
            img_str += f"![images_{i})](./images_{i}.png)\n"

    yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- controlnet
inference: true
---
    """
    model_card = f"""
# controlnet-{repo_id}

These are controlnet weights trained on {base_model} with new type of conditioning.
{img_str}
"""
    with open(os.path.join(repo_folder, "README.md"), "w") as f:
        f.write(yaml + model_card)


def parse_args(input_args=None):
    parser = argparse.ArgumentParser(
        description="Simple example of a ControlNet training script.")
    parser.add_argument(
        "--gt_image_dir",
        type=str,
        default="/home/data/FFHQ512x512",
        help="Path to the root folders containing the training data.",
    )
    parser.add_argument(
        "--text_dir",
        type=str,
        default="/home/data/FFHQ_text2",
        help="Path to the root folders containing the training data.",
    )
    parser.add_argument(
        "--null_text_ratio",
        type=float,
        default=0.0,
        help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
    )
    # 这里将模型换成了sd2.1
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default="/home/models/Group_AAA_Sr/weights/sd_2-1/models--stabilityai--stable-diffusion-2-1/snapshots/5cae40e6a2745ae2b01ad92ae5043f95f23644d6",
        # required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--controlnet_model_name_or_path",
        type=str,
        default=None,
        help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
        " If not specified controlnet weights are initialized from unet.",
    )
    parser.add_argument(
        "--unet_model_name_or_path",
        type=str,
        default=None,
        help="Path to pretrained unet model or model identifier from huggingface.co/models."
        " If not specified controlnet weights are initialized from unet.",
    )
    parser.add_argument(
        "--revision",
        type=str,
        default=None,
        required=False,
        help=(
            "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
            " float32 precision."
        ),
    )
    parser.add_argument(
        "--tokenizer_name",
        type=str,
        default=None,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="/home/data/Roboface/new_experience/lr_text_harder_noise",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--cache_dir",
        type=str,
        default=None,
        help="The directory where the downloaded models and datasets will be stored.",
    )
    parser.add_argument("--seed", type=int, default=None,
                        help="A seed for reproducible training.")
    parser.add_argument(
        "--resolution",
        type=int,
        default=512,
        help=(
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
            " resolution"
        ),
    )
    parser.add_argument(
        "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
    )
    parser.add_argument("--num_train_epochs", type=int, default=30)
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
    )
    parser.add_argument(
        "--checkpointing_steps",
        type=int,
        default=500,
        help=(
            "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
            "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
            "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
            "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
            "instructions."
        ),
    )
    parser.add_argument(
        "--checkpoints_total_limit",
        type=int,
        default=None,
        help=("Max number of checkpoints to store."),
    )
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default="checkpoint-10000",
        help=(
            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
        ),
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=24,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=5e-5,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument(
        "--scale_lr",
        action="store_true",
        default=False,
        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
    )
    parser.add_argument(
        "--lr_scheduler",
        type=str,
        default="constant",
        help=(
            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
            ' "constant", "constant_with_warmup"]'
        ),
    )
    parser.add_argument(
        "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
    )
    parser.add_argument(
        "--lr_num_cycles",
        type=int,
        default=1,
        help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
    )
    parser.add_argument("--lr_power", type=float, default=1.0,
                        help="Power factor of the polynomial scheduler.")
    parser.add_argument(
        "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
    )
    parser.add_argument(
        "--dataloader_num_workers",
        type=int,
        default=0,
        help=(
            "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
        ),
    )
    parser.add_argument("--adam_beta1", type=float, default=0.9,
                        help="The beta1 parameter for the Adam optimizer.")
    parser.add_argument("--adam_beta2", type=float, default=0.999,
                        help="The beta2 parameter for the Adam optimizer.")
    parser.add_argument("--adam_weight_decay", type=float,
                        default=1e-2, help="Weight decay to use.")
    parser.add_argument("--adam_epsilon", type=float, default=1e-08,
                        help="Epsilon value for the Adam optimizer")
    parser.add_argument("--max_grad_norm", default=1.0,
                        type=float, help="Max gradient norm.")
    parser.add_argument("--push_to_hub", action="store_true",
                        help="Whether or not to push the model to the Hub.")
    parser.add_argument("--hub_token", type=str, default=None,
                        help="The token to use to push to the Model Hub.")
    parser.add_argument(
        "--hub_model_id",
        type=str,
        default=None,
        help="The name of the repository to keep in sync with the local `output_dir`.",
    )
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help=(
            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
        ),
    )
    parser.add_argument(
        "--allow_tf32",
        action="store_true",
        help=(
            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
        ),
    )
    parser.add_argument(
        "--report_to",
        type=str,
        default="tensorboard",
        help=(
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
        ),
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default="fp16",
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
        ),
    )
    parser.add_argument(
        "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
    )
    parser.add_argument(
        "--set_grads_to_none",
        action="store_true",
        help=(
            "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
            " behaviors, so disable this argument if it causes any problems. More info:"
            " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
        ),
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        default=None,
        help=(
            "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
            " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
            " or to a folder containing files that 🤗 Datasets can understand."
        ),
    )
    parser.add_argument(
        "--dataset_config_name",
        type=str,
        default=None,
        help="The config of the Dataset, leave as None if there's only one config.",
    )
    parser.add_argument(
        "--train_data_dir",
        type=str,
        default='NOTHING',
        help=(
            "A folder containing the training data. Folder contents must follow the structure described in"
            " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
            " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
        ),
    )
    parser.add_argument(
        "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
    )
    parser.add_argument(
        "--conditioning_image_column",
        type=str,
        default="conditioning_image",
        help="The column of the dataset containing the controlnet conditioning image.",
    )
    parser.add_argument(
        "--caption_column",
        type=str,
        default="text",
        help="The column of the dataset containing a caption or a list of captions.",
    )
    parser.add_argument(
        "--max_train_samples",
        type=int,
        default=None,
        help=(
            "For debugging purposes or quicker training, truncate the number of training examples to this "
            "value if set."
        ),
    )
    parser.add_argument(
        "--proportion_empty_prompts",
        type=float,
        default=0,
        help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
    )
    parser.add_argument(
        "--validation_prompt",
        type=str,
        default=["The woman in the image has a heart-shaped face with a prominent nose and full lips. Her eyes are large and expressive, with a hint of green in her irises. She has a smile that brightens her face, and her skin has a warm, sun-kissed tone. Her eyebrows are arched, and her expression is cheerful and friendly. There are no notable moles, freckles, or wrinkles visible on her face.", "The woman in the image has a round face shape with a prominent nose. Her eyes are brown, and she has thick eyebrows. Her skin tone is olive, and she has a frowning expression on her face. The woman is wearing a gray shirt, and her hair is pulled back into a ponytail.",
                 "The baby in the image has a round face shape with a small nose and wide eyes. The eyes are brown and have a hint of surprise or curiosity. The baby's mouth is slightly open, and the lips are pink. The baby has short, black hair and is wearing a white shirt. The baby's skin tone is light, and the overall expression is innocent and adorable.", "The woman in the image has a heart-shaped face with a prominent nose and full lips. Her eyes are large and expressive, with long eyelashes. She has a distinctive mole on her cheek, and her skin has a slightly pinkish tone. The woman is wearing a pink hat and a pink coat, which adds to her overall appearance. Her expression is a mix of surprise and curiosity, as she looks at the camera."],
        nargs="+",
        help=(
            "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
            " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
            " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
        ),
    )
    parser.add_argument(
        "--lr_image",
        type=str,
        default=None,
        nargs="+",
        help=(
            "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
            " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
            " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
            " `--validation_image` that will be used with all `--validation_prompt`s."
        ),
    )
    parser.add_argument(
        "--hr_image",
        type=str,
        default=None,
        nargs="+",
    )
    parser.add_argument(
        "--num_validation_images",
        type=int,
        default=1,
        help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
    )
    parser.add_argument(
        "--validation_steps",
        type=int,
        default=50,
        help=(
            "Run validation every X steps. Validation consists of running the prompt"
            " `args.validation_prompt` multiple times: `args.num_validation_images`"
            " and logging the images."
        ),
    )
    parser.add_argument(
        "--tracker_project_name",
        type=str,
        default=None,
        help=(
            "The `project_name` argument passed to Accelerator.init_trackers for"
            " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
        ),
    )

    # Mapper and CleanMapper
    parser.add_argument('--mapper_pretrained_model_name_or_path', type=str, default=None,
                        help='The mapper to use for the training.')
    parser.add_argument('--clean_mapper_pretrained_model_name_or_path', type=str, default=None,
                        help='The clean mapper to use for the training.')
    parser.add_argument("--inj_num_token", default=30, type=int)
    parser.add_argument("--image_encoder_without_proj_path", type=str,
                        default=None)

    parser.add_argument('--trainable_modules', nargs='*',
                        type=str, default=["image_attentions"])

    if input_args is not None:
        args = parser.parse_args(input_args)
    else:
        args = parser.parse_args()

    if args.dataset_name is None and args.train_data_dir is None:
        raise ValueError(
            "Specify either `--dataset_name` or `--train_data_dir`")

    if args.dataset_name is not None and args.train_data_dir is not None:
        raise ValueError(
            "Specify only one of `--dataset_name` or `--train_data_dir`")

    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
        raise ValueError(
            "`--proportion_empty_prompts` must be in the range [0, 1].")

    if args.validation_prompt is not None and args.lr_image is None:
        raise ValueError(
            "`--lr_image` must be set if `--validation_prompt` is set")

    if args.validation_prompt is None and args.lr_image is not None:
        raise ValueError(
            "`--validation_prompt` must be set if `--lr_image` is set")

    if (
        args.lr_image is not None
        and args.validation_prompt is not None
        and len(args.lr_image) != 1
        and len(args.validation_prompt) != 1
        and len(args.lr_image) != len(args.validation_prompt)
    ):
        raise ValueError(
            "Must provide either 1 `--lr_image`, 1 `--validation_prompt`,"
            " or the same number of `--validation_prompt`s and `--lr_image`s"
        )

    if args.resolution % 8 != 0:
        raise ValueError(
            "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
        )

    return args


# def main(args):
args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir)


ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)

accelerator_project_config = ProjectConfiguration(
    project_dir=args.output_dir, logging_dir=logging_dir)

accelerator = Accelerator(
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    mixed_precision=args.mixed_precision,
    log_with=args.report_to,
    project_config=accelerator_project_config,
    kwargs_handlers=[ddp_kwargs]
)

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
    transformers.utils.logging.set_verbosity_warning()
    diffusers.utils.logging.set_verbosity_info()
else:
    transformers.utils.logging.set_verbosity_error()
    diffusers.utils.logging.set_verbosity_error()

# If passed along, set the training seed now.
if args.seed is not None:
    set_seed(args.seed)

# Handle the repository creation
if accelerator.is_main_process:
    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)

    if args.push_to_hub:
        repo_id = create_repo(
            repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
        ).repo_id

# Load the tokenizer
if args.tokenizer_name:
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer_name, revision=args.revision, use_fast=False)
# 从sd2.1中加载tokenizer
elif args.pretrained_model_name_or_path:
    tokenizer = AutoTokenizer.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="tokenizer",
        revision=args.revision,
        use_fast=False,
    )

# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(
    args.pretrained_model_name_or_path, args.revision)

# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
# unet = UNet2DConditionModel.from_pretrained(
#     args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
# )

mapper = Mapper(input_dim=1024, output_dim=1024, num_words=args.inj_num_token)
mapper_ckpt = torch.load(
    args.mapper_pretrained_model_name_or_path, map_location="cpu")
mapper.load_state_dict(mapper_ckpt, strict=False)

clean_mapper = CleanMapper(
    input_dim=1024, output_dim=1024, num_words=args.inj_num_token)
clean_mapper_ckpt = torch.load(
    args.clean_mapper_pretrained_model_name_or_path, map_location="cpu")
clean_mapper.load_state_dict(clean_mapper_ckpt, strict=False)

image_encoder_without_proj = CLIPVisionModel.from_pretrained(
    args.image_encoder_without_proj_path)
clip_image_processor = CLIPImageProcessor()

if args.unet_model_name_or_path:
    # resume from self-train
    logger.info("Loading unet weights from self-train")
    unet = UNet2DConditionModel.from_pretrained_orig(
        args.pretrained_model_name_or_path, args.unet_model_name_or_path, subfolder="unet", revision=args.revision, use_image_cross_attention=True
    )
else:
    # resume from pretrained SD
    logger.info("Loading unet weights from SD")
    unet = UNet2DConditionModel.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, use_image_cross_attention=True, low_cpu_mem_usage=False
    )
    print(f'===== if use ram encoder? {unet.config.use_image_cross_attention}')

if args.controlnet_model_name_or_path:
    # resume from self-train
    logger.info("Loading existing controlnet weights")
    controlnet = ControlNetModel.from_pretrained(
        args.controlnet_model_name_or_path, subfolder="controlnet", use_image_cross_attention=True)

else:
    logger.info("Initializing controlnet weights from unet")
    controlnet = ControlNetModel.from_unet(
        unet, use_image_cross_attention=True)


# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
    def save_model_hook(models, weights, output_dir):
        i = len(weights) - 1

        # while len(weights) > 0:
        #     weights.pop()
        #     model = models[i]

        #     sub_dir = "controlnet"
        #     model.save_pretrained(os.path.join(output_dir, sub_dir))

        #     i -= 1
        assert len(models) == 2 and len(weights) == 2
        for i, model in enumerate(models):
            sub_dir = "unet" if isinstance(
                model, UNet2DConditionModel) else "controlnet"
            model.save_pretrained(os.path.join(output_dir, sub_dir))
            # make sure to pop weight so that corresponding model is not saved again
            weights.pop()

    def load_model_hook(models, input_dir):
        # while len(models) > 0:
        #     # pop models so that they are not loaded again
        #     model = models.pop()

        #     # load diffusers style into model
        #     load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
        #     model.register_to_config(**load_model.config)

        #     model.load_state_dict(load_model.state_dict())
        #     del load_model
        assert len(models) == 2
        for i in range(len(models)):
            # pop models so that they are not loaded again
            model = models.pop()

            # load diffusers style into model
            if not isinstance(model, UNet2DConditionModel):
                # , low_cpu_mem_usage=False, ignore_mismatched_sizes=True
                load_model = ControlNetModel.from_pretrained(
                    input_dir, subfolder="controlnet")
            else:
                # , low_cpu_mem_usage=False, ignore_mismatched_sizes=True
                load_model = UNet2DConditionModel.from_pretrained(
                    input_dir, subfolder="unet")

            model.register_to_config(**load_model.config)

            model.load_state_dict(load_model.state_dict())
            del load_model

    accelerator.register_save_state_pre_hook(save_model_hook)
    accelerator.register_load_state_pre_hook(load_model_hook)

vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder.requires_grad_(False)
mapper.requires_grad_(False)
clean_mapper.requires_grad_(False)
image_encoder_without_proj.requires_grad_(False)
controlnet.requires_grad_(True)
controlnet.train()

for name, module in unet.named_modules():
    if name.endswith(tuple(args.trainable_modules)):
        print(f'{name} in <unet> will be optimized.')
        for params in module.parameters():
            params.requires_grad = True


if args.enable_xformers_memory_efficient_attention:
    if is_xformers_available():
        import xformers

        xformers_version = version.parse(xformers.__version__)
        if xformers_version == version.parse("0.0.16"):
            logger.warn(
                "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
            )
        unet.enable_xformers_memory_efficient_attention()
        controlnet.enable_xformers_memory_efficient_attention()
    else:
        raise ValueError(
            "xformers is not available. Make sure it is installed correctly")

if args.gradient_checkpointing:
    unet.enable_gradient_checkpointing()
    controlnet.enable_gradient_checkpointing()

# Check that all trainable models are in full precision
low_precision_error_string = (
    " Please make sure to always have all model weights in full float32 precision when starting training - even if"
    " doing mixed precision training, copy of the weights should still be float32."
)

if accelerator.unwrap_model(controlnet).dtype != torch.float32:
    raise ValueError(
        f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
    )
if accelerator.unwrap_model(unet).dtype != torch.float32:
    raise ValueError(
        f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
    )


# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
    torch.backends.cuda.matmul.allow_tf32 = True

if args.scale_lr:
    args.learning_rate = (
        args.learning_rate * args.gradient_accumulation_steps *
        args.train_batch_size * accelerator.num_processes
    )

# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
    try:
        import bitsandbytes as bnb
    except ImportError:
        raise ImportError(
            "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
        )

    optimizer_class = bnb.optim.AdamW8bit
else:
    optimizer_class = torch.optim.AdamW

# Optimizer creation
print(f'=================Optimize ControlNet and Unet ======================')
params_to_optimize = list(controlnet.parameters()) + list(unet.parameters())


print(f'start to load optimizer...')

optimizer = optimizer_class(
    params_to_optimize,
    lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,
)

train_dataset = ImageTextDegradationDataset(
    gt_image_dir=args.gt_image_dir,
    gt_text_dir=args.text_dir,
    tokenizer=tokenizer,
    empty_text_ratio=args.null_text_ratio,
    resize_bak=True
)

# train_dataset = PairedCaptionDataset(root_folders=args.root_folders,
#                                      tokenizer=tokenizer,
#                                      null_text_ratio=args.null_text_ratio,
#                                      )

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    num_workers=args.dataloader_num_workers,
    batch_size=args.train_batch_size,
    shuffle=True
)


# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(
    len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    overrode_max_train_steps = True

lr_scheduler = get_scheduler(
    args.lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
    num_training_steps=args.max_train_steps * accelerator.num_processes,
    num_cycles=args.lr_num_cycles,
    power=args.lr_power,
)

# Prepare everything with our `accelerator`.
controlnet, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    controlnet, unet, optimizer, train_dataloader, lr_scheduler
)

# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16

# Move vae, unet and text_encoder to device and cast to weight_dtype
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
mapper.to(accelerator.device, dtype=weight_dtype)
clean_mapper.to(accelerator.device, dtype=weight_dtype)
image_encoder_without_proj.to(accelerator.device, dtype=weight_dtype)
# RAM.to(accelerator.device, dtype=weight_dtype)

num_update_steps_per_epoch = math.ceil(
    len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(
    args.max_train_steps / num_update_steps_per_epoch)


# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
    tracker_config = dict(vars(args))

    # tensorboard cannot handle list types for config
    tracker_config.pop("validation_prompt")
    tracker_config.pop("lr_image")
    tracker_config.pop("hr_image")
    tracker_config.pop("trainable_modules")

    accelerator.init_trackers(args.tracker_project_name, config=tracker_config)

# Train!
total_batch_size = args.train_batch_size * \
    accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
# if not isinstance(train_dataset, WebImageDataset):
#     logger.info(f"  Num examples = {len(train_dataset)}")
#     logger.info(f"  Num batches each epoch = {len(train_dataloader)}")


logger.info(f"  Num examples = {len(train_dataset)}")
logger.info(f"  Num batches each epoch = {len(train_dataloader)}")

logger.info(f"  Num Epochs = {args.num_train_epochs}")
logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
logger.info(
    f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(
    f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f"  Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
    if args.resume_from_checkpoint != "latest":
        path = os.path.basename(args.resume_from_checkpoint)
    else:
        # Get the most recent checkpoint
        dirs = os.listdir(args.output_dir)
        dirs = [d for d in dirs if d.startswith("checkpoint")]
        dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
        path = dirs[-1] if len(dirs) > 0 else None

    if path is None:
        accelerator.print(
            f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
        )
        args.resume_from_checkpoint = None
        initial_global_step = 0
    else:
        accelerator.print(f"Resuming from checkpoint {path}")
        accelerator.load_state(os.path.join(args.output_dir, path))
        global_step = int(path.split("-")[1])

        initial_global_step = global_step
        first_epoch = global_step // num_update_steps_per_epoch
else:
    initial_global_step = 0

progress_bar = tqdm(
    range(0, args.max_train_steps),
    initial=initial_global_step,
    desc="Steps",
    # Only show the progress bar once on each machine.
    disable=not accelerator.is_local_main_process,
)


for epoch in range(first_epoch, args.num_train_epochs):
    for step, batch in enumerate(train_dataloader):
        # with accelerator.accumulate(controlnet):
        with accelerator.accumulate(controlnet), accelerator.accumulate(unet):
            pixel_values = batch["pixel_values"].to(
                accelerator.device, dtype=weight_dtype)
            # Convert images to latent space
            latents = vae.encode(pixel_values).latent_dist.sample()
            latents = latents * vae.config.scaling_factor

            # Sample noise that we'll add to the latents
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()

            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_latents = noise_scheduler.add_noise(
                latents, noise, timesteps)

            # # Get the text embedding for conditioning
            encoder_hidden_states = text_encoder(
                batch["input_ids"].to(accelerator.device))[0]

            inj_embeddings = get_inj_embedding(
                image=batch["conditioning_pixel_values"],
                mapper=mapper,
                cleanMapper=clean_mapper,
                clip_image_processor=clip_image_processor,
                image_encoder_without_proj=image_encoder_without_proj,
                reshape=False
            )
            controlnet_image = batch["conditioning_pixel_values"].to(
                accelerator.device, dtype=weight_dtype)

            down_block_res_samples, mid_block_res_sample = controlnet(
                noisy_latents,
                timesteps,
                encoder_hidden_states=encoder_hidden_states,
                controlnet_cond=controlnet_image,
                return_dict=False,
                image_encoder_hidden_states=inj_embeddings,
            )

            # Predict the noise residual
            model_pred = unet(
                noisy_latents,
                timesteps,
                encoder_hidden_states=encoder_hidden_states,
                down_block_additional_residuals=[
                    sample.to(dtype=weight_dtype) for sample in down_block_res_samples
                ],
                mid_block_additional_residual=mid_block_res_sample.to(
                    dtype=weight_dtype),
                image_encoder_hidden_states=inj_embeddings,
            ).sample

            # Get the target for loss depending on the prediction type
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(
                    latents, noise, timesteps)
            else:
                raise ValueError(
                    f"Unknown prediction type {noise_scheduler.config.prediction_type}")
            loss = F.mse_loss(model_pred.float(),
                              target.float(), reduction="mean")

            accelerator.backward(loss)
            if accelerator.sync_gradients:
                # params_to_clip = controlnet.parameters()
                params_to_clip = list(
                    controlnet.parameters()) + list(unet.parameters())
                accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad(set_to_none=args.set_grads_to_none)

        # Checks if the accelerator has performed an optimization step behind the scenes
        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1

            if accelerator.is_main_process:
                if global_step % args.checkpointing_steps == 0:

                    save_path = os.path.join(
                        args.output_dir, f"checkpoint-{global_step}")
                    accelerator.save_state(save_path)
                    logger.info(f"Saved state to {save_path}")

                if args.validation_prompt is not None and global_step % args.validation_steps == 0:
                    image_logs = log_validation(
                        vae,
                        text_encoder,
                        tokenizer,
                        unet,
                        controlnet,
                        mapper,
                        clean_mapper,
                        clip_image_processor,
                        image_encoder_without_proj,
                        args,
                        accelerator,
                        weight_dtype,
                        global_step,
                    )

        logs = {"loss": loss.detach().item(
        ), "lr": lr_scheduler.get_last_lr()[0]}
        progress_bar.set_postfix(**logs)
        accelerator.log(logs, step=global_step)

        if global_step >= args.max_train_steps:
            break

# Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
    controlnet = accelerator.unwrap_model(controlnet)
    controlnet.save_pretrained(args.output_dir)

    unet = accelerator.unwrap_model(unet)
    unet.save_pretrained(args.output_dir)

    if args.push_to_hub:
        save_model_card(
            repo_id,
            image_logs=image_logs,
            base_model=args.pretrained_model_name_or_path,
            repo_folder=args.output_dir,
        )
        upload_folder(
            repo_id=repo_id,
            folder_path=args.output_dir,
            commit_message="End of training",
            ignore_patterns=["step_*", "epoch_*"],
        )

accelerator.end_training()
