# Bootstrapped from:
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py

import argparse
import hashlib
import itertools
import math
import os
import inspect
from pathlib import Path
from typing import Optional

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import logging

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    StableDiffusionPipeline,
    UNet2DConditionModel,
    EulerAncestralDiscreteScheduler,
)
from diffusers.optimization import get_scheduler
from huggingface_hub import HfFolder, Repository, whoami

from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from lora_diffusion import (
    extract_lora_ups_down,
    inject_trainable_lora,
    safetensors_available,
    save_lora_weight,
    save_safeloras,
    tune_lora_scale,
)
from lora_diffusion.xformers_utils import set_use_memory_efficient_attention_xformers
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

from pathlib import Path

import random
import re
from utils import display_images, save_histogram
from itertools import cycle


class DreamBoothDataset(Dataset):
    """
    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
    It pre-processes the images and the tokenizes prompts.
    """

    def __init__(
        self,
        instance_data_dir,
        tokenizer,
        prompt_template,
        instance_prompt=None,
        class_data_root=None,
        class_prompt=None,
        num_class_images=None,
        size=512,
        center_crop=False,
        color_jitter=False,
        h_flip=False,
        resize=False,
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer
        self.resize = resize

        self.train_data = torch.load(instance_data_dir)
        self.instance_images_path = [data["impath"] for data in self.train_data]
        if instance_prompt is None:
            self.instance_prompts = [prompt_template.format(data["classname"].replace('_', ' '))\
                                    for data in self.train_data]
        else:
            self.instance_prompts = [prompt_template.format(instance_prompt + " " + data["classname"].replace('_', ' '))\
                                    for data in self.train_data]
        self.original_prompts = [prompt_template.format(data["classname"].replace('_', ' '))\
                                for data in self.train_data]
        self.num_instance_images = len(self.instance_images_path)
        self._length = self.num_instance_images

        self.unique_instance_prompts = set(self.instance_prompts)
        print(list(self.unique_instance_prompts)[0])
        if class_data_root is not None:
            self.class_data_root = Path(class_data_root)
            self.class_data_root.mkdir(parents=True, exist_ok=True)
            self.class_images_path = list(self.class_data_root.iterdir())
            if num_class_images is None:
                self.num_class_images = len(self.class_images_path)
            else:
                self.num_class_images = num_class_images
            self._length = self.num_class_images + self.num_instance_images
            self.class_prompt = class_prompt

            self.images_path = self.instance_images_path + self.class_images_path[:self.num_class_images]
            self.prompts = self.instance_prompts + [self.class_prompt] * self.num_class_images
            self.original_prompts += [self.class_prompt] * self.num_class_images
        else:
            self.images_path = self.instance_images_path
            self.prompts = self.instance_prompts

        img_transforms = []

        if resize:
            img_transforms.append(
                transforms.Resize(
                    size, interpolation=transforms.InterpolationMode.BILINEAR
                )
            )
        if center_crop:
            img_transforms.append(transforms.CenterCrop(size))
        if color_jitter:
            img_transforms.append(transforms.ColorJitter(0.2, 0.1))
        if h_flip:
            img_transforms.append(transforms.RandomHorizontalFlip())

        self.image_transforms = transforms.Compose(
            [*img_transforms, transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        image = Image.open(
            self.images_path[index]
        )
        if not image.mode == "RGB":
            image = image.convert("RGB")
        example["images"] = self.image_transforms(image)
        example["prompt_ids"] = self.tokenizer(
            self.prompts[index],
            padding="do_not_pad",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
        ).input_ids
        example["original_prompt_ids"] = self.tokenizer(
            self.original_prompts[index],
            padding="do_not_pad",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
        ).input_ids

        wrong_set = self.unique_instance_prompts.copy()
        wrong_set.discard(self.prompts[index])
        example["wrong_prompt_ids"] = self.tokenizer(
            random.choice(list(wrong_set)),
            padding="do_not_pad",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
        ).input_ids

        example["is_class"] = (index >= self.num_instance_images)
        return example


class PromptDataset(Dataset):
    "A simple dataset to prepare the prompts to generate class images on multiple GPUs."

    def __init__(self, prompt, num_samples):
        self.prompt = prompt
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        example = {}
        example["prompt"] = self.prompt
        example["index"] = index
        return example


logger = get_logger(__name__, log_level="INFO")


def parse_args(input_args=None):
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--pretrained_vae_name_or_path",
        type=str,
        default=None,
        help="Path to pretrained vae or vae identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--revision",
        type=str,
        default=None,
        required=False,
        help="Revision of pretrained model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--tokenizer_name",
        type=str,
        default=None,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--instance_data_dir",
        type=str,
        default=None,
        required=True,
        help="A folder containing the training data of instance images.",
    )
    parser.add_argument(
        "--class_data_dir",
        type=str,
        default=None,
        required=False,
        help="A folder containing the training data of class images.",
    )
    parser.add_argument(
        "--instance_prompt",
        type=str,
        default=None,
        required=False,
        help="The prompt with identifier specifying the instance",
    )
    parser.add_argument(
        "--prompt_template",
        type=str,
        default=None,
        required=True,
        help="Prompt template to format instance prompt",
    )
    parser.add_argument(
        "--class_prompt",
        type=str,
        default=None,
        help="The prompt to specify images in the same class as provided instance images.",
    )
    parser.add_argument(
        "--with_prior_preservation",
        default=False,
        action="store_true",
        help="Flag to add prior preservation loss.",
    )
    parser.add_argument(
        "--prior_loss_weight",
        type=float,
        default=1.0,
        help="The weight of prior preservation loss.",
    )
    parser.add_argument(
        "--num_class_images",
        type=int,
        default=100,
        help=(
            "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
            " sampled with class_prompt."
        ),
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="text-inversion-model",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--output_format",
        type=str,
        choices=["pt", "safe", "both"],
        default="both",
        help="The output format of the model predicitions and checkpoints.",
    )
    parser.add_argument(
        "--seed", type=int, default=None, help="A seed for reproducible training."
    )
    parser.add_argument(
        "--resolution",
        type=int,
        default=512,
        help=(
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
            " resolution"
        ),
    )
    parser.add_argument(
        "--center_crop",
        action="store_true",
        help="Whether to center crop images before resizing to resolution",
    )
    parser.add_argument(
        "--color_jitter",
        action="store_true",
        help="Whether to apply color jitter to images",
    )
    parser.add_argument(
        "--train_text_encoder",
        action="store_true",
        help="Whether to train the text encoder",
    )
    parser.add_argument(
        "--train_batch_size",
        type=int,
        default=4,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument(
        "--sample_batch_size",
        type=int,
        default=4,
        help="Batch size (per device) for sampling images.",
    )
    parser.add_argument("--num_train_epochs", type=int, default=1)
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
    )
    parser.add_argument(
        "--save_steps",
        type=int,
        default=500,
        help="Save checkpoint every X updates steps.",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
    )
    parser.add_argument(
        "--lora_rank",
        type=int,
        default=4,
        help="Rank of LoRA approximation.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=None,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument(
        "--learning_rate_text",
        type=float,
        default=5e-6,
        help="Initial learning rate for text encoder (after the potential warmup period) to use.",
    )
    parser.add_argument(
        "--scale_lr",
        action="store_true",
        default=False,
        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
    )
    parser.add_argument(
        "--lr_scheduler",
        type=str,
        default="constant",
        help=(
            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
            ' "constant", "constant_with_warmup"]'
        ),
    )
    parser.add_argument(
        "--lr_warmup_steps",
        type=int,
        default=500,
        help="Number of steps for the warmup in the lr scheduler.",
    )
    parser.add_argument(
        "--use_8bit_adam",
        action="store_true",
        help="Whether or not to use 8-bit Adam from bitsandbytes.",
    )
    parser.add_argument(
        "--adam_beta1",
        type=float,
        default=0.9,
        help="The beta1 parameter for the Adam optimizer.",
    )
    parser.add_argument(
        "--adam_beta2",
        type=float,
        default=0.999,
        help="The beta2 parameter for the Adam optimizer.",
    )
    parser.add_argument(
        "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
    )
    parser.add_argument(
        "--adam_epsilon",
        type=float,
        default=1e-08,
        help="Epsilon value for the Adam optimizer",
    )
    parser.add_argument(
        "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
    )
    parser.add_argument(
        "--push_to_hub",
        action="store_true",
        help="Whether or not to push the model to the Hub.",
    )
    parser.add_argument(
        "--save_at_last",
        action="store_true",
        help="Whether or not to save safetensor and weights at last.",
    )
    parser.add_argument(
        "--hub_token",
        type=str,
        default=None,
        help="The token to use to push to the Model Hub.",
    )
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help=(
            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
        ),
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default=None,
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
        ),
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="For distributed training: local_rank",
    )
    parser.add_argument(
        "--resume_unet",
        type=str,
        default=None,
        help=("File path for unet lora to resume training."),
    )
    parser.add_argument(
        "--resume_text_encoder",
        type=str,
        default=None,
        help=("File path for text encoder lora to resume training."),
    )
    parser.add_argument(
        "--resize",
        type=bool,
        default=True,
        required=False,
        help="Should images be resized to --resolution before training?",
    )
    parser.add_argument(
        "--use_xformers", action="store_true", help="Whether or not to use xformers"
    )
    # ADDED ARGUMENTS
    parser.add_argument(
        "--cutoff_T", type=int, default=-1, help=("training from 0...cutoff_T")
    )

    if input_args is not None:
        args = parser.parse_args(input_args)
    else:
        args = parser.parse_args()

    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    if args.with_prior_preservation:
        if args.class_data_dir is None:
            raise ValueError("You must specify a data directory for class images.")
        if args.class_prompt is None:
            raise ValueError("You must specify prompt for class images.")
    else:
        if args.class_data_dir is not None:
            logger.warning(
                "You need not use --class_data_dir without --with_prior_preservation."
            )
        if args.class_prompt is not None:
            logger.warning(
                "You need not use --class_prompt without --with_prior_preservation."
            )

    if not safetensors_available:
        if args.output_format == "both":
            print(
                "Safetensors is not available - changing output format to just output PyTorch files"
            )
            args.output_format = "pt"
        elif args.output_format == "safe":
            raise ValueError(
                "Safetensors is not available - either install it, or change output_format."
            )

    return args


def main(args):
    logging_dir = Path(args.output_dir, args.logging_dir)

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with="tensorboard",
        project_dir=logging_dir,
    )

    # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
    # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
    # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
    if (
        args.train_text_encoder
        and args.gradient_accumulation_steps > 1
        and accelerator.num_processes > 1
    ):
        raise ValueError(
            "Gradient accumulation is not supported when training the text encoder in distributed training. "
            "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
        )

    if args.seed is not None:
        set_seed(args.seed)

    # Handle the repository creation
    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)
    # Add file handler to logging
    file_handler = logging.FileHandler(f"{args.output_dir}/logging.txt")
    formatter = logging.Formatter('%(asctime)s - %(message)s')
    file_handler.setFormatter(formatter)
    logger.logger.addHandler(file_handler)

    logger.info("Running experiments with arguments...")
    arg_dict = vars(args)
    for key in arg_dict.keys():
        logger.info(f"{key}: {arg_dict[key]}")

    if args.with_prior_preservation:
        class_images_dir = Path(args.class_data_dir)
        if not class_images_dir.exists():
            class_images_dir.mkdir(parents=True)
        cur_class_images = len(list(class_images_dir.iterdir()))
        if args.num_class_images is None and cur_class_images == 0:
            assert False, "Must specify number of class images if current class folder is empty."

        if args.num_class_images is not None and cur_class_images < args.num_class_images:
            torch_dtype = (
                torch.float16 if accelerator.device.type == "cuda" else torch.float32
            )
            pipeline = StableDiffusionPipeline.from_pretrained(
                args.pretrained_model_name_or_path,
                torch_dtype=torch_dtype,
                safety_checker=None,
                revision=args.revision,
            )
            pipeline.set_progress_bar_config(disable=True)

            num_new_images = args.num_class_images - cur_class_images
            logger.info(f"Number of class images to sample: {num_new_images}.")

            sample_dataset = PromptDataset(args.class_prompt, num_new_images)
            sample_dataloader = torch.utils.data.DataLoader(
                sample_dataset, batch_size=args.sample_batch_size
            )

            sample_dataloader = accelerator.prepare(sample_dataloader)
            pipeline.to(accelerator.device)

            for example in tqdm(
                sample_dataloader,
                desc="Generating class images",
                disable=not accelerator.is_local_main_process,
            ):
                images = pipeline(example["prompt"]).images

                for i, image in enumerate(images):
                    hash_image = hashlib.sha1(image.tobytes()).hexdigest()
                    image_filename = (
                        class_images_dir
                        / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
                    )
                    image.save(image_filename)

            del pipeline
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    # Load the tokenizer
    if args.tokenizer_name:
        tokenizer = CLIPTokenizer.from_pretrained(
            args.tokenizer_name,
            revision=args.revision,
        )
    elif args.pretrained_model_name_or_path:
        tokenizer = CLIPTokenizer.from_pretrained(
            args.pretrained_model_name_or_path,
            subfolder="tokenizer",
            revision=args.revision,
        )

    # Load models and create wrapper for stable diffusion
    text_encoder = CLIPTextModel.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="text_encoder",
        revision=args.revision,
    )
    vae = AutoencoderKL.from_pretrained(
        args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
        subfolder=None if args.pretrained_vae_name_or_path else "vae",
        revision=None if args.pretrained_vae_name_or_path else args.revision,
    )
    unet = UNet2DConditionModel.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="unet",
        revision=args.revision,
    )
    unet.requires_grad_(False)
    unet_lora_params, _ = inject_trainable_lora(
        unet, r=args.lora_rank, loras=args.resume_unet
    )

    '''
    for _up, _down in extract_lora_ups_down(unet):
        print("Before training: Unet First Layer lora up", _up.weight.data)
        print("Before training: Unet First Layer lora down", _down.weight.data)
        break
    '''

    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)

    if args.train_text_encoder:
        text_encoder_lora_params, _ = inject_trainable_lora(
            text_encoder,
            target_replace_module=["CLIPAttention"],
            r=args.lora_rank,
        )
        '''
        for _up, _down in extract_lora_ups_down(
            text_encoder, target_replace_module=["CLIPAttention"]
        ):
            print("Before training: text encoder First Layer lora up", _up.weight.data)
            print(
                "Before training: text encoder First Layer lora down", _down.weight.data
            )
            break
        '''

    if args.use_xformers:
        set_use_memory_efficient_attention_xformers(unet, True)
        set_use_memory_efficient_attention_xformers(vae, True)

    if args.gradient_checkpointing:
        unet.enable_gradient_checkpointing()
        if args.train_text_encoder:
            text_encoder.gradient_checkpointing_enable()

    if args.scale_lr:
        args.learning_rate = (
            args.learning_rate
            * args.gradient_accumulation_steps
            * args.train_batch_size
            * accelerator.num_processes
        )

    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
            )

        optimizer_class = bnb.optim.AdamW8bit
    else:
        optimizer_class = torch.optim.AdamW

    text_lr = (
        args.learning_rate
        if args.learning_rate_text is None
        else args.learning_rate_text
    )

    params_to_optimize = (
        [
            {"params": itertools.chain(*unet_lora_params), "lr": args.learning_rate},
            {
                "params": itertools.chain(*text_encoder_lora_params),
                "lr": text_lr,
            },
        ]
        if args.train_text_encoder
        else itertools.chain(*unet_lora_params)
    )
    optimizer = optimizer_class(
        params_to_optimize,
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    noise_scheduler = DDPMScheduler.from_config(
        args.pretrained_model_name_or_path, subfolder="scheduler"
    )

    train_dataset = DreamBoothDataset(
        instance_data_dir=args.instance_data_dir,
        prompt_template=args.prompt_template,
        num_class_images=args.num_class_images,
        instance_prompt=args.instance_prompt,
        class_data_root=args.class_data_dir if args.with_prior_preservation else None,
        class_prompt=args.class_prompt,
        tokenizer=tokenizer,
        size=args.resolution,
        center_crop=args.center_crop,
        color_jitter=args.color_jitter,
        resize=args.resize,
    )
    
    val_dataset = DreamBoothDataset(
        instance_data_dir=os.path.dirname(args.instance_data_dir) + "/test.data",
        prompt_template=args.prompt_template,
        num_class_images=args.num_class_images,
        instance_prompt=args.instance_prompt,
        class_data_root=None,
        class_prompt=args.class_prompt,
        tokenizer=tokenizer,
        size=args.resolution,
        center_crop=args.center_crop,
        color_jitter=args.color_jitter,
        resize=args.resize,
    )

    def collate_fn(examples):
        input_ids = [example["prompt_ids"] for example in examples]
        original_input_ids = [example["original_prompt_ids"] for example in examples]
        pixel_values = [example["images"] for example in examples]
        wrong_input_ids = [example["wrong_prompt_ids"] for example in examples]

        pixel_values = torch.stack(pixel_values)
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

        class_mask = torch.tensor([example["is_class"] for example in examples], dtype=torch.bool)
        
        input_ids = tokenizer.pad(
            {"input_ids": input_ids},
            padding="max_length",
            max_length=tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids
        original_input_ids = tokenizer.pad(
            {"input_ids": original_input_ids},
            padding="max_length",
            max_length=tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids
        wrong_input_ids = tokenizer.pad(
            {"input_ids": wrong_input_ids},
            padding="max_length",
            max_length=tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids

        batch = {
            "input_ids": input_ids,
            "original_input_ids": original_input_ids,
            "wrong_input_ids": wrong_input_ids,
            "pixel_values": pixel_values,
            "class_mask": class_mask
        }
        return batch

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.train_batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=1,
    )

    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=1,
    )
    

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / args.gradient_accumulation_steps
    )
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
    )

    if args.train_text_encoder:
        (
            unet,
            text_encoder,
            optimizer,
            train_dataloader,
            val_dataloader,
            lr_scheduler,
        ) = accelerator.prepare(
            unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
        )
    else:
        unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
            unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
        )
    val_iterator = cycle(iter(val_dataloader))

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    # Move text_encode and vae to gpu.
    # For mixed precision training we cast the text_encoder and vae weights to half-precision
    # as these models are only used for inference, keeping weights in full precision is not required.
    vae.to(accelerator.device, dtype=weight_dtype)
    if not args.train_text_encoder:
        text_encoder.to(accelerator.device, dtype=weight_dtype)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / args.gradient_accumulation_steps
    )
    if overrode_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if accelerator.is_main_process:
        accelerator.init_trackers("dreambooth", config=vars(args))

    # Train!
    total_batch_size = (
        args.train_batch_size
        * accelerator.num_processes
        * args.gradient_accumulation_steps
    )

    print("***** Running training *****")
    print(f"  Num examples = {len(train_dataset)}")
    print(f"  Num batches each epoch = {len(train_dataloader)}")
    print(f"  Num Epochs = {args.num_train_epochs}")
    print(f"  Instantaneous batch size per device = {args.train_batch_size}")
    print(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
    )
    print(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    print(f"  Total optimization steps = {args.max_train_steps}")
    # Only show the progress bar once on each machine.
    progress_bar = tqdm(
        range(args.max_train_steps), disable=not accelerator.is_local_main_process
    )
    progress_bar.set_description("Steps")
    global_step = 0
    last_save = 0
    
    # for logging
    T = noise_scheduler.config.num_train_timesteps
    loss_at_ts = torch.zeros(T // 5, device=accelerator.device)
    discriminative_at_ts = torch.zeros(T // 5, device=accelerator.device)
    improvement_at_ts = torch.zeros(T // 5, device=accelerator.device)
    val_dis_at_ts = torch.zeros(T // 5, device=accelerator.device)
    val_imp_at_ts = torch.zeros(T // 5, device=accelerator.device)
    #sd = StableDiffusionPipeline.from_pretrained(
    #    args.pretrained_model_name_or_path,
    #    torch_dtype=torch.float16 if accelerator.device.type == "cuda" else torch.float32,
    #    safety_checker=None,
    #    revision=args.revision,
    #).to(accelerator.device)
    
    # train!
    logger.info("Start training...")
    for epoch in range(args.num_train_epochs):
        unet.train()
        if args.train_text_encoder:
            text_encoder.train()

        def process_batch(pixel_values, input_ids, noise=None, timesteps=None, max_T=-1):
            latents = vae.encode(
                pixel_values.to(dtype=weight_dtype)
            ).latent_dist.sample()
            latents = latents * 0.18215
            # Sample noise that we'll add to the latents
            noise = torch.randn_like(latents) if noise is None else noise
            bsz = latents.shape[0]
            # Sample a random timestep for each image
            cutoff_T = T if max_T < 0 else max_T
            timesteps = torch.randint(
                0,
                cutoff_T,
                (bsz,),
                device=latents.device,
            ) if timesteps is None else timesteps
            timesteps = timesteps.long()
            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            # Get the text embedding for conditioning
            encoder_hidden_states = text_encoder(input_ids)[0]
            # Predict the noise residual
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            # Get the target for loss depending on the prediction type
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(
                    f"Unknown prediction type {noise_scheduler.config.prediction_type}"
                )
            return model_pred, target, timesteps, noise, noisy_latents
        
        for step, batch in enumerate(train_dataloader):
            model_pred, target, ts, es, xts = process_batch(
                batch["pixel_values"], batch["input_ids"], max_T=args.cutoff_T
            )
            loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean(dim=(1,2,3))

            '''
            class_mask = batch["class_mask"]
            with torch.no_grad():
                encoder_hidden_states = text_encoder(batch["wrong_input_ids"])[0]
                model_pred = unet(xts, ts, encoder_hidden_states).sample
                wrong_class_loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean(dim=(1,2,3))
                loss_diff = wrong_class_loss - loss
                discriminative_at_ts[ts[~class_mask]//5] = \
                    discriminative_at_ts[ts[~class_mask]//5] * 0.95 + loss_diff[~class_mask].detach() * 0.05
                
                encoder_hidden_states = sd.text_encoder(batch["original_input_ids"])[0]
                model_pred = sd.unet(xts, ts, encoder_hidden_states).sample
                sd_loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean(dim=(1,2,3))
                sd_loss_diff = sd_loss - loss
                improvement_at_ts[ts[~class_mask]//5] = \
                    improvement_at_ts[ts[~class_mask]//5] * 0.95 + sd_loss_diff[~class_mask].detach() * 0.05
            '''
            accelerator.backward(loss.mean())
            if accelerator.sync_gradients:
                params_to_clip = (
                    itertools.chain(unet.parameters(), text_encoder.parameters())
                    if args.train_text_encoder
                    else unet.parameters()
                )
                accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
            optimizer.step()
            lr_scheduler.step()
            progress_bar.update(1)
            optimizer.zero_grad()
            loss_at_ts[ts//5] = loss_at_ts[ts//5] * 0.95 + loss.detach() * 0.05
            
            logs = {
                "loss": loss.mean().detach().item(),
                "lr": lr_scheduler.get_last_lr()[0],
                #"train_dis": loss_diff.mean().detach().item(),
                #"train_imp": sd_loss_diff.mean().detach().item(),
            }
            
            # val
            '''
            with torch.no_grad():
                val_batch = next(val_iterator)
                model_pred, target, ts, _, xts = process_batch(
                    val_batch["pixel_values"], val_batch["input_ids"], max_T=-1
                )   # random noise but same timestep as train
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean(dim=(1,2,3))
                encoder_hidden_states = text_encoder(val_batch["wrong_input_ids"])[0]
                model_pred = unet(xts, ts, encoder_hidden_states).sample
                wrong_class_loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean(dim=(1,2,3))
                loss_diff = wrong_class_loss - loss
                val_dis_at_ts[ts//5] = \
                    val_dis_at_ts[ts//5] * 0.95 + loss_diff.detach() * 0.05
                
                encoder_hidden_states = sd.text_encoder(val_batch["original_input_ids"])[0]
                model_pred = sd.unet(xts, ts, encoder_hidden_states).sample
                sd_loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean(dim=(1,2,3))
                sd_loss_diff = sd_loss - loss
                val_imp_at_ts[ts//5] = \
                    val_imp_at_ts[ts//5] * 0.95 + sd_loss_diff.detach() * 0.05
            logs["val_dis"] = loss_diff.mean().detach().item()
            logs["val_imp"] = sd_loss_diff.mean().detach().item()
            '''
            

            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                if args.save_steps and global_step - last_save >= args.save_steps:
                    if accelerator.is_main_process:
                        # newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
                        # it, the models will be unwrapped, and when they are then used for further training,
                        # we will crash. pass this, but only to newer versions of accelerate. fixes
                        # https://github.com/huggingface/diffusers/issues/1566
                        accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
                            inspect.signature(
                                accelerator.unwrap_model
                            ).parameters.keys()
                        )
                        extra_args = (
                            {"keep_fp32_wrapper": True}
                            if accepts_keep_fp32_wrapper
                            else {}
                        )
                        pipeline = StableDiffusionPipeline.from_pretrained(
                            args.pretrained_model_name_or_path,
                            unet=accelerator.unwrap_model(unet, **extra_args),
                            text_encoder=accelerator.unwrap_model(
                                text_encoder, **extra_args
                            ),
                            revision=args.revision,
                        )

                        #filename_unet = (
                        #    f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
                        #)
                        filename_unet = f"{args.output_dir}/unet_{global_step}.pt"
                        # filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
                        filename_text_encoder = f"{args.output_dir}/text_{global_step}.pt"
                        print(f"save weights {filename_unet}, {filename_text_encoder}")
                        save_lora_weight(pipeline.unet, filename_unet)
                        if args.train_text_encoder:
                            save_lora_weight(
                                pipeline.text_encoder,
                                filename_text_encoder,
                                target_replace_module=["CLIPAttention"],
                            )

                        # Save sample photos
                        pipeline.to("cuda")
                        pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
                        tune_lora_scale(pipeline.unet, 0.6)
                        tune_lora_scale(pipeline.text_encoder, 0.7)
                        with torch.no_grad():
                            prompts = random.choices(list(train_dataset.unique_instance_prompts), k=4)
                            images = pipeline(prompts, num_inference_steps=30, guidance_scale=7).images
                            grid_images = display_images(images, nrow=2, padding=5, size=1024, texts=prompts)
                            grid_images.save(f"{args.output_dir}/images_{global_step}.png")

                        # plot loss at each timestep
                        #save_histogram(
                        #    range(0, T, 5), loss_at_ts.cpu().numpy(), range(0, T, 50),
                        #    f"{args.output_dir}/loss_{global_step}.png"
                        #)
                        #save_histogram(
                        #    range(0, T, 5), discriminative_at_ts.cpu().numpy(), range(0, T, 50),
                        #    f"{args.output_dir}/discriminative_{global_step}.png"
                        #)
                        #save_histogram(
                        #    range(0, T, 5), improvement_at_ts.cpu().numpy(), range(0, T, 50),
                        #    f"{args.output_dir}/improvements_{global_step}.png"
                        #)
                        #save_histogram(
                        #    range(0, T, 5), val_dis_at_ts.cpu().numpy(), range(0, T, 50),
                        #    f"{args.output_dir}/val_discriminative_{global_step}.png"
                        #)
                        #save_histogram(
                        #    range(0, T, 5), val_imp_at_ts.cpu().numpy(), range(0, T, 50),
                        #    f"{args.output_dir}/val_improvements_{global_step}.png"
                        #)
                        
                        logger.info(f"Saving at step {global_step}...")
                        #logger.info(f"Train split: Correct prompt > wrong prompt by {discriminative_at_ts.cpu().numpy().mean()}.")
                        #logger.info(f"Train split: LoRA SD > SD by {improvement_at_ts.cpu().numpy().mean()}.")
                        #logger.info(f"Val split: Correct prompt > wrong prompt by {val_dis_at_ts.cpu().numpy().mean()}.")
                        #logger.info(f"Val split: LoRA SD > SD by {val_imp_at_ts.cpu().numpy().mean()}.")

                        del pipeline
                        last_save = global_step
                        
            if global_step >= args.max_train_steps:
                break

    accelerator.wait_for_everyone()

    # Create the pipeline using using the trained modules and save it.
    if accelerator.is_main_process and args.save_at_last:
        pipeline = StableDiffusionPipeline.from_pretrained(
            args.pretrained_model_name_or_path,
            unet=accelerator.unwrap_model(unet),
            text_encoder=accelerator.unwrap_model(text_encoder),
            revision=args.revision,
        )

        logger.info("\n\nLora TRAINING DONE!\n\n")

        if args.output_format == "pt" or args.output_format == "both":
            save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt")
            if args.train_text_encoder:
                save_lora_weight(
                    pipeline.text_encoder,
                    args.output_dir + "/lora_weight.text_encoder.pt",
                    target_replace_module=["CLIPAttention"],
                )

        if args.output_format == "safe" or args.output_format == "both":
            loras = {}
            loras["unet"] = (pipeline.unet, {"CrossAttention", "Attention", "GEGLU"})
            if args.train_text_encoder:
                loras["text_encoder"] = (pipeline.text_encoder, {"CLIPAttention"})

            save_safeloras(loras, args.output_dir + "/lora_weight.safetensors")

        if args.push_to_hub:
            repo.push_to_hub(
                commit_message="End of training",
                blocking=False,
                auto_lfs_prune=True,
            )

    accelerator.end_training()
    file_handler.close()


if __name__ == "__main__":
    args = parse_args()
    main(args)
