#!/usr/bin/env python
import argparse
import logging
import math
import os
import random
import shutil
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from tqdm.auto import tqdm
from transformers import (
    CLIPTokenizer,
    CLIPTextModel,
    CLIPTextModelWithProjection,
)

import diffusers
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    DPMSolverMultistepScheduler,
    UNet2DConditionModel,
    StableDiffusionXLPipeline,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import (
    check_min_version,
    is_wandb_available,
    convert_state_dict_to_diffusers,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available

from dti.datasets import DualTokenizerDataset
from dti.training import (
    add_new_token,
    replace_token_embedding,
    save_progress,
    retract_token_embeddings,
    project_grads_to_tangent_space,
)
from dti.utils import str2bool

if is_wandb_available():
    import wandb
# ------------------------------------------------------------------------------


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
# check_min_version("0.32.0.dev0")

logger = get_logger(__name__)


def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):
    img_str = ""
    for i, image in enumerate(images):
        image.save(Path(repo_folder) / f"image_{i}.png")
        img_str += f"![img_{i}](./image_{i}.png)\n"

    model_description = f"""
# Textual inversion text2image fine-tuning - {repo_id}
These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n
{img_str}
"""
    model_card = load_or_create_model_card(
        repo_id_or_path=repo_id,
        from_training=True,
        license="creativeml-openrail-m",
        base_model=base_model,
        model_description=model_description,
        inference=True,
    )

    tags = [
        "stable-diffusion-xl",
        "stable-diffusion-xl-diffusers",
        "text-to-image",
        "diffusers",
        "diffusers-training",
        "textual_inversion",
    ]

    model_card = populate_model_card(model_card, tags=tags)

    model_card.save(Path(repo_folder) / "README.md")


def log_validation(
    text_encoder_1,
    text_encoder_2,
    tokenizer_1,
    tokenizer_2,
    unet,
    vae,
    args,
    accelerator,
    weight_dtype,
    epoch,
    is_final_validation=False,
):
    logger.info(
        f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
        f" {args.validation_prompt}."
    )
    pipeline = StableDiffusionXLPipeline.from_pretrained(
        args.pretrained_model_name_or_path,
        text_encoder=accelerator.unwrap_model(text_encoder_1),
        text_encoder_2=accelerator.unwrap_model(text_encoder_2),
        tokenizer=tokenizer_1,
        tokenizer_2=tokenizer_2,
        unet=accelerator.unwrap_model(unet),
        vae=vae,
        revision=args.revision,
        variant=args.variant,
        torch_dtype=weight_dtype,
    )
    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
    pipeline = pipeline.to(accelerator.device)
    pipeline.set_progress_bar_config(disable=True)

    # run inference
    generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
    images = []
    print(args.validation_prompt)
    for _ in range(args.num_validation_images):
        image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
        images.append(image)

    tracker_key = "test" if is_final_validation else "validation"
    for tracker in accelerator.trackers:
        if tracker.name == "tensorboard":
            np_images = np.stack([np.asarray(img) for img in images])
            tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC")
        if tracker.name == "wandb":
            tracker.log(
                {
                    tracker_key: [
                        wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
                    ]
                }
            )

    del pipeline
    torch.cuda.empty_cache()
    return images


def save_unet_lora(unet, accelerator, save_dir):
    unet = (
        accelerator.unwrap_model(unet)
    )
    unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
    StableDiffusionXLPipeline.save_lora_weights(
        save_dir,
        unet_lora_layers=unet_lora_layers,
        text_encoder_lora_layers=None,
        text_encoder_2_lora_layers=None,
    )


def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--save_steps",
        type=int,
        default=500,
        help="Save learned_embeds.bin every X updates steps.",
    )
    parser.add_argument(
        "--save_as_full_pipeline",
        action="store_true",
        help="Save the complete stable diffusion pipeline.",
    )
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--revision",
        type=str,
        default=None,
        required=False,
        help="Revision of pretrained model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--variant",
        type=str,
        default=None,
        help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
    )
    parser.add_argument(
        "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
    )
    parser.add_argument(
        "--instance",
        type=str,
        default=None,
        help="The name of the instance to use. If not specified, all instances in the train data directory will be used.",
    )
    parser.add_argument(
        "--placeholder_token",
        type=str,
        default=None,
        required=True,
        help="A token to use as a placeholder for the concept.",
    )
    parser.add_argument(
        "--num_vectors",
        type=int,
        default=None,
        help="Number of vectors to learn. The model will learn a vector for each placeholder token.",
    )
    parser.add_argument(
        "--initializer_token", type=str, default=None, help="A token to use as initializer word."
    )
    parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
    parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
    parser.add_argument(
        "--output_dir",
        type=str,
        default="text-inversion-model",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
    parser.add_argument(
        "--resolution",
        type=int,
        default=512,
        help=(
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
            " resolution"
        ),
    )
    parser.add_argument(
        "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
    )
    parser.add_argument(
        "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
    )
    parser.add_argument("--num_train_epochs", type=int, default=100)
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=5000,
        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
    )
    parser.add_argument(
        "--emb_learning_rate",
        type=float,
        default=5e-4,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-4,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument(
        "--scale_lr",
        action="store_true",
        default=False,
        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
    )
    parser.add_argument(
        "--lr_scheduler",
        type=str,
        default="constant",
        help=(
            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
            ' "constant", "constant_with_warmup"]'
        ),
    )
    parser.add_argument(
        "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
    )
    parser.add_argument(
        "--lr_num_cycles",
        type=int,
        default=1,
        help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
    )
    parser.add_argument(
        "--dataloader_num_workers",
        type=int,
        default=1,
        help=(
            "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
        ),
    )
    parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
    parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
    parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
    parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
    parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
    parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
    parser.add_argument(
        "--hub_model_id",
        type=str,
        default=None,
        help="The name of the repository to keep in sync with the local `output_dir`.",
    )
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help=(
            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
        ),
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default=None,
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose"
            "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
            "and an Nvidia Ampere GPU."
        ),
    )
    parser.add_argument(
        "--allow_tf32",
        action="store_true",
        help=(
            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
        ),
    )
    parser.add_argument(
        "--report_to",
        type=str,
        default="tensorboard",
        help=(
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
        ),
    )
    parser.add_argument(
        "--run_name",
        type=str,
        default=None,
        help="Optionally set the run name to use for logging.",
    )
    parser.add_argument(
        "--validation_prompt",
        type=str,
        default=None,
        help="A prompt that is used during validation to verify that the model is learning.",
    )
    parser.add_argument(
        "--num_validation_images",
        type=int,
        default=4,
        help="Number of images that should be generated during validation with `validation_prompt`.",
    )
    parser.add_argument(
        "--validation_steps",
        type=int,
        default=100,
        help=(
            "Run validation every X steps. Validation consists of running the prompt"
            " `args.validation_prompt` multiple times: `args.num_validation_images`"
            " and logging the images."
        ),
    )
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument(
        "--checkpointing_steps",
        type=int,
        default=5000,  # TODO: set to 500
        help=(
            "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
            " training using `--resume_from_checkpoint`."
        ),
    )
    parser.add_argument(
        "--checkpoints_total_limit",
        type=int,
        default=None,
        help=("Max number of checkpoints to store."),
    )
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
        help=(
            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
        ),
    )
    parser.add_argument(
        "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
    )

    parser.add_argument(
        "--lora_rank",
        type=int,
        default=4,
        help="For distributed training: lora_rank",
    )
    parser.add_argument(
        "--lora_modules",
        type=str,
        nargs="+",
        default=["to_q", "to_k", "to_v"],
        help="The modules to apply lora to.",
    )
    parser.add_argument(
        "--dco_beta",
        type=float,
        default=0.0,
        help="DCO beta parameter. 0.0 means no DCO.",
    )

    parser.add_argument(
        "--use_adam",
        action="store_true",
        help="Whether to use Riemannian Stochastic Gradient Descent (RSGD) for optimization.",
    )
    parser.add_argument(
        "--reparameterize",
        type=str2bool,
        default=True,
        help=(
            "Whether to use reparametrization trick for the new token embeddings. This is used to improve the"
            " stability of the training process."
        ),
    )
    parser.add_argument(
        "--init_method",
        type=str,
        default="token",
        choices=["token", "random", "mean"],
        help=(
            "The method to use for initializing the new token embeddings. Choose between 'token', 'random' and 'mean'."
            " 'token' uses the embedding of the initializer token, 'random' uses a random embedding and 'mean' uses"
            " the mean of all embeddings."
        ),
    )
    parser.add_argument(
        "--init_scale",
        type=str,
        default="max",
    )
    parser.add_argument(
        "--kappa",
        type=float,
        default=0.1,
        help=(
            "The concentration parameter for the von Mises-Fisher distribution. This is used to initialize the"
            " embeddings of the new token."
        ),
    )
    parser.add_argument(
        "--prior_min",
        type=float,
        default=0.0,
        help="Concentration parameter for the von Mises-Fisher distribution.",
    )
    parser.add_argument(
        "--prior_weight",
        type=float,
        default=0.0005,
        help="Weight of the prior loss.",
    )
    parser.add_argument(
        "--train_magnitude",
        action="store_true",
        help=(
            "Whether to train the magnitude of the new token embeddings. If set, the magnitude of the new token"
            " embeddings will be trained."
        ),
    )
    parser.add_argument(
        "--mag_lr_multiplier",
        type=float,
        default=1.0,
        help=(
            "The learning rate multiplier for the magnitude of the new token embeddings. This is used to control the"
            " learning rate of the magnitude of the new token embeddings."
        ),
    )
    parser.add_argument(
        "--zero_pad",
        action="store_true",
        help="Whether to pad the text with <pad> tokens.",
    )
    parser.add_argument(
        "--legacy",
        action="store_true",
        help="Whether to use the legacy code for RSGD.",
    )

    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    if args.train_data_dir is None:
        raise ValueError("You must specify a train data directory.")

    if args.num_vectors is None and args.initializer_token is None:
        raise ValueError("You must specify either --num_vectors or --initializer_token.")

    return args


def main():
    args = parse_args()
    if args.report_to == "wandb" and args.hub_token is not None:
        raise ValueError(
            "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
            " Please use `huggingface-cli login` to authenticate with the Hub."
        )

    logging_dir = Path(args.output_dir) / args.logging_dir
    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_config=accelerator_project_config,
    )

    # Disable AMP for MPS.
    if torch.backends.mps.is_available():
        accelerator.native_amp = False

    if args.report_to == "wandb":
        if not is_wandb_available():
            raise ImportError("Make sure to install wandb if you want to use it for logging during training.")

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Handle the repository creation
    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)

        if args.push_to_hub:
            repo_id = create_repo(
                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
            ).repo_id

    # Load tokenizer
    tokenizer_1 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
    tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2")

    # Load scheduler and models
    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
    text_encoder_1 = CLIPTextModel.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
    )
    text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
    )
    vae = AutoencoderKL.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
    )
    unet = UNet2DConditionModel.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
    )

    # Replace text embeddings.
    text_encoder_1 = replace_token_embedding(
        text_encoder_1, reparameterize=args.reparameterize, legacy=args.legacy
    )
    text_encoder_2 = replace_token_embedding(
        text_encoder_2, reparameterize=args.reparameterize, legacy=args.legacy
    )

    placeholder = args.placeholder_token
    added_token_ids = []
    added_token_ids_2 = []
    new_tokens = []
    new_tokens_2 = []

    new_token, init_embeds = add_new_token(
        tokenizer_1,
        text_encoder_1,
        args.placeholder_token,
        num_vectors=args.num_vectors,
        init_token=args.initializer_token,
        scale=args.init_scale,
        init_method=args.init_method,
    )
    new_token_2, init_embeds_2 = add_new_token(
        tokenizer_2,
        text_encoder_2,
        args.placeholder_token,
        num_vectors=args.num_vectors,
        # init_token=mu_2,
        init_token=args.initializer_token,
        scale=args.init_scale,
        init_method=args.init_method,
    )

    added_token_ids += new_token.token_ids.copy()
    added_token_ids_2 += new_token_2.token_ids.copy()

    new_tokens.append(new_token)
    new_tokens_2.append(new_token_2)

    init_embeds = init_embeds.to(accelerator.device)
    init_embeds_2 = init_embeds_2.to(accelerator.device)

    # Add padding token.
    placeholder = "<pad>"
    _, _ = add_new_token(
        tokenizer_1,
        text_encoder_1,
        placeholder,
        num_vectors=1,
        scale=0.0,
    )
    _, _ = add_new_token(
        tokenizer_2,
        text_encoder_2,
        placeholder,
        num_vectors=1,
        scale=0.0,
    )
    print(tokenizer_1)
    print(tokenizer_2)
    print(new_tokens)
    print(new_tokens_2)

    # Freeze vae and unet
    vae.eval().requires_grad_(False)
    unet.eval().requires_grad_(False)

    # Freeze all parameters except for the token embeddings in text encoder
    text_encoder_1.text_model.encoder.requires_grad_(False)
    text_encoder_1.text_model.final_layer_norm.requires_grad_(False)
    text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)
    text_encoder_2.text_model.encoder.requires_grad_(False)
    text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
    text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
    text_encoder_2.text_projection.requires_grad_(False)

    scale_grad = args.train_magnitude
    text_encoder_1.get_input_embeddings().scales.requires_grad_(scale_grad)
    text_encoder_2.get_input_embeddings().scales.requires_grad_(scale_grad)

    if args.gradient_checkpointing:
        text_encoder_1.gradient_checkpointing_enable()
        text_encoder_2.gradient_checkpointing_enable()

    unet_lora_config = LoraConfig(
        r=args.lora_rank,
        lora_alpha=args.lora_rank,
        init_lora_weights="gaussian",
        target_modules=args.lora_modules,
    )
    unet.add_adapter(unet_lora_config)

    if args.enable_xformers_memory_efficient_attention:
        if is_xformers_available():
            import xformers

            xformers_version = version.parse(xformers.__version__)
            if xformers_version == version.parse("0.0.16"):
                logger.warning(
                    "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
                )
            unet.enable_xformers_memory_efficient_attention()
        else:
            raise ValueError("xformers is not available. Make sure it is installed correctly")

    # Enable TF32 for faster training on Ampere GPUs,
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
    if args.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True

    if args.scale_lr:
        args.learning_rate = (
            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
        )

    # Initialize the optimizer
    if args.use_adam:
        embedding_parameters = [
            text_encoder_1.text_model.embeddings.token_embedding.weight,
            # text_encoder_1.text_model.embeddings.token_scale.weight,
            text_encoder_2.text_model.embeddings.token_embedding.weight,
            # text_encoder_2.text_model.embeddings.token_scale.weight,
        ]
        emb_optimizer = torch.optim.AdamW(
            embedding_parameters,
            lr=args.emb_learning_rate,
            betas=(args.adam_beta1, args.adam_beta2),
            # weight_decay=args.adam_weight_decay,
            weight_decay=0.0,
            eps=args.adam_epsilon,
        )
    else:
        emb_optimizer = torch.optim.SGD(
            # only optimize the embeddings
            [
                text_encoder_1.text_model.embeddings.token_embedding.weight,
                text_encoder_2.text_model.embeddings.token_embedding.weight,
            ],
            lr=args.emb_learning_rate,
            weight_decay=0.0,
        )

    unet_lora_parameters = list(
        filter(lambda p: p.requires_grad, unet.parameters())
    )
    lora_optimizer = torch.optim.AdamW(
        unet_lora_parameters,
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    args.validation_prompt = args.validation_prompt.format(
        new_token.identifier,
    )
    # Dataset and DataLoaders creation:
    train_dataset = DualTokenizerDataset(
        data_root=args.train_data_dir,
        tokenizer_1=tokenizer_1,
        tokenizer_2=tokenizer_2,
        instance=args.instance,
        size=args.resolution,
        placeholder_token=new_token.identifier,
        repeats=args.repeats,
        learnable_property=args.learnable_property,
        center_crop=args.center_crop,
        flip_p=0.0,  # NOTE: 0.5?
        zero_pad=False,
    )
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
    )

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=lora_optimizer,
        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
        num_training_steps=args.max_train_steps * accelerator.num_processes,
        num_cycles=args.lr_num_cycles,
    )

    text_encoder_1.train()
    text_encoder_2.train()
    # Prepare everything with our `accelerator`.
    (
        unet, text_encoder_1, text_encoder_2,
        emb_optimizer, lora_optimizer, lr_scheduler,
        train_dataloader,
    ) = accelerator.prepare(
        unet, text_encoder_1, text_encoder_2,
        emb_optimizer, lora_optimizer, lr_scheduler,
        train_dataloader,
    )

    # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    # Move vae and unet and text_encoder_2 to device and cast to weight_dtype
    unet.to(accelerator.device, dtype=weight_dtype)
    vae.to(accelerator.device, dtype=torch.float32)
    text_encoder_2.to(accelerator.device, dtype=weight_dtype)

    num_lora = 0
    for name, param in unet.named_parameters():
        if param.requires_grad:
            param.to(dtype=torch.float32)
            if "lora" in name:
                num_lora += param.numel()
            else:
                print(name)
    print(f"Number of LoRA parameters: {num_lora}")

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if overrode_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if accelerator.is_main_process:
        init_kwargs = {}
        if args.run_name is not None:
            # add random two digits to the run name to avoid name clashes
            run_name = f"{args.run_name}-{random.randint(0, 99):02d}"
            init_kwargs["wandb"] = {"name": run_name}
        accelerator.init_trackers("dti", config=vars(args), init_kwargs=init_kwargs)
        # accelerator.init_trackers("textual_inversion", config=vars(args))

    # Train!
    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
    global_step = 0
    first_epoch = 0
    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint != "latest":
            path = Path(args.resume_from_checkpoint).name
        else:
            # Get the most recent checkpoint
            dirs = list(Path(args.output_dir).iterdir())
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None

        if path is None:
            accelerator.print(
                f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
            )
            args.resume_from_checkpoint = None
            initial_global_step = 0
        else:
            accelerator.print(f"Resuming from checkpoint {path}")
            accelerator.load_state(Path(args.output_dir) / path)
            global_step = int(path.split("-")[1])

            initial_global_step = global_step
            first_epoch = global_step // num_update_steps_per_epoch

    else:
        initial_global_step = 0

    progress_bar = tqdm(
        range(0, args.max_train_steps),
        initial=initial_global_step,
        desc="Steps",
        # Only show the progress bar once on each machine.
        disable=not accelerator.is_local_main_process,
    )

    with torch.no_grad():
        target_embeds = init_embeds / torch.linalg.norm(init_embeds, dim=-1, keepdim=True)
        target_embeds_2 = init_embeds_2 / torch.linalg.norm(init_embeds_2, dim=-1, keepdim=True)

    # keep original embeddings as reference
    orig_embeds_params = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone()
    orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone()

    index_no_updates = torch.ones((len(tokenizer_1),), dtype=torch.bool)
    index_no_updates[min(added_token_ids) : max(added_token_ids) + 1] = False
    index_no_updates_2 = torch.ones((len(tokenizer_2),), dtype=torch.bool)
    index_no_updates_2[min(added_token_ids_2) : max(added_token_ids_2) + 1] = False


    for epoch in range(first_epoch, args.num_train_epochs):
        text_encoder_1.train()
        text_encoder_2.train()
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate([text_encoder_1, text_encoder_2]):
                # Convert images to latent space
                latents = vae.encode(batch["pixel_values"].to(dtype=vae.dtype)).latent_dist.sample().detach()
                latents = latents * vae.config.scaling_factor

                # Sample noise that we'll add to the latents
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                # Sample a random timestep for each image
                timesteps = torch.randint(1, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
                timesteps = timesteps.long()

                # Add noise to the latents according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps).to(dtype=weight_dtype)

                # Get the text embedding for conditioning
                encoder_hidden_states_1 = (
                    text_encoder_1(batch["input_ids_1"], output_hidden_states=True)
                    .hidden_states[-2]
                    .to(dtype=weight_dtype)
                )
                encoder_output_2 = text_encoder_2(
                    batch["input_ids_2"],
                    output_attentions=True,
                    output_hidden_states=True,
                )
                encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype)
                original_size = [
                    (batch["original_size"][0][i].item(), batch["original_size"][1][i].item())
                    for i in range(args.train_batch_size)
                ]
                crop_top_left = [
                    (batch["crop_top_left"][0][i].item(), batch["crop_top_left"][1][i].item())
                    for i in range(args.train_batch_size)
                ]
                target_size = (args.resolution, args.resolution)
                add_time_ids = torch.cat(
                    [
                        torch.tensor(original_size[i] + crop_top_left[i] + target_size)
                        for i in range(args.train_batch_size)
                    ]
                ).to(accelerator.device, dtype=weight_dtype)
                added_cond_kwargs = {"text_embeds": encoder_output_2[0], "time_ids": add_time_ids}
                encoder_hidden_states = torch.cat([encoder_hidden_states_1, encoder_hidden_states_2], dim=-1)

                # Predict the noise residual
                model_pred = unet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states,
                    added_cond_kwargs=added_cond_kwargs,
                ).sample

                # Get the target for loss depending on the prediction type
                if noise_scheduler.config.prediction_type == "epsilon":
                    target = noise
                elif noise_scheduler.config.prediction_type == "v_prediction":
                    target = noise_scheduler.get_velocity(latents, noise, timesteps)
                else:
                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
                if args.dco_beta > 0.0:
                    with torch.no_grad():
                        cross_attention_kwargs = {"scale": 0.0}
                        refer_pred = unet(
                            noisy_latents,
                            timesteps,
                            encoder_hidden_states,
                            added_cond_kwargs=added_cond_kwargs,
                            cross_attention_kwargs=cross_attention_kwargs,
                        ).sample
                        loss_refer = F.mse_loss(refer_pred.float(), target.float(), reduction="mean")
                    diff = loss - loss_refer
                    print(diff)
                    inside_term = -1 * args.dco_beta * diff
                    loss = -1 * F.logsigmoid(inside_term)

                # NOTE: Direction regularization.
                if args.kappa > 0.0 and False:
                    kappa_t = args.kappa

                    embeds = (
                        accelerator.unwrap_model(text_encoder_1)
                        .get_input_embeddings()
                        .weight[min(added_token_ids) : max(added_token_ids) + 1]
                    )
                    if not args.legacy:
                        reg_1 = (target_embeds * embeds).sum(dim=-1)
                    else:
                        normalized_embeds = embeds / torch.linalg.norm(embeds, dim=-1, keepdim=True)
                        reg_1 = (target_embeds * normalized_embeds).sum(dim=-1)
                    reg_1 = -1 * kappa_t * reg_1

                    embeds_2 = (
                        accelerator.unwrap_model(text_encoder_2)
                        .get_input_embeddings()
                        .weight[min(added_token_ids_2) : max(added_token_ids_2) + 1]
                    )
                    if not args.legacy:
                        reg_2 = (target_embeds_2 * embeds_2).sum(dim=-1)
                    else:
                        normalized_embeds_2 = embeds_2 / torch.linalg.norm(embeds_2, dim=-1, keepdim=True)
                        reg_2 = (target_embeds_2 * normalized_embeds_2).sum(dim=-1)
                    reg_2 = -1 * kappa_t * reg_2

                    # prior_coeff = args.prior_min + relative_step * (args.prior_weight - args.prior_min)
                    prior_coeff = 1 / (1000 * args.gradient_accumulation_steps)
                    prior_term = reg_1 + reg_2
                    loss = loss + prior_coeff * prior_term.sum()

                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    if not args.use_adam:
                        with torch.no_grad():
                            # b2 = 1 - (global_step+1)**(-0.1)
                            if not args.legacy:
                                project_grads_to_tangent_space(
                                    accelerator.unwrap_model(text_encoder_1),
                                    added_token_ids,
                                    kappa=(args.kappa / 1000),
                                    target_embeds=target_embeds,
                                )
                                project_grads_to_tangent_space(
                                    accelerator.unwrap_model(text_encoder_2),
                                    added_token_ids_2,
                                    kappa=(args.kappa / 1000),
                                    target_embeds=target_embeds_2,
                                )
                            else:
                                grad = (
                                    accelerator.unwrap_model(text_encoder_1)
                                    .get_input_embeddings()
                                    .weight.grad
                                )[min(added_token_ids) : max(added_token_ids) + 1]
                                grad_norm = torch.linalg.norm(grad, dim=-1, keepdim=True)  # = L2 norm
                                grad_norm = torch.clamp(grad_norm, min=1e-6)
                                accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.grad[
                                    min(added_token_ids) : max(added_token_ids) + 1
                                ] = grad / grad_norm
                                # h_1 = grad.square().sum(dim=1)
                                # v_t1 = v_t1 * b2 + (1-b2) * h_1
                                # accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.grad[
                                #     min(added_token_ids) : max(added_token_ids) + 1
                                # ] = grad / torch.sqrt(v_t1 + 1e-8)

                                grad = (
                                    accelerator.unwrap_model(text_encoder_2)
                                    .get_input_embeddings()
                                    .weight.grad
                                )[min(added_token_ids_2) : max(added_token_ids_2) + 1]
                                grad_norm = torch.linalg.norm(grad, dim=-1, keepdim=True)
                                grad_norm = torch.clamp(grad_norm, min=1e-6)
                                accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.grad[
                                    min(added_token_ids_2) : max(added_token_ids_2) + 1
                                ] = grad / grad_norm
                                # h_2 = grad.square().sum(dim=1)
                                # v_t2 = v_t2 * b2 + (1-b2) * h_2
                                # accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.grad[
                                #     min(added_token_ids_2) : max(added_token_ids_2) + 1
                                # ] = grad / torch.sqrt(v_t2 + 1e-8)

                emb_optimizer.step()
                lora_optimizer.step()
                lr_scheduler.step()
                emb_optimizer.zero_grad()
                lora_optimizer.zero_grad()

                # Let's make sure we don't update any embedding weights besides the newly added token
                if accelerator.sync_gradients:
                    with torch.no_grad():
                        embeddings = (
                            accelerator.unwrap_model(text_encoder_1)
                            .get_input_embeddings()
                        )
                        embeddings.weight[
                            index_no_updates
                        ] = orig_embeds_params[index_no_updates]
                        index_updates = ~index_no_updates
                        retract_token_embeddings(
                            embeddings,
                            index_updates=index_updates,
                        )

                        embeddings_2 = (
                            accelerator.unwrap_model(text_encoder_2)
                            .get_input_embeddings()
                        )
                        embeddings_2.weight[
                            index_no_updates_2
                        ] = orig_embeds_params_2[index_no_updates_2]
                        index_updates_2 = ~index_no_updates_2
                        retract_token_embeddings(
                            embeddings_2,
                            index_updates=index_updates_2,
                        )


            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                images = []
                progress_bar.update(1)
                global_step += 1
                if global_step % args.save_steps == 0:
                    weight_name = f"learned_embeds-steps-{global_step}.safetensors"
                    save_path = Path(args.output_dir) / weight_name
                    save_progress(
                        accelerator.unwrap_model(text_encoder_1),
                        new_tokens,
                        save_path,
                        safe_serialization=True,
                    )
                    weight_name = f"learned_embeds_2-steps-{global_step}.safetensors"
                    save_path = Path(args.output_dir) / weight_name
                    save_progress(
                        accelerator.unwrap_model(text_encoder_2),
                        new_tokens_2,
                        save_path,
                        safe_serialization=True,
                    )
                    ckpt_dir = Path(args.output_dir) / f"checkpoint-{global_step}"
                    os.makedirs(ckpt_dir, exist_ok=True)
                    save_unet_lora(
                        unet=unet,
                        accelerator=accelerator,
                        save_dir=ckpt_dir,
                    )

                if accelerator.is_main_process:
                    if global_step % args.checkpointing_steps == 0:
                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
                        if args.checkpoints_total_limit is not None:
                            checkpoints = list(Path(args.output_dir).iterdir())
                            checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
                            if len(checkpoints) >= args.checkpoints_total_limit:
                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
                                removing_checkpoints = checkpoints[0:num_to_remove]

                                logger.info(
                                    f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
                                )
                                logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")

                                for removing_checkpoint in removing_checkpoints:
                                    removing_checkpoint = Path(args.output_dir) / removing_checkpoint
                                    shutil.rmtree(removing_checkpoint)

                        save_path = Path(args.output_dir) / f"checkpoint-{global_step}"
                        accelerator.save_state(save_path)
                        logger.info(f"Saved state to {save_path}")

                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:
                        images = log_validation(
                            text_encoder_1,
                            text_encoder_2,
                            tokenizer_1,
                            tokenizer_2,
                            unet,
                            vae,
                            args,
                            accelerator,
                            weight_dtype,
                            epoch,
                        )
                        rows = args.num_validation_images // 2
                        cols = 2
                        image_grid = diffusers.utils.make_image_grid(images, rows=rows, cols=cols)
                        image_grid.save(Path(args.output_dir) / f"validation-{global_step:04d}.jpg")

            logs = {
                "loss": loss.detach().item(),
                "lr": lr_scheduler.get_last_lr()[0],
            }
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)

            if global_step >= args.max_train_steps:
                break
    # Create the pipeline using the trained modules and save it.
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        if args.validation_prompt:
            images = log_validation(
                text_encoder_1,
                text_encoder_2,
                tokenizer_1,
                tokenizer_2,
                unet,
                vae,
                args,
                accelerator,
                weight_dtype,
                epoch,
                is_final_validation=True,
            )

        if args.push_to_hub and not args.save_as_full_pipeline:
            logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
            save_full_model = True
        else:
            save_full_model = args.save_as_full_pipeline
        if save_full_model:
            pipeline = StableDiffusionXLPipeline.from_pretrained(
                args.pretrained_model_name_or_path,
                text_encoder=accelerator.unwrap_model(text_encoder_1),
                text_encoder_2=accelerator.unwrap_model(text_encoder_2),
                vae=vae,
                unet=unet,
                tokenizer=tokenizer_1,
                tokenizer_2=tokenizer_2,
            )
            pipeline.save_pretrained(args.output_dir)
        # Save the newly trained embeddings
        weight_name = "learned_embeds.safetensors"
        save_path = Path(args.output_dir) / weight_name
        save_progress(
            accelerator.unwrap_model(text_encoder_1),
            new_tokens,
            save_path,
            safe_serialization=True,
        )
        weight_name = "learned_embeds_2.safetensors"
        save_path = Path(args.output_dir) / weight_name
        save_progress(
            accelerator.unwrap_model(text_encoder_2),
            new_tokens_2,
            save_path,
            safe_serialization=True,
        )
        save_unet_lora(
            unet=unet,
            accelerator=accelerator,
            save_dir=args.output_dir,
        )
        if args.push_to_hub:
            save_model_card(
                repo_id,
                images=images,
                base_model=args.pretrained_model_name_or_path,
                repo_folder=args.output_dir,
            )
            upload_folder(
                repo_id=repo_id,
                folder_path=args.output_dir,
                commit_message="End of training",
                ignore_patterns=["step_*", "epoch_*"],
            )

    accelerator.end_training()


if __name__ == "__main__":
    main()
